点蓝色字关注“机器学习算法工程师”
设为星标,干货直达!
AI编辑:深度眸
0 摘要
本文本意是写 pytorch 中 DataLoader 源码学习心得,但是发现自己对迭代器和生成器的掌握比较水,不够牢固,而我也没有搜到能够解决我所有疑问的解答文章,因此诞生了这篇文章。通过本文你将能够零基础深入掌握 python 迭代器相关知识、并且能够一步步理解 DataLoader 的实现原理以及背后涉及的设计模式。
本文最终目的是通过源码学习自己实现一个功能比较完善的 DataLoader 类,为了达到这个目的,本文写作流程是:
先深入浅出分析 python 中迭代器、生成器等实现原理,包括 Iterable、Iterator、for .. in ..、__getitem__、yield 生成器 5个部分
再实现了一个最简单版本的 DataLoader,目的是理解 DataLoader 与 Dataset、Sampler、BatchSampler和 collate_fn 之间的调用关系
最后对该实现进行深入全面分析,读者可以清晰的理解每个类的作用
但是 DataLoader 功能其实非常复杂,故本文属于系列文章的第一篇,后面文章会不断完善、调整,最终实现 DataLoader 所有功能。或者说本文是后续文章的基础,如果基础内容没有理解非常透彻,后面的多进程、分布式版本就更难以理解了。
虽然本文比较简单,但是由于涉及到代码,故为了方便,有必要的读者可以 clone rep 进行学习(需要特意说明的是:rep 里面代码是学习目的的,质量不高,不要要求那么多)
github: https://github.com/hhaAndroid/miniloader
由于本人水平有限,某些环节理解可能有偏颇,欢迎指正。手机对于代码显示效果不太好,建议电脑端阅读。
1 python 迭代器深入浅出理解
1.1 可迭代对象 Iterable
可迭代对象 Iterable:表示该对象可迭代,其并不是指某种具体数据类型。简单来说只要是实现了 `__iter__` 方法的类就是可迭代对象。
-
from collections.abc
import Iterable, Iterator
-
class A(object):
-
def __init__(self):
-
self.a = [
1,
2,
3]
-
def __iter__(self):
-
# 此处返回啥无所谓
-
return self.a
-
cls_a = A()
-
# True
-
print(isinstance(cls_a, Iterable))
但是对象如果是 Iterable 的,看起来好像也没有特别大的用途,因为你依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口:
-
for a
in cls_a:
-
print(a)
-
-
-
# 程序报错,要理解这个错误的含义
-
TypeError: iter() returned non-iterator of
type
'list'
我们可以检查下 Iterable 接口:
-
class Iterable(metaclass=ABCMeta):
-
-
-
# 如果实现了这个方法,那么就是 Iterable
-
@abstractmethod
-
def __iter__(self):
-
while
False:
-
yield
None
-
-
-
@classmethod
-
def __subclasshook__(cls, C):
-
if cls
is Iterable:
-
return _check_methods(C,
"__iter__")
-
return
NotImplemented
看起来实现 Iterable 接口用途不大,其实不是的,其有很多用途的,例如简化代码等,在后面的高级语法糖中会频繁用到,后面会分析。
1.2 迭代器 Iterator
迭代器 Iterator:其和 Iterable 之间是一个包含与被包含的关系,如果一个对象是迭代器 Iterator,那么这个对象肯定是可迭代 Iterable;但是反过来,如果一个对象是可迭代 Iterable,那么这个对象不一定是迭代器 Iterator,可以通过接口协议看出:
-
class Iterator(Iterable):
-
-
-
# 迭代具体实现
-
@abstractmethod
-
def __next__(self):
-
'Return the next item from the iterator. When exhausted, raise StopIteration'
-
raise StopIteration
-
-
-
# 返回自身,因为自身有 __next__ 方法(如果自身没有 __next__,那么返回自身没有意义)
-
def __iter__(self):
-
return self
-
-
@classmethod
-
def __subclasshook__(cls, C):
-
if cls
is Iterator:
-
return _check_methods(C,
'__iter__',
'__next__')
-
return
NotImplemented
可以发现:实现了 `__next__` 和 `__iter__` 方法的类才能称为迭代器,就可以被 for 遍历了。
-
class A(object):
-
def __init__(self):
-
self.index = -
1
-
self.a = [
1,
2,
3]
-
-
-
#必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
-
#因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
-
def __iter__(self):
-
return
self
-
-
-
def __next__(self):
-
self.index +=
1
-
if
self.index < len(
self.a):
-
return
self.a[
self.index]
-
else:
-
#抛异常,for 内部会自动捕获,表示迭代完成
-
raise StopIteration(
"遍历完了")
-
cls_a = A()
-
print(isinstance(cls_a, Iterable))
# True
-
print(isinstance(cls_a, Iterator))
# True
-
print(isinstance(iter(cls_a), Iterator))
# True
-
-
-
for a
in
cls_a:
-
print(a)
-
# 打印 1 2 3
再次明确,一个对象如果要是 Iterator ,那么必须要实现 `__next__` 和 `__iter__` 方法,但是要理解其内部迭代流程,还需要理解 for .. in .. 流程。
1.3 for .. in .. 本质流程
for .. in .. 也就是常见的迭代操作了,其被 python 编译器编译后,实际上代码是:
-
# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
-
cls_a = iter(cls_a)
-
while
True:
-
try:
-
# 然后调用对象的 __next__ 方法,不断返回元素
-
value = next(cls_a)
-
print(value)
-
# 如果迭代完成,则捕获异常即可
-
except StopIteration:
-
break
可以看出,任何一个对象如果要能够被 for 遍历,必须要实现 `__iter__` 和 `__next__` 方法,缺一不可。
明白了上述流程,那么迭代器对象 A,我们可以采用如下方式进行遍历:
-
myiter = iter(cls_a)
-
print(
next(myiter))
-
print(
next(myiter))
-
print(
next(myiter))
-
# 因为遍历完了,故此时会出现错误:StopIteration: 遍历完了
-
print(
next(myiter))
我们再来思考 python 内置对象 list 为啥可以被迭代?
-
b=
list([
1,
2,
3])
-
print(isinstance(b,
Iterable))
# True
-
print(isinstance(b,
Iterator))
# False
可以发现 list 类型是可迭代对象,但是其不是迭代器(即 list 没有 `__next__` 方法),那为啥 for .. in .. 可以迭代呢?
原因是 list 内部的 `__iter__` 方法内部返回了具备 `__next__` 方法的类,或者说调用 iter() 后返回的对象本身就是一个迭代器,当然可以 for 循环了。
-
b=
list([
1,
2,
3])
-
print(dir(b))
# 可以发现其存在 __iter__ 方法,不存在 __next__
-
-
-
b=iter(b)
# 调用 list 内部的 __iter__,返回了具备 __next__ 的对象
-
print(isinstance(b,
Iterable))
# True
-
print(isinstance(b,
Iterator))
# True
-
print(dir(b))
# 同时具备 __iter__ 和 __next__ 方法
基于上述理解我们可以对 A 类代码进行改造,使其更加简单:
-
class A(object):
-
def __init__(self):
-
self.a = [
1,
2,
3]
-
# 我们内部又调用了 list 对象的 __iter__ 方法,故此时返回的对象是迭代器对象
-
def __iter__(self):
-
return iter(self.a)
-
-
-
cls_a = A()
-
print(isinstance(cls_a, Iterable))
# True
-
print(isinstance(cls_a, Iterator))
# False
-
-
-
for a
in cls_a:
-
print(a)
-
# 输出:1 2 3
此时我们就实现了仅仅实现 Iterable 规范接口,但是又具备了 for .. in .. 功能,代码是不是比最开始的实现简单很多?这种写法应用也非常广泛,因为其不需要自己再次实现 `__next__` 方法。
如果你想理解的更加透彻,那么可以看下面例子:
-
# 仅仅实现 __iter__
-
class A(object):
-
def __init__(self):
-
self.b = B()
-
-
-
def __iter__(self):
-
return
self.b
-
-
-
# 仅仅实现 __next__
-
class B(object):
-
def __init__(self):
-
self.index = -
1
-
self.a = [
1,
2,
3]
-
-
-
def __next__(self):
-
self.index +=
1
-
if
self.index < len(
self.a):
-
return
self.a[
self.index]
-
else:
-
# 内部会自动捕获,表示迭代完成
-
raise StopIteration(
"遍历完了")
-
-
-
-
-
cls_a = A()
-
cls_b = B()
-
print(isinstance(cls_a, Iterable))
# True
-
print(isinstance(cls_a, Iterator))
# False
-
print(isinstance(cls_b, Iterable))
# False
-
print(isinstance(cls_b, Iterator))
# False
-
-
-
print(type(iter(cls_a)))
# B 对象
-
print(isinstance(iter(cls_a), Iterator))
# False
-
-
-
for a
in
cls_a:
-
print(a)
-
-
-
# 输出:1 2 3
自此我们知道了:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 `__iter__` 和 `__next__` 方法(对象必然是 Iterator),还是只实现 `__iter__`(不是 Iterator),但是内部间接返回了具备 `__next__` 对象的类,都是可行的。
但是除了这两种实现,还有其他高级语法糖,可以进一步精简代码。
1.4 __ getitem__ 理解
上面说过 for .. in .. 的本质就是调用对象的 `__iter__` 和 `__next__` 方法,但是有一种更加简单的写法,你通过仅仅实现 `__getitem__` 方法就可以让对象实现迭代功能。实际上任何一个类,如果实现了`__getitem__` 方法,那么当调用 iter(类实例) 时候会自动具备`__iter__` 和 `__next__`方法,从而可迭代了。
通过下面例子可以看出,`__getitem__` 实际上是属于 __iter__` 和 `__next__` 方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。
-
class A(object):
-
def __init__(self):
-
self.a = [
1,
2,
3]
-
-
-
def __getitem__(self, item):
-
return self.a[item]
-
-
cls_a = A()
-
print(isinstance(cls_a, Iterable))
# False
-
print(isinstance(cls_a, Iterator))
# False
-
print(dir(cls_a))
# 仅仅具备 __getitem__ 方法
-
-
-
cls_a = iter(cls_a)
-
print(dir(cls_a))
# 具备 __iter__ 和 __next__ 方法
-
-
-
print(isinstance(cls_a, Iterable))
# True
-
print(isinstance(cls_a, Iterator))
# True
-
-
-
# 等价于 for .. in ..
-
while
True:
-
try:
-
# 然后调用对象的 __next__ 方法,不断返回元素
-
value = next(cls_a)
-
print(value)
-
# 如果迭代完成,则捕获异常即可
-
except StopIteration:
-
break
-
-
-
# 输出:1 2 3
而且 `__getitem__` 还可以通过索引直接访问元素,非常方便
-
a[0]
# 1
-
a[4]
# 错误,索引越界
如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 `__len__` 方法即可
-
class A(object):
-
def __init__(self):
-
self.a = [
1,
2,
3]
-
-
-
def __getitem__(self, item):
-
return
self.a[item]
-
-
-
def __len__(self):
-
return len(
self.a)
-
-
-
cls_a = A()
-
print(len(cls_a))
# 3
到目前为止,我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。
1.5 yield 生成器
生成器是一个在行为上和迭代器非常类似的对象,二者功能上差不多,但是生成器更优雅,只需要用关键字 yield 来返回,作用于函数上叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质就是迭代器,其最大特点是代码简洁。
-
def func():
-
for a
in [
1,
2,
3]:
-
yield a
-
-
-
cls_g = func()
-
print(isinstance(cls_g, Iterator))
# True
-
print(dir(cls_g))
# 自动具备 __iter__ 和 __next__ 方法
-
-
-
for a
in cls_g:
-
print(a)
-
-
-
# 输出: 1 2 3
-
-
-
# 一种更简单的写法是用 ()
-
cls_g = (i
for i
in [
1,
2,
3])
直观感觉和 `__getitem__` 一样,也是高级语法糖,但是比 `__getitem__` 更加简单,更加好用。
使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。假设创建一个包含10万个元素的列表,如果用 list 返回不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了,这种场景就适合采用生成器,在迭代过程中推算出后续元素,而不需要一次性全部算出。
1.6 小结
list set dict等内置对象都是容器 container 对象,容器是一种把多个元素组织在一起的数据结构,可以逐个迭代获取其中的元素。容器可以用 in 来判断容器中是否包含某个元素。大多数容器都是可迭代对象,可以使用某种方式访问容器中的每一个元素。
在迭代对象基础上,如果实现了 `__next__` 方法则是迭代器对象,该对象在调用 next() 的时候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。
对于采用语法糖 `__getitem__` 实现的迭代器对象,其本身实例既不是可迭代对象,更不是迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。
生成器是一种特殊迭代器,但是不需要像迭代器一样实现`__iter__`和`__next__`方法,只需要使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。
对于在类的 `__iter__` 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。
2 DataLoader 最简版本 V1
这里说的最简版本是指:没有任何花哨、高级实现技巧,仅仅以实现最基础功能为目的。具体来说是包括必备的5个对象:Dataset、Sampler、BatchSampler、DataLoader 和 collate_fn。其作用可以简要描述为如下:
Dataset 提供整个数据集的随机访问功能,每次调用都返回单个对象,例如一张图片和对应 target 等等
Sampler 提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引,常用子类是 SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引
BatchSampler 内部调用 Sampler 实例,输出指定 `batch_size` 个索引,然后将索引作用于 Dataset 上从而输出 `batch_size` 个数据对象,例如 batch 张图片和 batch 个 target
collate_fn 用于将 batch 个数据对象在 batch 维度进行聚合,生成 (b,...) 格式的数据输出,如果待聚合对象是 numpy,则会自动转化为 tensor,此时就可以输入到网络中了
迭代一次伪代码如下(非迭代器版本):
-
class DataLoader(object):
-
def __init__(self):
-
# 假设数据长度是100,batch_size 是4
-
self.dataset = [[img
0, target
0], [img1, target1], ..., [img99, target99]]
-
# 假设 sampler 是 SequentialSampler,那么实际上就是 [0,1,...,99] 列表而已
-
# 如果 sampler 是 RandomSampler,那么可能是 [30,1,34,2,6,...,0] 列表
-
self.sampler = [
0,
1,
2,
3,
4, ...,
99]
-
self.batch_size =
4
-
self.index =
0
-
-
-
def collate_fn(self, data):
-
# batch 维度聚合数据
-
batch_img = torch.stack(data[
0],
0)
-
batch_target = torch.stack(data[
1],
0)
-
return batch_img, batch_target
-
-
-
def __next__(self):
-
# 0.batch_index 输出,实际上就是 BatchSampler 做的事情
-
i =
0
-
batch_index = []
-
while i <
self.
batch_size:
-
# 内部会调用 sampler 对象取单个索引
-
batch_index.append(
self.sampler[
self.index])
-
self.index +=
1
-
i +=
1
-
-
-
# 1.得到 batch 个数据了,调用 dataset 对象
-
data = [
self.dataset[idx]
for idx
in batch_index]
-
-
-
# 2. 调用 collate_fn 在 batch 维度拼接输出
-
batch_data =
self.collate_fn(data)
-
return batch_data
-
-
-
def __iter__(self):
-
return
self
以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。
2.1 整体对象理解
首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器,你必须要先理解第一小节内容,否则本节内容会比较难理解,具体为:
Dataset 通过实现 `__getitem__` 方法使其可迭代
Sampler 对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 `__iter__` 内部返回迭代器,RandomSampler 在 `__iter__` 内部通过 yield 关键字返回迭代器
BatchSampler 也是在 `__iter__` 内部通过 yield 关键字返回迭代器
DataLoader 通过直接实现 `__next__` 和 `__iter__` 变成迭代器
注意除了 DataLoader 本身是迭代器外,其余对象本身不是迭代器,但是都能被 for .. in .. 迭代。下面一个简单例子证明:
-
from simplev1_datatset
import SimpleV1Dataset
-
from libv1
import SequentialSampler, RandomSampler
-
from collections
import Iterator, Iterable
-
-
-
simple_dataset = SimpleV1Dataset()
-
dataloader = DataLoader(simple_dataset, batch_size=
2, collate_fn=default_collate)
-
-
-
print(isinstance(simple_dataset, Iterable))
# False
-
print(isinstance(simple_dataset, Iterator))
# False
-
print(isinstance(iter(simple_dataset), Iterator))
# True
-
-
-
print(isinstance(SequentialSampler(simple_dataset), Iterable))
# True
-
print(isinstance(SequentialSampler(simple_dataset), Iterator))
# False
-
print(isinstance(iter(SequentialSampler(simple_dataset)), Iterator))
# True
-
-
-
# BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样
-
print(isinstance(RandomSampler(simple_dataset), Iterable))
# True
-
print(isinstance(RandomSampler(simple_dataset), Iterator))
# False
-
print(isinstance(iter(RandomSampler(simple_dataset)), Iterator))
# True
-
-
-
print(isinstance(dataloader, Iterator))
# True
在 DataLoader 中主要涉及3个类,其内部实例传递关系如下:
由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。
需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。
2.2 DataLoader 运行流程
最简单版本 DataLoader,具备如下功能:
Dataset 内部返回需要是 numpy 或者 tensor 对象
Sampler 直接 SequentialSampler 和 RandomSampler
BatchSampler 已经实现
collate_fn 仅仅考虑了 numpy 或者 tensor 对象
仅仅支持 num_works=0 即单进程
看起来功能非常单一,但是其实已经搭建起了整个框架,理解了这个最简框架才能去理解高级实现,其核心运行逻辑为:
-
def __next__(self):
-
# 返回 batch 个索引
-
index =
next(
self.batch_sampler)
-
# 利用索引去取数据
-
data = [
self.dataset[idx]
for idx
in index
-
# batch 维度聚合
-
data =
self.collate_fn(data)
-
return data
然后为了方便大家理解,特意绘制了如下代码运行流程图:
还是那句话:一定要对第1小节内容非常熟悉,否则里面这么多迭代器、生成器的调用,可能会把你绕晕。详细代码描述如下:
`self.batch_sampler = iter(batch_sampler)`。在 DataLoader 的类初始化,需要得到 BatchSampler 的迭代器对象
`index = next(self.batch_sampler)`。对于每次迭代,DataLoader 对象首先会调用 BatchSampler 的迭代器进行下一次迭代,具体是调用 BatchSampler 对象的 `__iter__` 方法
而 BatchSampler 对象的 `__iter__` 方法实际上是需要依靠 Sampler 对象进行迭代输出索引,Sampler 对象也是一个迭代器,当迭代 `batch_size` 次后就可以得到 `batch_size` 个数据索引
`data = [self.dataset[idx] for idx in index]`。有了 batch 个索引就可以通过不断调用 dataset 的 `__getitem__` 方法返回数据对象,此时 data 就包含了 batch 个对象
`data = self.collate_fn(data)`。将 batch 个对象输入给聚合函数,在第0个维度也就是 batch 维度进行聚合,得到类似 (b,...) 的对象
不断重复1-5步,就可以不断的输出一个一个 batch 的数据了
以上就是完整流程,如果理解有困难,你可以先看下一小结的代码实现,然后再返回去理解。
2.3 最简V1版本源代码
(1) Dataset
-
class Dataset(object):
-
# 只要实现了 __getitem__ 方法就可以变成迭代器
-
def __getitem__(self, index):
-
raise NotImplementedError
-
# 用于获取数据集长度
-
def __len__(self):
-
raise NotImplementedError
(2) Sampler
-
class Sampler(object):
-
def __init__(self, data_source):
-
pass
-
-
-
def __iter__(self):
-
raise NotImplementedError
-
-
-
def __len__(self):
-
raise NotImplementedError
-
class SequentialSampler(Sampler):
-
-
-
def __init__(self, data_source):
-
super(SequentialSampler,
self).__init_
_(data_source)
-
self.data_source = data_source
-
-
-
def __iter__(self):
-
# 返回迭代器,不然无法 for .. in ..
-
return iter(range(len(
self.data_source)))
-
-
-
def __len__(self):
-
return len(
self.data_source)
-
class BatchSampler(Sampler):
-
def __init__(
self, sampler, batch_size, drop_last):
-
self.sampler = sampler
-
self.batch_size = batch_size
-
self.drop_last = drop_last
-
-
-
def __iter__(
self):
-
batch = []
-
# 调用 sampler 内部的迭代器对象
-
for idx
in
self.sampler:
-
batch.append(idx)
-
# 如果已经得到了 batch 个 索引,则可以通过 yield
-
# 关键字生成生成器返回,得到迭代器对象
-
if len(batch) ==
self.batch_size:
-
yield batch
-
batch = []
-
if len(batch) >
0 and not
self.drop_last:
-
yield batch
-
-
-
def __len__(
self):
-
if
self.drop_last:
-
# 如果最后的索引数不够一个 batch,则抛弃
-
return len(
self.sampler)
// self.batch_size
-
else:
-
return (len(
self.sampler) +
self.batch_size -
1)
// self.batch_size
(3) DataLoader
-
class DataLoader(object):
-
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
-
batch_sampler=None, collate_fn=None, drop_last=False)
-
self.dataset = dataset
-
-
-
# 因为这两个功能是冲突的,假设 shuffle=True,
-
# 但是 sampler 里面是 SequentialSampler,那么就违背设计思想了
-
if sampler is not None and shuffle:
-
raise ValueError(
'sampler option is mutually exclusive with '
-
'shuffle')
-
-
-
if batch_sampler
is
not
None:
-
# 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
-
# 和 drop_last 四个参数就不能传入
-
# 因为这4个参数功能和 batch_sampler 功能冲突了
-
if batch_size !=
1
or shuffle
or sampler
is
not
None
or drop_last:
-
raise ValueError(
'batch_sampler option is mutually exclusive '
-
'with batch_size, shuffle, sampler, and '
-
'drop_last')
-
batch_size =
None
-
drop_last =
False
-
-
-
if sampler
is
None:
-
if shuffle:
-
sampler = RandomSampler(dataset)
-
else:
-
sampler = SequentialSampler(dataset)
-
-
-
# 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
-
if batch_sampler
is
None:
-
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
-
-
-
self.batch_size = batch_size
-
self.drop_last = drop_last
-
self.sampler = sampler
-
self.batch_sampler = iter(batch_sampler)
-
-
-
if collate_fn
is
None:
-
collate_fn = default_collate
-
self.collate_fn = collate_fn
-
-
-
# 核心代码
-
def __next__(self):
-
index = next(self.batch_sampler)
-
data = [self.dataset[idx]
for idx
in index]
-
data = self.collate_fn(data)
-
return data
-
-
-
# 返回自身,因为自身实现了 __next__
-
def __iter__(self):
-
return self
(4) collate_fn
-
def default_collate(batch):
-
elem = batch[
0]
-
elem_type = type(elem)
-
if isinstance(elem, torch.Tensor):
-
return torch.stack(batch,
0)
-
elif elem_type.__module__ ==
'numpy':
-
return default_collate([torch.as_tensor(b)
for b
in batch])
-
else:
-
raise NotImplementedError
(5) 调用完整例子
-
class SimpleV1Dataset(Dataset):
-
def __init__(self):
-
# 伪造数据
-
self.imgs = np.arange(
0,
16).reshape(
8,
2)
-
-
-
def __getitem__(self, index):
-
return
self.imgs[index]
-
-
-
def __len__(self):
-
return
self.imgs.shape[
0]
-
-
-
-
-
from simplev1_datatset import SimpleV1Dataset
-
simple_dataset = SimpleV1Dataset()
-
dataloader = DataLoader(simple_dataset, batch_size=
2, collate_fn=default_collate)
-
for data
in
dataloader:
-
print(data)
3 总结
本文是最小 DataLoader 系列文章的第一篇,重点是分析了 python 中迭代器相关知识,然后构建一个最简单的 DataLoader 类,用于加深到 DataLoader 流程的理解,功能比较简单。
后面慢慢完善,希望最终能实现完整功能。
github: https://github.com/hhaAndroid/miniloader
推荐阅读
机器学习算法工程师
一个用心的公众号
转载:https://blog.csdn.net/l7H9JA4/article/details/111939787