PyTorch快速入门:深度学习框架初探
上周五晚上十点半,我正坐在工位上盯着终端里疯狂滚动的训练日志,旁边产品经理小王又发来消息:“模型效果能不能再提两个点?明天演示要用。” 我盯着那条消息看了足足三分钟,差点把MacBook Air(对,就是那台8G内存还舍不得换的古董)摔了。但转念一想,算了,毕竟这个项目从立项开始就没给过完整的需求文档,现在能跑起来已经算烧高香了。
我是你们的老朋友阿哲,一个在某中厂AI组摸爬滚打快两年的算法工程师。日常开发用Mac,Windows只用来装测试环境——主要是怕自己哪天手滑把生产环境搞崩了。最近被领导“建议”深入学习PyTorch,说是公司接下来的重点方向要转向自研模型。其实我心里清楚,无非是之前那个TensorFlow 1.x项目上线后各种兼容性问题,运维老李已经找我喝了三次茶了。
为什么是PyTorch?
说实话,我一开始是抗拒的。毕竟在TF1.x时代,我已经习惯了写tf.Session()和sess.run()那一套繁琐的流程。虽然TF2.x也转向了eager execution,但总觉得API设计有点别扭。而PyTorch呢?动态图、Pythonic、调试友好,简直是为我们这些“调参民工”量身定做的。
更重要的是,社区生态真的香。你看HuggingFace Transformers、Detectron2、FastAI,哪个不是PyTorch系的?连我们组最近要做的那个电商商品识别项目,开源方案清一色都是PyTorch实现。产品经理说要做“类似淘宝拍立淘”的功能,结果我搜了一圈,发现人家官方demo就是用PyTorch写的。
不过说到前端……咳咳,我知道你们看到标题里的“前端”可能会疑惑:一个算法工程师为啥要提前端?其实这里有个小故事。我们组去年双11期间上线了一个智能推荐模块,后端用Flask搭了个简单的API,前端同事直接把推理结果渲染到页面上。结果测试时发现,前端传过来的图片格式五花八门——有的带EXIF旋转信息,有的是WebP格式,甚至还有CMYK色彩空间的!最后还是我这个“算法狗”去帮前端写了图像预处理逻辑。所以说,在实际项目中,算法和前端的边界早就模糊了,你不了解点前端知识,连数据都接不住。
环境搭建:Mac用户的痛
先说个扎心的事实:很多PyTorch教程默认你用Linux或者Windows,但对我们Mac用户来说,GPU加速基本是奢望(除非你用M1/M2芯片的新机型)。我这台2019款MacBook Pro,只能靠CPU硬扛。不过好消息是,PyTorch对Apple Silicon的支持越来越好,如果你用的是M系列芯片,记得安装torch的arm64版本:
# 对于Intel Mac
pip3 install torch torchvision torchaudio
# 对于Apple Silicon Mac (M1/M2)
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
我试过用Docker,但Mac上的Docker Desktop性能实在感人,尤其是挂载volume的时候。后来干脆直接在本地虚拟环境里跑,反而更流畅。当然,正式训练肯定得扔到公司GPU服务器上——感谢运维老李,至少给我们配了4张V100。
动手实战:从零开始训练一个图像分类器
咱们不玩MNIST那种玩具数据集了,直接上CIFAR-10。虽然只有32x32的小图,但好歹是个彩色图像,而且类别也够多(10类)。最重要的是,加载方便:
import torch
import torchvision
import torchvision.transforms as transforms
# 数据预处理:标准化 + 数据增强
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
这里有几个坑我得提醒:
num_workers在Mac上别设太大,否则会报BrokenPipeError。我一般设2,再多就卡死。- 标准化参数
(0.4914, 0.4822, 0.4465)这些是从CIFAR-10训练集统计出来的均值和标准差,别瞎改。 - 如果你用M1芯片,记得把
num_workers设为0,不然会触发一个已知bug。
接下来是模型定义。为了省事,我直接用ResNet18,但去掉最后的全连接层,换成适合10分类的:
import torch.nn as nn
import torchvision.models as models
class CIFAR10ResNet(nn.Module):
def __init__(self, num_classes=10):
super(CIFAR10ResNet, self).__init__()
# 加载预训练的ResNet18
self.resnet = models.resnet18(pretrained=True)
# 替换第一层卷积,因为CIFAR-10是32x32,而ImageNet是224x224
self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.resnet.maxpool = nn.Identity() # 去掉maxpool,防止特征图太小
# 替换最后的全连接层
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
def forward(self, x):
return self.resnet(x)
model = CIFAR10ResNet()
吐槽时间:你知道吗?我第一次跑这段代码时,忘了改
conv1,结果训练了半小时才发现准确率卡在10%(随机猜测水平)。当时真的想砸电脑,还好有Git,回滚了事。
训练循环是PyTorch最优雅的部分之一——完全像写普通Python代码:
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
for epoch in range(10): # 训练10个epoch
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199: # 每200个batch打印一次
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}')
running_loss = 0.0
print('Finished Training')
注意这里的optimizer.zero_grad()——这是新手最容易忘的步骤。我见过太多人问“为什么loss不下降”,结果就是忘了清零梯度。PyTorch的梯度是累加的,不像TF会自动重置。
调参与避坑指南
说到调参,这可是我们“炼丹师”的看家本领。上面的代码里,我故意用了比较保守的参数(比如lr=0.01),实际项目中肯定要调整。以下是我踩过的几个大坑:
| 参数 | 初始值 | 调整后 | 效果 |
|---|---|---|---|
| learning_rate | 0.01 | 0.1 | 训练速度提升3倍,但需要配合warmup |
| batch_size | 128 | 256 | 显存占用增加,但收敛更稳定 |
| optimizer | SGD | AdamW | 少了手动调lr的烦恼,但最终精度略低 |
另外,一定要用验证集监控过拟合!我之前有个项目,训练集准确率99%,验证集才70%,结果上线后被用户骂惨了。现在我的习惯是在每个epoch结束后跑一遍验证:
# 验证代码片段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
注意model.eval()和torch.no_grad()这两个关键点。前者会关闭BatchNorm和Dropout,后者会禁用梯度计算,节省显存和计算资源。
和前端对接的那些事
回到开头提到的“前端”问题。假设你现在要把训练好的模型部署成一个Web服务,供前端调用。最简单的方式是用Flask:
from flask import Flask, request, jsonify
import torch
from PIL import Image
import io
app = Flask(__name__)
model = torch.load('cifar10_resnet.pth')
model.eval()
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'no file'}), 400
file = request.files['file']
# 处理前端上传的图片
img_bytes = file.read()
img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
# 应用和训练时相同的预处理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
img_tensor = transform(img).unsqueeze(0) # 增加batch维度
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
return jsonify({'prediction': classes[predicted.item()]})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
这里的关键是:预处理必须和训练时完全一致!我之前就吃过亏,训练时用了RandomHorizontalFlip,但推理时忘了去掉,导致同一个图片多次预测结果不一样。前端同事差点以为我们的API有bug。
另外,前端传图经常会有尺寸问题。比如手机拍的照片可能是4000x3000,直接resize到32x32会失真。更好的做法是先等比例缩放到短边32,再中心裁剪——但这需要和前端约定好,或者在服务端做兼容处理。
性能优化:从小时级到分钟级
最后说说性能。我第一次在Mac上跑CIFAR-10训练,10个epoch花了将近3小时。扔到公司GPU服务器上,同样代码只要8分钟。但即使这样,我们还是要优化:
- 混合精度训练:用
torch.cuda.amp,显存占用减半,速度提升30% - DataLoader优化:增加
pin_memory=True,加快CPU到GPU的数据传输 - 模型量化:推理时用
torch.quantization,体积缩小4倍,速度提升2倍
特别是第三点,对前端特别友好——模型小了,JS加载更快(如果用ONNX.js的话)。虽然PyTorch原生不支持直接输出JS可用的格式,但可以先导出ONNX,再转成Web-friendly格式:
# 导出ONNX
dummy_input = torch.randn(1, 3, 32, 32)
torch.onnx.export(model, dummy_input, "cifar10.onnx",
export_params=True, opset_version=11)
写在最后
从被逼着学PyTorch,到现在的真香现场,我最大的感受是:框架只是工具,核心还是算法思维。PyTorch之所以流行,是因为它让研究者能快速验证想法,而不是被工程细节绊住手脚。
当然,现实项目远比CIFAR-10复杂。我们组最近在做的商品识别,光数据清洗就花了两周——前端上传的图片里有水印、有遮挡、有反光,甚至还有一张是纯黑的(用户手抖拍的)。但有了PyTorch这套灵活的工具链,至少让我们能把精力集中在解决业务问题上,而不是和框架斗智斗勇。
哦对了,上周五那个演示……最后我用了一个小技巧:在损失函数里加了类别权重,重点提升“电子产品”类的准确率(因为产品经理说这次演示主要给数码部门看)。果然,演示时效果看起来棒极了。至于其他类别?等他们提需求再说吧,反正deadline已过,我又可以安心摸鱼学新东西了。
下次想看什么?PyTorch Lightning?还是TorchScript部署实战?留言区告诉我,反正我周末加班也闲着(不是)。

评论 0