PyTorch 深度学习实战
方远
LINE China 数据科学家
10381 人已学习
新⼈⾸单¥59
登录后,你可以任选3讲全文学习
课程目录
已完结/共 32 讲
开篇词 (1讲)
PyTorch 深度学习实战
15
15
1.0x
00:00/00:00
登录|注册

18 | 图像分类(下):如何构建一个图像分类模型?

你好,我是方远。欢迎来到第 18 节课的学习。
我相信经过上节课的学习,你已经了解了图像分类的原理,还初步认识了一些经典的卷积神经网络。
正所谓“纸上得来终觉浅,绝知此事要躬行”,今天就让我们把上节课的理论知识应用起来,一起从数据的准备、模型训练以及模型评估,从头至尾一起来完成一个完整的图像分类项目实践。
课程代码你可以从这里下载。

问题回顾

我们先来回顾一下问题背景,我们要解决的问题是,在众多图片中自动识别出极客时间 Logo 的图片。想要实现自动识别,首先需要分析数据集里的图片是啥样子的。
那我们先来看一张包含极客时间 Logo 的图片,如下所示。
你可以看到,Logo 占整张图片的比例还是比较小的,所以说,如果这个项目是真实存在的,目标检测其实更加合适。不过,我们可以将问题稍微修改一下,修改成自动识别极客时间宣传海报,这其实就很适合图像分类任务了。

数据准备

相比目标检测与图像分割来说,图像分类的数据准备还是比较简单的。在图像分类中,我们只需要将每个类别的图片放到指定的文件夹里就行了。
下图是我的图片组织方式,文件夹就是图片所属的类别。
logo 文件夹中存放的是 10 张极客时间海报的图片。
而 others 中,理论上应该是各种其它类型的图片,但这里为了简化问题,我这个文件夹中存放的都是小猫的图片。
确认放弃笔记?
放弃后所记笔记将不保留。
新功能上线,你的历史笔记已初始化为私密笔记,是否一键批量公开?
批量公开的笔记不会为你同步至部落
公开
同步至部落
取消
完成
0/2000
荧光笔
直线
曲线
笔记
复制
AI
  • 深入了解
  • 翻译
    • 英语
    • 中文简体
    • 中文繁体
    • 法语
    • 德语
    • 日语
    • 韩语
    • 俄语
    • 西班牙语
    • 阿拉伯语
  • 解释
  • 总结

构建图像分类模型的关键步骤包括数据准备、模型训练和模型评估。文章介绍了如何使用EfficientNet模型进行图像分类,包括加载预训练模型、模型微调和设定损失函数与优化方法。在模型评估方面,重点介绍了精确率与召回率的计算方法,并指出根据业务需求选择侧重的指标。作者强调了数据准备的重要性,模型训练中采用卷积神经网络,并提出了对模型调整的建议。整体而言,本文以实际项目实践为背景,结合了理论知识和实际操作,为读者提供了一个完整的图像分类项目实践过程,同时介绍了EfficientNet模型及其在图像分类中的应用。

仅可试看部分内容,如需阅读全部内容,请付费购买文章所属专栏
《PyTorch 深度学习实战》
新⼈⾸单¥59
立即购买
登录 后留言

全部留言(42)

  • 最新
  • 精选
  • vcjmhg
    思考题: 召回率衡量的是,在整个验证集中,模型能找到多少 Logo 图片。因此在尽可能的把线上所有极客时间的海报都找到,允许一些误召回的情况下,训练的的模型应该更侧重召回率。

    作者回复: 👍🏻👍🏻^^

    2021-11-24
    3
  • pencilCool
    精确率: 滥竽充数者几何 召回率:漏网之鱼者几何

    作者回复: 👍🏻👍🏻👍🏻👍🏻👍🏻👍🏻

    2023-08-09归属地:中国台湾
    1
  • Geek_be7ab2
    老师好,我通过nohup python train.py > myout.log 2>&1 & 方式运行train.py文件没有仍何问题,但是使用predict.py时发现只能对一张图片进行预测。代码如下: python predict.py --path ./data/val/logo/14.jpeg 这是为什么?

    作者回复: 你好,感谢留言。 因为我给出的predict只能对一张图片预测啊。 https://github.com/syuu1987/geekTime-image-classification/blob/main/predict.py#L19 想预测多张图片的话,可以自己手动改一下。

    2022-03-14
    1
  • narsil的梦
    老师好,请问下面这段代码里为什么调用 model 的时候一次传入了一个 batch 的图片数量而不是一张一张传入并计算 output?如果一次传入多张图片输入,是不是输出的 output 是平均值?但我看 EfficientNet 的 forward 函数确实一次只接受一张图片输入 for i, (images, target) in enumerate(train_loader): # compute output output = model(images) loss = criterion(output, target) print('Epoch ', epoch, loss)

    作者回复: 你好,narsil的梦,感谢你的留言。 forward接受的也是一个batch的数据。 forward的输入即为网络的输入,它是一个shape为(batch_size, 通道数,高,宽)的tensor。 output的输出也是一个batch_size的输出,而不是平均值。 当预测一张图片的时候,输入应该为(1, 通道数,高,宽)的tensor。

    2021-11-24
    3
    1
  • zhaobk
    老师好。可以说说练完以后,要这么进行验证吗?

    作者回复: 你好,zhaobk,谢谢留言。你说的验证是指如何评估吗? 评估的话可以按照文稿中的精确与召回对每个模型进行评估,然后选择合适的模型。 代码的话,sklearn中提供了混淆矩阵、精确和召回的方法。

    2021-11-22
    1
  • Wayne
    老师你好,我想请问一下如何找最新的CNN网络模型呢

    作者回复: 你好,简单点的可以直接去Pytorch与Tensorflow的github上找。 https://github.com/tensorflow/models https://www.kaggle.com/models?tfhub-redirect=true https://pytorch.org/hub/ 其次是去CVPR、ICCV等地方找,每年都会发很多新的相关论文。

    2023-12-29归属地:福建
  • ifelse
    问题:应该关注召回率

    作者回复: 👍🏻

    2023-12-07归属地:浙江
  • Zeurd
    老师想问一下,运行的话会报AttributeError: 'NoneType' object has no attribute 'read'的错,是因为predict文件里test路径的问题么? 还有就是我在尝试别的数据集的时候,那是个4分类任务,我对标签tosenor之后,标签size为啥是[1,1]啊,4个文件夹4种标签不应该是[1,4]么,但是这样模型的liner层我的class只能设1,怎么预测都是100%正确率,我知道应该是4的,但不知道标签那个target_transform应该怎么改

    作者回复: 你好, >>”AttributeError: 'NoneType' object has no attribute 'read'“ 这个问题一般都是路径的问题,数据没有正确读取到。 >>还有就是我在尝试别的数据集的时候,那是个4分类任务,我对标签tosenor之后,标签size为啥是[1,1]啊 类别是1,2,3,4这样的吧。one-hot之后会变成4维

    2022-10-02归属地:北京
  • gavin
    两个问题没弄明白: 1.transforms.Lambda(_norm_advprop)中不用传img参数进去的吗? 函数定义时不是有个img参数吗 2.train.py中,model.train()是什么用?如果是模型内部就定义了训练,后面那段训练还有什么用

    作者回复: 你好,感谢留言。 1. 属于Python语法的问题。_norm_advprop 是个函数 https://github.com/syuu1987/geekTime-image-classification/blob/main/dataset.py#L6 2. https://github.com/syuu1987/geekTime-image-classification/blob/main/train.py#L16与https://github.com/syuu1987/geekTime-image-classification/blob/main/predict.py#L18 是配合使用的。 但是model.train(mode=True)默认是True,一般来说不写也可以。 它的作用是会启动模型中的BN层与dropout层。 而model.eval()则是关闭BN与dropout,进行预测。

    2022-09-12归属地:北京
  • 赵启明
    老师好,如下代码执行后: model = EfficientNet.from_name('efficientnet-b0') num_ftrs = model._fc.in_features model._fc = nn.Linear(num_ftrs, 2) model.load_state_dict(torch.load('./data/checkpoint/checkpoint.pth.tar.epoch_9')) 会报这个错误: lib\site-packages\torch\nn\modules\module.py", line 1379, in load_state_dict state_dict = state_dict.copy() AttributeError: '_IncompatibleKeys' object has no attribute 'copy' 是什么原因?

    作者回复: 抱歉,回复迟了。 感觉是版本的问题,你看看你的pytorch版本跟课程的一不一样?

    2022-08-20归属地:北京
收起评论
显示
设置
留言
42
收藏
沉浸
阅读
分享
手机端
快捷键
回顶部