PyTorch快速入门:深度学习框架初探

Agent实验员
2025-06-30 08:28
阅读 908

引言:一次从零开始的尝试

引言:一次从零开始的尝试

我第一次接触PyTorch,是在公司的一个图像分类项目中。那时候我们刚转型进入AI领域,团队里没有专门的算法工程师,大部分人都来自传统后端或前端背景。面对TensorFlow和PyTorch的选择时,我们犹豫了很久——TensorFlow有Keras这样的高级封装,看起来更容易上手;但最终还是选了PyTorch,因为它的动态计算图更符合我们这些开发者的编程直觉。

现在回头看,这绝对是一个正确的决定。随着PyTorch在学术界和工业界的迅猛发展,特别是Transformer、扩散模型等前沿技术大多基于PyTorch实现,掌握这个框架已经成了AI时代的一项必备技能。这篇文章想结合我自己的真实经历,分享一下我是怎么从零开始,一步步搞懂PyTorch,并成功用它完成一个实战项目的。


项目背景:图像分类任务的挑战

项目背景:图像分类任务的挑战

事情要回到2022年年初,我们接了一个新需求:为一家医疗影像公司搭建一个肺部CT图像异常检测系统。简单来说,就是要把海量的CT图像自动标注出是否存在结节、钙化、磨玻璃影等异常特征。当时我们拿到的数据量大概是3万张左右,标注相对完整,格式也统一(都是PNG),看似是一个“标准”的图像分类问题。

但实际操作下来才发现,没那么简单。

首先,数据分布不均。虽然整体样本有3万张,但某些类别比如钙化斑块,数量特别少,导致模型很容易过拟合到那些常见的标签上。

其次,图像质量参差不齐。有些CT图亮度不均匀、对比度低,甚至部分有伪影,这对模型泛化提出了很高的要求。

更重要的是,我们在训练初期用了一些传统的CV方法(比如SIFT+SVM)效果并不理想,准确率始终在70%以下。这时候大家意识到,必须引入深度学习模型,才能把准确率提上来。

于是,我们选择了PyTorch作为深度学习框架,开始了这段“摸着石头过河”的旅程。


选择PyTorch的理由:为什么不是TensorFlow?

AI模型训练过程-1

虽然TensorFlow也有它的优势,比如更适合部署生产环境,但我们最终选择PyTorch,主要有以下几个原因:

  1. 易读性高:PyTorch是动态图机制,代码写起来就像Python一样直观,调试起来非常方便。
  2. 社区活跃:很多最新的论文实现都发布在GitHub上,并且是以PyTorch为主的。我们当时想复现ResNet-50、EfficientNet等经典结构时,可以很方便地找到开源参考代码。
  3. 适合研发阶段:我们处于快速试错、不断迭代的阶段,PyTorch的灵活性更能满足这种需求。

还有一个很重要的原因是:PyTorch的文档和教程比TensorFlow更加贴近开发者视角。虽然两者官方文档都不错,但PyTorch的官方教程往往带着“动手实践”的味道,这对于非算法出身的人来说非常友好。


技术方案:从零搭起一个图像分类系统

我们的目标是构建一个能识别多种肺部异常类型的图像分类系统。最终采用的技术方案如下:

  • 模型架构:以ResNet-50为主干网络,做微调(fine-tune)
  • 数据增强:使用torchvision.transforms进行随机裁剪、旋转、调整亮度/对比度
  • 优化器:Adam,配合Cosine退火调度策略
  • 损失函数:Focal Loss,缓解类别不平衡问题
  • 硬件配置:一块NVIDIA A6000显卡,80G内存,Ubuntu 20.04 + PyTorch 1.13

整个流程大致可以分为以下几个步骤:

  1. 数据准备与预处理
  2. 构建模型结构
  3. 定义损失函数和优化器
  4. 编写训练循环
  5. 模型评估与调优

下面我会详细介绍一下其中几个关键环节。


数据处理:不只是加载而已

PyTorch提供了一个非常好用的torch.utils.data.Dataset类,允许我们自定义数据集。最初我直接继承这个类,自己写了一个读取CSV文件(包含图像路径和label)的代码:

from torch.utils.data import Dataset
import cv2
import pandas as pd

class LungCTDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = self.annotations.iloc[idx, 0]
        image = cv2.imread(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[idx, 1]))

        if self.transform:
            image = self.transform(image)

        return image, y_label

这段代码看起来没什么问题,但实际运行的时候发现:

  • cv2.imread()读出来的图像通道顺序是BGR,而PyTorch默认处理RGB,结果颜色完全对不上!
  • 图像尺寸不一致,有的是512x512,有的是256x256,不做统一的话后续会报错。
  • 类别严重不平衡,需要额外采样或加权。

经过几次修改之后,我们用了更稳健的方式处理图像:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

另外,我们也采用了WeightedRandomSampler来解决类别不平衡的问题:

from torch.utils.data.sampler import WeightedRandomSampler

# 计算每个类别的权重
class_counts = ... # 各类数量
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]

sampler = WeightedRandomSampler(weights=samples_weights,
                               num_samples=len(samples_weights),
                               replacement=True)

train_loader = DataLoader(dataset, batch_size=64, sampler=sampler)

这样做的好处是,在训练时,小样本类别会被频繁抽样,防止模型偏向多数类。


模型搭建与迁移学习

我们一开始尝试从头开始训练一个ResNet50模型,结果训练了快一周还没收敛。后来果断改成迁移学习方式,用ImageNet预训练好的ResNet50,冻结底层参数,只训练最后几层:

model = resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False

# 替换最后一层全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)  # num_classes是我们自己的分类数

为了更好地调整学习率,我们还引入了余弦退火:

from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = Adam(model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

这部分的经验教训就是:

不要一上来就从头训练大模型,尤其是当你的数据量不是非常大的时候。迁移学习+Fine Tune几乎总是更快见效。


踩过的坑与解决方案

坑一:图像通道顺序搞反了

前面提到,OpenCV是按BGR格式读取图像,而大多数预训练模型期望的是RGB格式。如果你忘了转换,颜色就会出错,严重影响训练效果。

解决方案:使用transforms.ToPILImage()替代cv2.imread(),或者手动转换通道顺序:

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

坑二:模型输出值全是NaN

有一段时间,我们的模型输出全是NaN,loss也不下降。排查了半天才发现是归一化参数搞错了!

PyTorch推荐使用:

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

这是针对ImageNet数据集的统计值,如果你的数据差异比较大,会导致输入数值超出标准范围,出现梯度爆炸。

解决方案:如果是自己的数据集,最好自己计算mean和std:

def calculate_mean_std(loader):
    mean = 0.
    std = 0.
    total_images_count = 0
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images_count += batch_samples

    mean /= total_images_count
    std /= total_images_count
    return mean, std

坑三:多GPU训练的陷阱

当我们尝试在多个GPU上训练模型时,出现了显存暴涨的问题。

刚开始我们用了:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = nn.DataParallel(model)

但结果发现,DataParallel的性能反而不如单卡。这是因为:

  • DataParallel会把数据复制到每个GPU上,然后分别前向传播,再合并结果。对于大模型来说,这会占用大量内存。
  • 如果你想真正高效地使用多GPU,应该使用DistributedDataParallel(DDP)。

不过DDP配置更复杂一些,涉及到进程管理。如果你只是想加速训练,又不想花时间研究分布式训练细节,用单个GPU其实更好。


效果总结:准确率稳步提升

我们一共训练了大约15轮(epoch),每一轮平均耗时约12分钟。随着训练的推进,准确率逐渐上升:

Epoch Train Accuracy Validation Accuracy
1 72.3% 68.2%
5 81.4% 76.5%
10 87.6% 83.1%
15 90.8% 85.6%

虽然没有达到预期的90%以上,但相比之前的CV方法,已经是质的飞跃。而且考虑到数据质量和标注偏差的存在,这个结果还算说得过去。

最终我们将模型打包成REST API接口,通过FastAPI对外提供服务,支持图像上传并返回预测结果。整个系统上线后运行稳定,得到了客户的认可。


我的几点经验总结

在这次项目实践中,我积累了不少关于PyTorch的实际使用经验。这里总结几点给正在入门的朋友们:

  1. 先跑起来,再优化
    刚接触PyTorch时,不要一上来就追求完美设计,而是先跑通一个完整的流程,哪怕只是一个简单的分类任务。这样你会更直观地理解各个模块是怎么协作的。

  2. 多看官方文档和教程
    PyTorch官网有很多高质量的示例教程,比如“Deep Learning with PyTorch: A 60 Minute Blitz”,非常适合初学者。建议一边看一遍动手敲。

  3. 善用Jupyter Notebook练手
    在调试模型结构和数据预处理时,Notebook是非常好用的工具。你可以随时修改一部分代码并看到结果,省去了反复启动训练的麻烦。

  4. 学会打印中间变量
    动态图的好处是可以随时print各种tensor的shape、type、device信息,帮助你快速定位问题。

  5. 记录训练过程,画个图看看趋势
    用TensorBoard可视化loss和accuracy曲线,不仅有助于调参,也可以让你直观地看出模型是否过拟合或欠拟合。

  6. 遇到问题多搜索,但不要迷信StackOverflow
    PyTorch更新很快,有时候网上搜到的答案可能是旧版本的方法,执行时会出错。记得看一下PyTorch的release notes,确认你用的是最新推荐做法。


结语:PyTorch是一把好刀

说实话,刚开始学PyTorch的时候我也很头疼。那堆module、tensor、backward、optim这些东西,看得人眼花缭乱。但是当你真的把它用在一个实际项目上,亲手调出第一个能收敛的模型,那种成就感真的是无可替代。

PyTorch就像一把功能强大但略显复杂的瑞士军刀,只要你愿意花时间熟悉它的每一把刃,它就能帮你解决无数技术难题。

希望这篇从实战出发、带着些许个人成长痕迹的文章,能帮你在学习PyTorch的路上少走弯路,走得更稳、更远。

如果你有任何问题,欢迎留言交流。毕竟,谁不是从一行import torch开始的呢?🚀

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝