从零开始用PyTorch:一次真实项目中的深度学习初探

产品经理别看我
2025-06-14 09:04
阅读 1031

作为一个在互联网公司做AI开发的普通程序员,我经常面对的是业务驱动的技术挑战。去年年底,我们团队接到了一个需求:为公司的短视频平台搭建一个视频标签推荐系统,目标是通过分析视频内容自动打上多个关键词标签,用于提升内容分发效率。

这个项目对我来说是一次很好的契机,可以实践我在学校学过的深度学习知识,并且第一次深入使用 PyTorch 来完成真实世界的建模任务。这篇文章我就以这个项目为背景,分享一下自己初次使用 PyTorch 的经历,从“不知道怎么下手”到“有点上手”的全过程,希望能给刚开始接触深度学习框架的朋友一些启发和帮助。


问题描述:从“一脸懵”开始的需求

问题描述:从“一脸懵”开始的需求

我们的项目一开始其实是一个非常常见的多标签图像分类问题。但因为是基于视频,每个视频我们会抽帧处理,变成对每一帧图片进行分类。虽然问题不复杂,但我们面临几个关键挑战:

  1. 没有现成的标注数据集:需要人工标注训练集,时间紧迫。
  2. 模型结构不确定:ResNet、VGG、EfficientNet……该选哪个?要不要加注意力机制?
  3. 训练环境配置困难:本地环境跑不动大模型,必须上服务器,但又怕写错代码浪费GPU资源。
  4. 部署和线上服务集成还不知道怎么做:这虽然是后话,但也得提前考虑可扩展性。

那时候我其实已经了解过 TensorFlow,但听同事说 PyTorch 更适合研究和快速迭代,尤其适合我们这种边实验边改的情况,于是决定先用 PyTorch 上车试试看。


解决方案:选型 + 构建流程

解决方案:选型 + 构建流程

我们在确定了任务大致方向后,迅速制定了一个“试错式”的开发流程:

第一阶段:数据准备与特征提取

  • 视频抽帧(ffmpeg) → 图片裁剪统一尺寸 → 手动打标签(Excel表格+LabelImg)
  • 数据格式最终统一为 Image Folder 格式,方便后面用 ImageFolder 加载器

第二阶段:模型选择与构建

  • 最终选定 ResNet18 做 backbone,加了一个全连接层做多标签输出(Sigmoid)
  • 使用 torchvision.models.resnet18(pretrained=True) 预训练权重加速收敛
  • 冻结部分网络参数,只训练最后的分类层

第三阶段:训练调参与验证

  • 损失函数选择 BCEWithLogitsLoss(更适合多标签任务)
  • 优化器选用 AdamW,学习率设置为 5e-4,weight_decay=1e-2
  • 使用混合精度训练 (torch.cuda.amp) 提升训练速度
  • 在测试集上计算 precision / recall / mAP 作为评估指标

整个项目的周期大概是两周左右完成初版上线,后续还在持续优化中。


代码实践:PyTorch 快速上手实战

代码实践:PyTorch 快速上手实战

下面我放出项目中最核心的一段模型定义和训练循环代码,让大家直观感受 PyTorch 的写法和风格。

import torch
import torchvision.models as models
import torch.nn as nn

class VideoTagModel(nn.Module):
    def __init__(self, num_tags):
        super(VideoTagModel, self).__init__()
        # 使用预训练的 ResNet18 作为基础模型
        self.base_model = models.resnet18(pretrained=True)
        
        # 替换最后一层,改为支持多标签的输出
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(in_features, num_tags)

    def forward(self, x):
        return self.base_model(x)

然后是训练循环的一个简化版本:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VideoTagModel(num_tags=20).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)

for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

这段代码可能看起来有些简陋,但它完整地涵盖了 PyTorch 开发的核心流程:

  • 模型定义
  • 数据加载(可以用 Dataset + DataLoader 实现)
  • 损失函数和优化器
  • 单卡训练循环 + 使用混合精度加速

如果你之前用过 TensorFlow,可能会觉得 PyTorch “更像编程语言”,而不是“声明式框架”。这种灵活性也正是它适合研究和探索的根源所在。


踩坑经验:那些让人抓狂的瞬间

神经网络结构图-1

踩坑经验:那些让人抓狂的瞬间

说实话,在刚接触 PyTorch 的时候,踩了不少坑,这里分享几个印象比较深的:

❗ 1. 数据维度搞错了,损失一直不下降

一开始用了 BCELoss,但忘了把 logits 直接喂进去,而是加了 sigmoid。结果损失怎么都不下降,后来才发现应该用 BCEWithLogitsLoss,内置了 sigmoid 和 log 运算,还能防止梯度爆炸。

教训:PyTorch 的很多 loss 函数已经做了内部封装,不要手动加 activation。


❗ 2. 多标签的 label 编码方式不对

标签一开始我用的是 one-hot 编码成 [batch_size, 1],后来发现这样没法处理多个标签。正确的做法应该是每个样本对应一个长度为标签数的向量,元素是 0 或者 1,表示是否属于该类别。

教训:多标签任务要确保 label 是 float 类型,shape 是 [N, C],值域为 0/1。


❗ 3. 混合精度训练导致 NaN

某次加了 amp 后,训练过程突然出现 NaN,查了很久才意识到某些操作不支持 fp16 精度(比如除法、指数等),这时候需要加入 scale 操作或者调整某些层的 dtype。

教训:amp 不是万能的,建议初期先不用,确认模型稳定后再加。


效果总结:小数据也能跑出不错的结果

尽管我们的训练数据只有不到 1000 个样本(每条视频平均抽 5~10 帧),但由于使用了 ResNet18 的预训练模型,并加上了一些简单的数据增强(Rotate, Flip, ColorJitter),最终在验证集上的 F2 score 达到了约 0.72,基本满足初期上线要求。

当然,我们也在思考进一步改进的方向:

  • 引入 Transformer-based 模型(如 ViT)
  • 增加样本数量,引入弱监督或伪标签策略
  • 使用 TTA(Test Time Augmentation)提高预测稳定性

经验分享:给新手的一些建议

计算机视觉应用-2

  1. PyTorch 真的很好入门,特别是有 Python 基础的同学,几乎不需要太多前置知识就能写出训练脚本。
  2. 建议多看看官方文档和教程,尤其是 torchvision, torch.utils.data, nn, optim 这几个模块,90% 的任务都能搞定。
  3. 模型不要太复杂,先跑起来再说。很多时候不是你不会写,而是不敢动手。
  4. 调试时注意 GPU 显存分配问题,尤其是在多卡环境下,避免 OOM。
  5. 善用 Jupyter Notebook 快速尝试小片段代码,别上来就写完整训练脚本。

总结:PyTorch 让深度学习变得更“接地气”

回顾这次用 PyTorch 完成视频标签识别的经历,我觉得最大的收获不是算法本身,而是理解了“如何用工程化的思维去解决 AI 问题”。

PyTorch 并不是一个黑盒工具,它让我们能够真正控制每一个细节,同时又能快速验证想法。对于刚入门的开发者来说,它是一种友好的引导;而对于老手而言,它是强大的表达媒介。

如果你也在犹豫要不要开始用 PyTorch,我的建议是:先写一个小 demo,跑起来再说。哪怕只是一个最简单的线性回归,只要你让它跑起来了,你就离真正的“AI工程师”不远了。

也欢迎你在评论区和我一起交流你在 PyTorch 学习过程中的心得和问题,咱们一起进步 💪


📌 下一篇可能会写《如何用 PyTorch 部署模型到生产环境》,敬请期待 😄

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝