Pytorch-Sampler类学习笔记
前言
我们在训练神经网络时,如果数据量太大,无法一次性将数据放入到网络中进行训练,所以需要进行分批处理数据读取。这一个问题涉及到如何从数据集中进行读取数据的问题,pytorch框提供了Sampler基类与多个子类实现不同方式的数据采样。子类包括:
all = [
“BatchSampler”,
“RandomSampler”,
“Sampler”,
“SequentialSampler”,
“SubsetRandomSampler”,
“WeightedRandomSampler”,
]
它决定了在训练过程中如何从数据集(Dataset
)中选择样本
1.基类Sampler
1 | class Sampler(object): |
- 对于所有的采样器来说,都需要继承采样器类,**必须实现的方法为_iter_()**,也就是定义迭代器行为,返回可
选代对象。除此之外,采样器类并没有定义任何其它的方法
2、顺序采样Sequential Sampler
1 | class SequentialSampler(Sampler[int]): |
- 顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。
对于 len ()只负责返回数据源包含的数据个数;iter()方法负责返回一个可迭代对象,这个可选代对象是
由range产生的顺序数值序列,也就是说选代是按照顺序进行的。 - 常用于验证集或测试集上,因为测试过程中我们通常不需要对数据进行打乱,按照顺序采样即可。
前述几种方法都只需要self.data source实现了 len ()方法,因为这几种方法都仅仅使用了
len(self.data source)函数。
所以下面采用同样实现了 len()的list类型来代替Dataset类型做测试:
1 | # 定义数据和对应的采样器 |
3、随机采样RandomSampler
1 | class RandomSampler(Sampler[int]): |
- iter()方法,定义了核心的索引生成行为,其中if replacement判断处返回了两种随机值,根据是否在初始化中给出replacement参数决定是否重复采样,区别核心在于randint()函数生成的随机数学列是可能包含重复数值的,而randperm()函数生成的随机数序列是绝对不包含重复数值的
RandomSampler
从数据集中随机选择样本,且每个样本被选择的概率是相等的。通常用于打乱数据集中的样本顺序,特别是在训练阶段。每个样本的选择都是独立且均匀的。
下面分别测试是否使用replacement作为输入的情况,首先是不使用时:
1 | ran_sampler = sampler.RandomSampler(data_source=data) |
可以看出生成的随机索引是不重复的,下面是采用replacement参数的情况
1 | ran_sampler = sampler.RandomSampler(data_source=data, replacement=True) |
此时生成的随机索引是有重复的(1出现两次),也就是说第“1”条数据可能会被重复的采样。
4.子集随机采样Subset Random Sampler
1 | class SubsetRandomSampler(Sampler[int]): |
- 上述代码中len()的作用与前面几个类的相同,依旧是返回数据集的长度,区别在于iter()返回的并不是
随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身。需要注意的仍然是采样是不重复的,也是通过randperm()函数实现的。按照网上可以搜集到的资料,Subset Random sampler应该用于训练集、测试集和验证集的划分,下面将data划分为train和val两个部分,再次指出iter()返回的的不是索引,而是索引对应的数据: - 可以在指定的索引子集中进行随机采样,这样你可以控制哪些数据被用于训练或验证,而不是整个数据集。
1 | print('***********') |
5.加权随机采样WeightedRandomSampler
1 | class WeightedRandomSampler(Sampler[int]): |
- 对于Weighted Random Sampler类的 init()来说,replacement参数依旧用于控制采样是否是有放回的;
num sampler用于控制生成的个数;weights参数对应的是“样本”的权重而不是“类别的权重”。其中 iter_()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的 WeightedRandomSampler
按照给定的样本权重随机采样。每个样本的选择概率是与其权重成正比的。它通常用于数据集不平衡的情况,赋予少数类样本更大的权重,以增加其被采样的机会。在处理类别不平衡的数据时,可以通过设置每个样本的权重,使得少数类样本有更高的采样概率,帮助模型学习到更好的分类边界。
1 | # 加权随机采样 |
6、批采样BatchSampler
1 | class BatchSampler(Sampler[List[int]]): |
通过将 Sampler
和批次大小结合,BatchSampler
提供了一种高效的批量采样方式。它的返回值是一个批量样本的索引序列:
1 | seq_sampler = sampler.SequentialSampler(data_source=data) |