飞道的博客

带你从零掌握迭代器及构建最简DataLoader

620人阅读  评论(0)

点蓝色字关注“机器学习算法工程师

设为星标,干货直达!

AI编辑:深度眸

0 摘要

    本文本意是写 pytorch 中 DataLoader 源码学习心得,但是发现自己对迭代器和生成器的掌握比较水,不够牢固,而我也没有搜到能够解决我所有疑问的解答文章,因此诞生了这篇文章。通过本文你将能够零基础深入掌握 python 迭代器相关知识、并且能够一步步理解 DataLoader 的实现原理以及背后涉及的设计模式

    本文最终目的是通过源码学习自己实现一个功能比较完善的 DataLoader 类,为了达到这个目的,本文写作流程是:

  • 先深入浅出分析 python 中迭代器、生成器等实现原理,包括 Iterable、Iterator、for .. in ..、__getitem__、yield 生成器 5个部分

  • 再实现了一个最简单版本的 DataLoader,目的是理解 DataLoader 与 Dataset、Sampler、BatchSampler和 collate_fn 之间的调用关系

  • 最后对该实现进行深入全面分析,读者可以清晰的理解每个类的作用

    但是 DataLoader 功能其实非常复杂,故本文属于系列文章的第一篇,后面文章会不断完善、调整,最终实现 DataLoader 所有功能。或者说本文是后续文章的基础,如果基础内容没有理解非常透彻,后面的多进程、分布式版本就更难以理解了。

    虽然本文比较简单,但是由于涉及到代码,故为了方便,有必要的读者可以 clone rep 进行学习(需要特意说明的是:rep 里面代码是学习目的的,质量不高,不要要求那么多)

github:  https://github.com/hhaAndroid/miniloader

由于本人水平有限,某些环节理解可能有偏颇,欢迎指正。手机对于代码显示效果不太好,建议电脑端阅读。

1 python 迭代器深入浅出理解

1.1 可迭代对象 Iterable

    可迭代对象 Iterable:表示该对象可迭代,其并不是指某种具体数据类型。简单来说只要是实现了 `__iter__` 方法的类就是可迭代对象


   
  1. from collections.abc import Iterable, Iterator
  2. class A(object):
  3. def __init__(self):
  4. self.a = [ 1, 2, 3]
  5. def __iter__(self):
  6. # 此处返回啥无所谓
  7. return self.a
  8. cls_a = A()
  9. # True
  10. print(isinstance(cls_a, Iterable))

    但是对象如果是 Iterable 的,看起来好像也没有特别大的用途,因为你依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口:


   
  1. for a in cls_a:
  2. print(a)
  3. # 程序报错,要理解这个错误的含义
  4. TypeError: iter() returned non-iterator of type 'list'

    我们可以检查下 Iterable 接口:


   
  1. class Iterable(metaclass=ABCMeta):
  2. # 如果实现了这个方法,那么就是 Iterable
  3. @abstractmethod
  4. def __iter__(self):
  5. while False:
  6. yield None
  7. @classmethod
  8. def __subclasshook__(cls, C):
  9. if cls is Iterable:
  10. return _check_methods(C, "__iter__")
  11. return NotImplemented

看起来实现 Iterable 接口用途不大,其实不是的,其有很多用途的,例如简化代码等,在后面的高级语法糖中会频繁用到,后面会分析。

1.2 迭代器 Iterator

    迭代器 Iterator:其和 Iterable 之间是一个包含与被包含的关系,如果一个对象是迭代器 Iterator,那么这个对象肯定是可迭代 Iterable;但是反过来,如果一个对象是可迭代 Iterable,那么这个对象不一定是迭代器 Iterator,可以通过接口协议看出:


   
  1. class Iterator(Iterable):
  2. # 迭代具体实现
  3. @abstractmethod
  4. def __next__(self):
  5. 'Return the next item from the iterator. When exhausted, raise StopIteration'
  6. raise StopIteration
  7.      # 返回自身,因为自身有 __next__ 方法(如果自身没有 __next__,那么返回自身没有意义)
  8. def __iter__(self):
  9.          return self
  10.         
  11. @classmethod
  12. def __subclasshook__(cls, C):
  13. if cls is Iterator:
  14. return _check_methods(C, '__iter__', '__next__')
  15. return NotImplemented

可以发现:实现了 `__next__` 和 `__iter__` 方法的类才能称为迭代器,就可以被 for 遍历了


   
  1. class A(object):
  2. def __init__(self):
  3. self.index = - 1
  4. self.a = [ 1, 2, 3]
  5.      #必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
  6.      #因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
  7. def __iter__(self):
  8. return self
  9. def __next__(self):
  10. self.index += 1
  11. if self.index < len( self.a):
  12. return self.a[ self.index]
  13. else:
  14.              #抛异常,for 内部会自动捕获,表示迭代完成
  15. raise StopIteration( "遍历完了")
  16. cls_a = A()
  17. print(isinstance(cls_a, Iterable)) # True
  18. print(isinstance(cls_a, Iterator)) # True
  19. print(isinstance(iter(cls_a), Iterator)) # True
  20. for a in cls_a:
  21. print(a)
  22. # 打印 1 2 3

再次明确,一个对象如果要是 Iterator ,那么必须要实现 `__next__` 和 `__iter__` 方法,但是要理解其内部迭代流程,还需要理解 for .. in .. 流程。

1.3 for .. in .. 本质流程

    for .. in .. 也就是常见的迭代操作了,其被 python 编译器编译后,实际上代码是:


   
  1. # 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
  2. cls_a = iter(cls_a)
  3. while True:
  4. try:
  5. # 然后调用对象的 __next__ 方法,不断返回元素
  6. value = next(cls_a)
  7. print(value)
  8. # 如果迭代完成,则捕获异常即可
  9. except StopIteration:
  10. break

可以看出,任何一个对象如果要能够被 for 遍历,必须要实现  `__iter__` 和 `__next__` 方法,缺一不可

    明白了上述流程,那么迭代器对象 A,我们可以采用如下方式进行遍历:


   
  1. myiter = iter(cls_a)
  2. print( next(myiter))
  3. print( next(myiter))
  4. print( next(myiter))
  5. # 因为遍历完了,故此时会出现错误:StopIteration: 遍历完了
  6. print( next(myiter))

我们再来思考 python 内置对象 list 为啥可以被迭代


   
  1. b= list([ 1, 2, 3])
  2. print(isinstance(b, Iterable)) # True
  3. print(isinstance(b, Iterator)) # False

    可以发现 list 类型是可迭代对象,但是其不是迭代器(即 list 没有 `__next__` 方法),那为啥 for .. in .. 可以迭代呢?

    原因是 list 内部的 `__iter__` 方法内部返回了具备 `__next__` 方法的类,或者说调用 iter() 后返回的对象本身就是一个迭代器,当然可以 for 循环了


   
  1. b= list([ 1, 2, 3])
  2. print(dir(b)) # 可以发现其存在 __iter__ 方法,不存在 __next__
  3. b=iter(b) # 调用 list 内部的 __iter__,返回了具备 __next__ 的对象
  4. print(isinstance(b, Iterable)) # True
  5. print(isinstance(b, Iterator)) # True
  6. print(dir(b)) # 同时具备 __iter__ 和 __next__ 方法

基于上述理解我们可以对 A 类代码进行改造,使其更加简单:


   
  1. class A(object):
  2. def __init__(self):
  3. self.a = [ 1, 2, 3]
  4. # 我们内部又调用了 list 对象的 __iter__ 方法,故此时返回的对象是迭代器对象
  5. def __iter__(self):
  6. return iter(self.a)
  7. cls_a = A()
  8. print(isinstance(cls_a, Iterable)) # True
  9. print(isinstance(cls_a, Iterator)) # False
  10. for a in cls_a:
  11. print(a)
  12. # 输出:1 2 3

    此时我们就实现了仅仅实现 Iterable 规范接口,但是又具备了 for .. in .. 功能,代码是不是比最开始的实现简单很多?这种写法应用也非常广泛,因为其不需要自己再次实现 `__next__` 方法。

    如果你想理解的更加透彻,那么可以看下面例子:


   
  1. # 仅仅实现 __iter__
  2. class A(object):
  3. def __init__(self):
  4. self.b = B()
  5. def __iter__(self):
  6. return self.b
  7. # 仅仅实现 __next__
  8. class B(object):
  9. def __init__(self):
  10. self.index = - 1
  11. self.a = [ 1, 2, 3]
  12. def __next__(self):
  13. self.index += 1
  14. if self.index < len( self.a):
  15. return self.a[ self.index]
  16. else:
  17. # 内部会自动捕获,表示迭代完成
  18. raise StopIteration( "遍历完了")
  19. cls_a = A()
  20. cls_b = B()
  21. print(isinstance(cls_a, Iterable)) # True
  22. print(isinstance(cls_a, Iterator)) # False
  23. print(isinstance(cls_b, Iterable)) # False
  24. print(isinstance(cls_b, Iterator)) # False
  25. print(type(iter(cls_a))) # B 对象
  26. print(isinstance(iter(cls_a), Iterator)) # False
  27. for a in cls_a:
  28. print(a)
  29. # 输出:1 2 3

    自此我们知道了:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 `__iter__` 和 `__next__` 方法(对象必然是 Iterator),还是只实现 `__iter__`(不是 Iterator),但是内部间接返回了具备 `__next__` 对象的类,都是可行的

    但是除了这两种实现,还有其他高级语法糖,可以进一步精简代码。

1.4  __ getitem__ 理解

    上面说过 for .. in .. 的本质就是调用对象的 `__iter__` 和 `__next__` 方法,但是有一种更加简单的写法,你通过仅仅实现 `__getitem__` 方法就可以让对象实现迭代功能。实际上任何一个类,如果实现了`__getitem__` 方法,那么当调用 iter(类实例) 时候会自动具备`__iter__` 和 `__next__`方法,从而可迭代了。

    通过下面例子可以看出,`__getitem__` 实际上是属于 __iter__` 和 `__next__` 方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。


   
  1. class A(object):
  2. def __init__(self):
  3. self.a = [ 1, 2, 3]
  4. def __getitem__(self, item):
  5.          return self.a[item]
  6.         
  7. cls_a = A()
  8. print(isinstance(cls_a, Iterable)) # False
  9. print(isinstance(cls_a, Iterator)) # False
  10. print(dir(cls_a)) # 仅仅具备 __getitem__ 方法
  11. cls_a = iter(cls_a)
  12. print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
  13. print(isinstance(cls_a, Iterable)) # True
  14. print(isinstance(cls_a, Iterator)) # True
  15. # 等价于 for .. in ..
  16. while True:
  17. try:
  18.          # 然后调用对象的 __next__ 方法,不断返回元素
  19. value = next(cls_a)
  20. print(value)
  21. # 如果迭代完成,则捕获异常即可
  22. except StopIteration:
  23. break
  24. # 输出:1 2 3

而且 `__getitem__` 还可以通过索引直接访问元素,非常方便


   
  1. a[0]  # 1  
  2. a[4] # 错误,索引越界

如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 `__len__` 方法即可


   
  1. class A(object):
  2. def __init__(self):
  3. self.a = [ 1, 2, 3]
  4. def __getitem__(self, item):
  5. return self.a[item]
  6. def __len__(self):
  7. return len( self.a)
  8. cls_a = A()
  9. print(len(cls_a)) # 3

    到目前为止,我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

1.5 yield 生成器

    生成器是一个在行为上和迭代器非常类似的对象,二者功能上差不多,但是生成器更优雅,只需要用关键字 yield 来返回,作用于函数上叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质就是迭代器,其最大特点是代码简洁。


   
  1. def func():
  2. for a in [ 1, 2, 3]:
  3. yield a
  4. cls_g = func()
  5. print(isinstance(cls_g, Iterator)) # True
  6. print(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
  7. for a in cls_g:
  8. print(a)
  9. # 输出: 1 2 3
  10. # 一种更简单的写法是用 ()
  11. cls_g = (i for i in [ 1, 2, 3])

    直观感觉和 `__getitem__` 一样,也是高级语法糖,但是比 `__getitem__` 更加简单,更加好用。

    使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。假设创建一个包含10万个元素的列表,如果用 list 返回不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了,这种场景就适合采用生成器,在迭代过程中推算出后续元素,而不需要一次性全部算出。

1.6 小结

  •  list set dict等内置对象都是容器 container 对象,容器是一种把多个元素组织在一起的数据结构,可以逐个迭代获取其中的元素。容器可以用 in 来判断容器中是否包含某个元素。大多数容器都是可迭代对象,可以使用某种方式访问容器中的每一个元素。

  • 在迭代对象基础上,如果实现了 `__next__`  方法则是迭代器对象,该对象在调用 next()  的时候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

  • 对于采用语法糖 `__getitem__` 实现的迭代器对象,其本身实例既不是可迭代对象,更不是迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

  • 生成器是一种特殊迭代器,但是不需要像迭代器一样实现`__iter__`和`__next__`方法,只需要使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。

  • 对于在类的 `__iter__` 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

2 DataLoader 最简版本 V1

    这里说的最简版本是指:没有任何花哨、高级实现技巧,仅仅以实现最基础功能为目的。具体来说是包括必备的5个对象:Dataset、Sampler、BatchSampler、DataLoader 和 collate_fn。其作用可以简要描述为如下:

  • Dataset 提供整个数据集的随机访问功能,每次调用都返回单个对象,例如一张图片和对应 target 等等

  • Sampler 提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引,常用子类是 SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

  • BatchSampler 内部调用 Sampler 实例,输出指定 `batch_size` 个索引,然后将索引作用于 Dataset 上从而输出 `batch_size` 个数据对象,例如 batch 张图片和 batch 个 target

  • collate_fn 用于将 batch 个数据对象在 batch 维度进行聚合,生成 (b,...) 格式的数据输出,如果待聚合对象是 numpy,则会自动转化为 tensor,此时就可以输入到网络中了

迭代一次伪代码如下(非迭代器版本):


   
  1. class DataLoader(object):
  2. def __init__(self):
  3. # 假设数据长度是100,batch_size 是4
  4. self.dataset = [[img 0, target 0], [img1, target1], ..., [img99, target99]]
  5. # 假设 sampler 是 SequentialSampler,那么实际上就是 [0,1,...,99] 列表而已
  6. # 如果 sampler 是 RandomSampler,那么可能是 [30,1,34,2,6,...,0] 列表
  7. self.sampler = [ 0, 1, 2, 3, 4, ..., 99]
  8. self.batch_size = 4
  9. self.index = 0
  10. def collate_fn(self, data):
  11. # batch 维度聚合数据
  12. batch_img = torch.stack(data[ 0], 0)
  13. batch_target = torch.stack(data[ 1], 0)
  14. return batch_img, batch_target
  15. def __next__(self):
  16. # 0.batch_index 输出,实际上就是 BatchSampler 做的事情
  17. i = 0
  18. batch_index = []
  19. while i < self. batch_size:
  20. # 内部会调用 sampler 对象取单个索引
  21. batch_index.append( self.sampler[ self.index])
  22. self.index += 1
  23. i += 1
  24. # 1.得到 batch 个数据了,调用 dataset 对象
  25. data = [ self.dataset[idx] for idx in batch_index]
  26. # 2. 调用 collate_fn 在 batch 维度拼接输出
  27. batch_data = self.collate_fn(data)
  28. return batch_data
  29. def __iter__(self):
  30. return self

    以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

2.1 整体对象理解

    首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器,你必须要先理解第一小节内容,否则本节内容会比较难理解,具体为:

  •  Dataset 通过实现 `__getitem__` 方法使其可迭代

  •  Sampler 对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 `__iter__` 内部返回迭代器,RandomSampler 在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  BatchSampler 也是在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  DataLoader 通过直接实现 `__next__` 和 `__iter__` 变成迭代器

    注意除了 DataLoader 本身是迭代器外,其余对象本身不是迭代器,但是都能被 for .. in .. 迭代。下面一个简单例子证明:


   
  1. from simplev1_datatset  import SimpleV1Dataset
  2. from libv1 import SequentialSampler, RandomSampler
  3. from collections  import Iterator, Iterable   
  4. simple_dataset = SimpleV1Dataset() 
  5. dataloader = DataLoader(simple_dataset, batch_size= 2, collate_fn=default_collate)
  6. print(isinstance(simple_dataset, Iterable)) # False
  7. print(isinstance(simple_dataset, Iterator)) # False
  8. print(isinstance(iter(simple_dataset), Iterator)) # True
  9. print(isinstance(SequentialSampler(simple_dataset), Iterable)) # True
  10. print(isinstance(SequentialSampler(simple_dataset), Iterator)) # False
  11. print(isinstance(iter(SequentialSampler(simple_dataset)), Iterator)) # True
  12. # BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样
  13. print(isinstance(RandomSampler(simple_dataset), Iterable)) # True
  14. print(isinstance(RandomSampler(simple_dataset), Iterator)) # False
  15. print(isinstance(iter(RandomSampler(simple_dataset)), Iterator))  # True
  16. print(isinstance(dataloader, Iterator)) # True

    在 DataLoader 中主要涉及3个类,其内部实例传递关系如下:

    由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

    需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

2.2 DataLoader 运行流程

    最简单版本 DataLoader,具备如下功能:

  • Dataset 内部返回需要是 numpy 或者 tensor 对象

  • Sampler 直接 SequentialSampler 和 RandomSampler

  • BatchSampler 已经实现

  • collate_fn 仅仅考虑了 numpy 或者 tensor 对象

  • 仅仅支持 num_works=0 即单进程

看起来功能非常单一,但是其实已经搭建起了整个框架,理解了这个最简框架才能去理解高级实现,其核心运行逻辑为:


   
  1. def __next__(self):
  2.      # 返回 batch 个索引
  3.     index =  next( self.batch_sampler)
  4.      # 利用索引去取数据
  5.     data = [ self.dataset[idx]  for idx  in index
  6.      # batch 维度聚合
  7.     data =  self.collate_fn(data)
  8.      return data

然后为了方便大家理解,特意绘制了如下代码运行流程图:

    还是那句话:一定要对第1小节内容非常熟悉,否则里面这么多迭代器、生成器的调用,可能会把你绕晕。详细代码描述如下:

  1. `self.batch_sampler = iter(batch_sampler)`。在 DataLoader 的类初始化,需要得到 BatchSampler 的迭代器对象

  2. `index = next(self.batch_sampler)`。对于每次迭代,DataLoader 对象首先会调用 BatchSampler 的迭代器进行下一次迭代,具体是调用 BatchSampler 对象的  `__iter__`  方法

  3. 而 BatchSampler 对象的 `__iter__` 方法实际上是需要依靠 Sampler 对象进行迭代输出索引,Sampler 对象也是一个迭代器,当迭代 `batch_size` 次后就可以得到 `batch_size` 个数据索引

  4. `data = [self.dataset[idx] for idx in index]`。有了 batch 个索引就可以通过不断调用  dataset 的 `__getitem__` 方法返回数据对象,此时 data 就包含了 batch 个对象

  5. `data = self.collate_fn(data)`。将 batch 个对象输入给聚合函数,在第0个维度也就是 batch 维度进行聚合,得到类似 (b,...) 的对象

  6. 不断重复1-5步,就可以不断的输出一个一个 batch 的数据了

以上就是完整流程,如果理解有困难,你可以先看下一小结的代码实现,然后再返回去理解

2.3 最简V1版本源代码

(1) Dataset


   
  1. class Dataset(object):
  2.      # 只要实现了 __getitem__ 方法就可以变成迭代器
  3. def __getitem__(self, index):
  4. raise NotImplementedError
  5. # 用于获取数据集长度
  6. def __len__(self):
  7.          raise NotImplementedError

(2) Sampler


   
  1. class Sampler(object):
  2.      def __init__(self, data_source):
  3. pass
  4. def __iter__(self):
  5. raise NotImplementedError
  6. def __len__(self):
  7. raise NotImplementedError

   
  1. class SequentialSampler(Sampler):
  2. def __init__(self, data_source):
  3. super(SequentialSampler, self).__init_ _(data_source)
  4. self.data_source = data_source
  5. def __iter__(self):
  6. # 返回迭代器,不然无法 for .. in ..
  7. return iter(range(len( self.data_source)))
  8. def __len__(self):
  9. return len( self.data_source)

   
  1. class BatchSampler(Sampler):
  2. def __init__( self, sampler, batch_size, drop_last):
  3. self.sampler = sampler
  4. self.batch_size = batch_size
  5. self.drop_last = drop_last
  6. def __iter__( self):
  7. batch = []
  8. # 调用 sampler 内部的迭代器对象
  9. for idx in self.sampler:
  10. batch.append(idx)
  11. # 如果已经得到了 batch 个 索引,则可以通过 yield
  12. # 关键字生成生成器返回,得到迭代器对象
  13. if len(batch) == self.batch_size:
  14. yield batch
  15. batch = []
  16. if len(batch) > 0 and not self.drop_last:
  17. yield batch
  18. def __len__( self):
  19. if self.drop_last:
  20. # 如果最后的索引数不够一个 batch,则抛弃
  21. return len( self.sampler) // self.batch_size
  22.          else:  
  23.              return  (len( self.sampler) +  self.batch_size -  1// self.batch_size

(3) DataLoader


   
  1. class DataLoader(object):
  2. def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
  3.                  batch_sampler=None, collate_fn=None, drop_last=False)
  4. self.dataset = dataset
  5. # 因为这两个功能是冲突的,假设 shuffle=True,
  6. # 但是 sampler 里面是 SequentialSampler,那么就违背设计思想了
  7. if sampler is not None and shuffle:
  8. raise ValueError( 'sampler option is mutually exclusive with '
  9. 'shuffle')
  10. if batch_sampler is not None:
  11. # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
  12. # 和 drop_last 四个参数就不能传入
  13. # 因为这4个参数功能和 batch_sampler 功能冲突了
  14. if batch_size != 1 or shuffle or sampler is not None or drop_last:
  15. raise ValueError( 'batch_sampler option is mutually exclusive '
  16. 'with batch_size, shuffle, sampler, and '
  17. 'drop_last')
  18. batch_size = None
  19. drop_last = False
  20. if sampler is None:
  21. if shuffle:
  22. sampler = RandomSampler(dataset)
  23. else:
  24. sampler = SequentialSampler(dataset)
  25. # 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
  26. if batch_sampler is None:
  27. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  28. self.batch_size = batch_size
  29. self.drop_last = drop_last
  30. self.sampler = sampler
  31. self.batch_sampler = iter(batch_sampler)
  32. if collate_fn is None:
  33. collate_fn = default_collate
  34. self.collate_fn = collate_fn
  35. # 核心代码
  36. def __next__(self):
  37. index = next(self.batch_sampler)
  38. data = [self.dataset[idx] for idx in index]
  39. data = self.collate_fn(data)
  40. return data
  41. # 返回自身,因为自身实现了 __next__
  42. def __iter__(self):
  43. return self

(4) collate_fn


   
  1. def default_collate(batch):
  2. elem = batch[ 0]
  3. elem_type = type(elem)
  4. if isinstance(elem, torch.Tensor):
  5. return torch.stack(batch, 0)
  6. elif elem_type.__module__ == 'numpy':
  7. return default_collate([torch.as_tensor(b) for b in batch])
  8. else:
  9. raise NotImplementedError

(5) 调用完整例子


   
  1. class SimpleV1Dataset(Dataset):
  2. def __init__(self):
  3. # 伪造数据
  4. self.imgs = np.arange( 0, 16).reshape( 8, 2)
  5. def __getitem__(self, index):
  6. return self.imgs[index]
  7. def __len__(self):
  8. return self.imgs.shape[ 0]
  9. from simplev1_datatset import SimpleV1Dataset
  10. simple_dataset = SimpleV1Dataset()
  11. dataloader = DataLoader(simple_dataset, batch_size= 2, collate_fn=default_collate)
  12. for data in dataloader:
  13. print(data)

3 总结

    本文是最小 DataLoader 系列文章的第一篇,重点是分析了 python 中迭代器相关知识,然后构建一个最简单的 DataLoader 类,用于加深到 DataLoader 流程的理解,功能比较简单。

    后面慢慢完善,希望最终能实现完整功能。

github: https://github.com/hhaAndroid/miniloader

推荐阅读

PyTorch 源码解读之 torch.autograd

PyTorch 源码解读之 BN & SyncBN

机器学习算法工程师


                                            一个用心的公众号


 


转载:https://blog.csdn.net/l7H9JA4/article/details/111939787
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场