1 小时入门 PyTorch:从张量到多 GPU 神经网络训练
Sebastian Raschka(极客时间编译)
ML/AI Research Engineer
32 人已学习
免费领取
1 小时入门 PyTorch:从张量到多 GPU 神经网络训练
15
15
1.0x
00:00/00:00
登录|注册

06|设置高效的数据加载器

在前一节中,我们定义了一个自定义神经网络模型。在训练这个模型之前,我们必须简要讨论如何在 PyTorch 中创建高效的数据加载器,这些数据加载器将在训练模型时进行迭代。PyTorch 中数据加载的整体思路如图 1 所示。
图 1 展示了 PyTorch 实现的 DatasetDataLoader 类。其中,Dataset 类用于实例化定义如何加载每条数据记录的对象,而 DataLoader 则负责处理数据的打乱和批量组装方式。
根据图 1 中的说明,在本节中,我们将实现一个自定义的 Dataset 类,我们将使用该类创建训练集和测试集,然后使用这些数据集创建数据加载器。
下面我们通过创建一个简单的玩具数据集来开始,这个数据集包含 5 个训练样本,每个样本有两个特征。与训练样本相对应,我们还会创建一个包含对应类别标签的张量:其中三个样本属于类别 0,两个样本属于类别 1。此外,我们还会创建一个由两个样本组成的测试集。创建这个数据集的代码如下:
X_train = torch.tensor([
[-1.2, 3.1],
[-0.9, 2.9],
[-0.5, 2.6],
[2.3, -1.1],
[2.7, -1.5]
])
y_train = torch.tensor([0, 0, 0, 1, 1])
确认放弃笔记?
放弃后所记笔记将不保留。
新功能上线,你的历史笔记已初始化为私密笔记,是否一键批量公开?
批量公开的笔记不会为你同步至部落
公开
同步至部落
取消
完成
0/2000
荧光笔
直线
曲线
笔记
复制
AI
  • 深入了解
  • 翻译
    • 英语
    • 中文简体
    • 法语
    • 德语
    • 日语
    • 韩语
    • 俄语
    • 西班牙语
  • 解释
  • 总结

1. PyTorch中数据加载的整体思路是通过`Dataset`和`DataLoader`类实现数据加载和批量组装方式。 2. 自定义`Dataset`类的三个主要组件是`__init__`构造函数、`__getitem__`方法和`__len__`方法,用于实例化定义如何加载每条数据记录的对象。 3. 在`__init__`方法中,设置属性以访问后续的`__getitem__`和`__len__`方法,对于内存中的张量数据集,只需将X和y赋值给这些属性。 4. 在`__getitem__`方法中,通过索引从数据集中精确返回单个项目的指令,返回与单个训练样本或测试实例对应的特征和类别标签。 5. `__len__`方法包含获取数据集长度的指令,使用张量的`.shape`属性来返回特征数组中的行数。 6. 使用`DataLoader`类从自定义的`Dataset`类中采样数据,可以设置批次大小和是否打乱数据。 7. 遍历训练数据加载器,每个训练样本恰好被访问一次,可以设置`drop_last=True`来丢弃每个周期的最后一个批次。 8. 在`DataLoader`中的`num_workers`参数对于并行化数据加载和预处理至关重要,设置合适的值能显著提升效率,但需根据具体数据集规模和计算环境进行调整以达到最佳效果。 9. 在实际应用中,理解`num_workers`参数的权衡并审慎设置至关重要,最佳设置取决于硬件以及用于加载`Dataset`类中定义的训练示例的代码。 10. 设置`num_workers=4`通常会在许多真实数据集上带来最佳性能,但需根据具体情况进行调整。

该试读文章来自《1 小时入门 PyTorch:从张量到多 GPU 神经网络训练》,如需阅读全部文章,
请先领取课程
免费领取
登录 后留言

精选留言

由作者筛选后的优质留言将会公开显示,欢迎踊跃留言。
收起评论
显示
设置
留言
收藏
沉浸
阅读
分享
手机端
快捷键
回顶部