PyTorch快速入门:深度学习框架初探

张红_程序员
2025-12-13 09:48
阅读 736

上周五晚上十点半,我正坐在工位上盯着终端里疯狂滚动的训练日志,旁边产品经理小王又发来消息:“模型效果能不能再提两个点?明天演示要用。” 我盯着那条消息看了足足三分钟,差点把MacBook Air(对,就是那台8G内存还舍不得换的古董)摔了。但转念一想,算了,毕竟这个项目从立项开始就没给过完整的需求文档,现在能跑起来已经算烧高香了。

我是你们的老朋友阿哲,一个在某中厂AI组摸爬滚打快两年的算法工程师。日常开发用Mac,Windows只用来装测试环境——主要是怕自己哪天手滑把生产环境搞崩了。最近被领导“建议”深入学习PyTorch,说是公司接下来的重点方向要转向自研模型。其实我心里清楚,无非是之前那个TensorFlow 1.x项目上线后各种兼容性问题,运维老李已经找我喝了三次茶了。

为什么是PyTorch?

说实话,我一开始是抗拒的。毕竟在TF1.x时代,我已经习惯了写tf.Session()sess.run()那一套繁琐的流程。虽然TF2.x也转向了eager execution,但总觉得API设计有点别扭。而PyTorch呢?动态图、Pythonic、调试友好,简直是为我们这些“调参民工”量身定做的。

更重要的是,社区生态真的香。你看HuggingFace Transformers、Detectron2、FastAI,哪个不是PyTorch系的?连我们组最近要做的那个电商商品识别项目,开源方案清一色都是PyTorch实现。产品经理说要做“类似淘宝拍立淘”的功能,结果我搜了一圈,发现人家官方demo就是用PyTorch写的。

不过说到前端……咳咳,我知道你们看到标题里的“前端”可能会疑惑:一个算法工程师为啥要提前端?其实这里有个小故事。我们组去年双11期间上线了一个智能推荐模块,后端用Flask搭了个简单的API,前端同事直接把推理结果渲染到页面上。结果测试时发现,前端传过来的图片格式五花八门——有的带EXIF旋转信息,有的是WebP格式,甚至还有CMYK色彩空间的!最后还是我这个“算法狗”去帮前端写了图像预处理逻辑。所以说,在实际项目中,算法和前端的边界早就模糊了,你不了解点前端知识,连数据都接不住。

环境搭建:Mac用户的痛

先说个扎心的事实:很多PyTorch教程默认你用Linux或者Windows,但对我们Mac用户来说,GPU加速基本是奢望(除非你用M1/M2芯片的新机型)。我这台2019款MacBook Pro,只能靠CPU硬扛。不过好消息是,PyTorch对Apple Silicon的支持越来越好,如果你用的是M系列芯片,记得安装torch的arm64版本:

# 对于Intel Mac
pip3 install torch torchvision torchaudio

# 对于Apple Silicon Mac (M1/M2)
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

我试过用Docker,但Mac上的Docker Desktop性能实在感人,尤其是挂载volume的时候。后来干脆直接在本地虚拟环境里跑,反而更流畅。当然,正式训练肯定得扔到公司GPU服务器上——感谢运维老李,至少给我们配了4张V100。

动手实战:从零开始训练一个图像分类器

咱们不玩MNIST那种玩具数据集了,直接上CIFAR-10。虽然只有32x32的小图,但好歹是个彩色图像,而且类别也够多(10类)。最重要的是,加载方便

import torch
import torchvision
import torchvision.transforms as transforms

# 数据预处理:标准化 + 数据增强
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

这里有几个坑我得提醒:

  • num_workers在Mac上别设太大,否则会报BrokenPipeError。我一般设2,再多就卡死。
  • 标准化参数(0.4914, 0.4822, 0.4465)这些是从CIFAR-10训练集统计出来的均值和标准差,别瞎改。
  • 如果你用M1芯片,记得把num_workers设为0,不然会触发一个已知bug。

接下来是模型定义。为了省事,我直接用ResNet18,但去掉最后的全连接层,换成适合10分类的:

import torch.nn as nn
import torchvision.models as models

class CIFAR10ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(CIFAR10ResNet, self).__init__()
        # 加载预训练的ResNet18
        self.resnet = models.resnet18(pretrained=True)
        # 替换第一层卷积,因为CIFAR-10是32x32,而ImageNet是224x224
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()  # 去掉maxpool,防止特征图太小
        # 替换最后的全连接层
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.resnet(x)

model = CIFAR10ResNet()

吐槽时间:你知道吗?我第一次跑这段代码时,忘了改conv1,结果训练了半小时才发现准确率卡在10%(随机猜测水平)。当时真的想砸电脑,还好有Git,回滚了事。

训练循环是PyTorch最优雅的部分之一——完全像写普通Python代码:

import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

for epoch in range(10):  # 训练10个epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 200 == 199:    # 每200个batch打印一次
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}')
            running_loss = 0.0

print('Finished Training')

注意这里的optimizer.zero_grad()——这是新手最容易忘的步骤。我见过太多人问“为什么loss不下降”,结果就是忘了清零梯度。PyTorch的梯度是累加的,不像TF会自动重置。

调参与避坑指南

说到调参,这可是我们“炼丹师”的看家本领。上面的代码里,我故意用了比较保守的参数(比如lr=0.01),实际项目中肯定要调整。以下是我踩过的几个大坑:

参数 初始值 调整后 效果
learning_rate 0.01 0.1 训练速度提升3倍,但需要配合warmup
batch_size 128 256 显存占用增加,但收敛更稳定
optimizer SGD AdamW 少了手动调lr的烦恼,但最终精度略低

另外,一定要用验证集监控过拟合!我之前有个项目,训练集准确率99%,验证集才70%,结果上线后被用户骂惨了。现在我的习惯是在每个epoch结束后跑一遍验证:

# 验证代码片段
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

注意model.eval()torch.no_grad()这两个关键点。前者会关闭BatchNorm和Dropout,后者会禁用梯度计算,节省显存和计算资源。

和前端对接的那些事

回到开头提到的“前端”问题。假设你现在要把训练好的模型部署成一个Web服务,供前端调用。最简单的方式是用Flask:

from flask import Flask, request, jsonify
import torch
from PIL import Image
import io

app = Flask(__name__)
model = torch.load('cifar10_resnet.pth')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'no file'}), 400
    
    file = request.files['file']
    # 处理前端上传的图片
    img_bytes = file.read()
    img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
    
    # 应用和训练时相同的预处理
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    img_tensor = transform(img).unsqueeze(0)  # 增加batch维度
    
    with torch.no_grad():
        output = model(img_tensor)
        _, predicted = torch.max(output, 1)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return jsonify({'prediction': classes[predicted.item()]})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

这里的关键是:预处理必须和训练时完全一致!我之前就吃过亏,训练时用了RandomHorizontalFlip,但推理时忘了去掉,导致同一个图片多次预测结果不一样。前端同事差点以为我们的API有bug。

另外,前端传图经常会有尺寸问题。比如手机拍的照片可能是4000x3000,直接resize到32x32会失真。更好的做法是先等比例缩放到短边32,再中心裁剪——但这需要和前端约定好,或者在服务端做兼容处理。

性能优化:从小时级到分钟级

最后说说性能。我第一次在Mac上跑CIFAR-10训练,10个epoch花了将近3小时。扔到公司GPU服务器上,同样代码只要8分钟。但即使这样,我们还是要优化:

  1. 混合精度训练:用torch.cuda.amp,显存占用减半,速度提升30%
  2. DataLoader优化:增加pin_memory=True,加快CPU到GPU的数据传输
  3. 模型量化:推理时用torch.quantization,体积缩小4倍,速度提升2倍

特别是第三点,对前端特别友好——模型小了,JS加载更快(如果用ONNX.js的话)。虽然PyTorch原生不支持直接输出JS可用的格式,但可以先导出ONNX,再转成Web-friendly格式:

# 导出ONNX
dummy_input = torch.randn(1, 3, 32, 32)
torch.onnx.export(model, dummy_input, "cifar10.onnx", 
                  export_params=True, opset_version=11)

写在最后

从被逼着学PyTorch,到现在的真香现场,我最大的感受是:框架只是工具,核心还是算法思维。PyTorch之所以流行,是因为它让研究者能快速验证想法,而不是被工程细节绊住手脚。

当然,现实项目远比CIFAR-10复杂。我们组最近在做的商品识别,光数据清洗就花了两周——前端上传的图片里有水印、有遮挡、有反光,甚至还有一张是纯黑的(用户手抖拍的)。但有了PyTorch这套灵活的工具链,至少让我们能把精力集中在解决业务问题上,而不是和框架斗智斗勇。

哦对了,上周五那个演示……最后我用了一个小技巧:在损失函数里加了类别权重,重点提升“电子产品”类的准确率(因为产品经理说这次演示主要给数码部门看)。果然,演示时效果看起来棒极了。至于其他类别?等他们提需求再说吧,反正deadline已过,我又可以安心摸鱼学新东西了。

下次想看什么?PyTorch Lightning?还是TorchScript部署实战?留言区告诉我,反正我周末加班也闲着(不是)。

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝