PyTorch快速入门:一个被算法同事“卷”进深度学习的后端老油条

慢慢写代码
2025-12-16 05:49
阅读 694

上周五晚上十一点,我瘫在深圳南山科技园某栋写字楼的工位上,盯着屏幕上一行又一行红得发紫的报错信息,心里只有一句话:“早知道就不答应帮算法组跑这个模型了……”

但事情还得从头说起。

我是滴滴干了四年的后端开发,主要搞司机端的核心业务——订单匹配、接单策略、位置上报这些。日常就是写 Java、调 RPC、怼 Redis,偶尔和产品经理battle一下“这个需求真的不能下周上线吗?”。用 Mac 写代码,Windows 只在测试兼容性时才勉强打开(说实话,那蓝屏频率比我司上线事故还高)。

去年双11前,公司搞了个“智能调度2.0”项目,说要用强化学习优化司机派单效率。我们后端团队本来以为就是改改接口、加个开关,结果算法组甩过来一句:“这次模型要实时推理,你们后端得支持 PyTorch 模型部署。”

我:???

我当时连 TensorFlow 和 PyTorch 有啥区别都说不清,只知道这俩名字经常出现在招聘要求里,后面还跟着“熟悉者优先”——懂的都懂,那就是“必须会”。

为了不被卷成麻花,我硬着头皮开始学 PyTorch。这篇笔记,就是我从“Hello World”到成功在测试环境跑通第一个 LSTM 预测模型的血泪史。如果你和我一样是个只会 CRUD 的后端,想快速上手深度学习框架,那这篇文章可能能帮你少踩几个坑。


为啥是 PyTorch?而不是 TensorFlow?

说实话,一开始我也纠结过。毕竟 TensorFlow 背靠 Google,社区大、文档全,连我司某些老系统还在用 TF Serving。但跟算法同事聊了几次后,发现他们清一色用 PyTorch。理由很实在:

  • Python 原生体验好:PyTorch 几乎就是 NumPy + 自动微分 + GPU 加速,写起来像普通 Python,调试也方便。
  • 动态图机制:不像 TF 1.x 那种先建图再执行的“声明式”风格,PyTorch 是“命令式”的,print() 随便打,断点随便设。
  • 学术界主流:GitHub 上新论文的代码 90% 以上都是 PyTorch 实现,复现起来快。

而且,我们团队有个不成文的规定:能用 Python 解决的问题,绝不碰 Java 写模型(虽然最终部署还是得转 ONNX 或 TorchServe)。所以,PyTorch 成了唯一选择。


环境搭建:别被 CUDA 劝退

作为后端,我对 Python 的印象还停留在 Flask + requests + pandas 的组合。结果一装 PyTorch,直接给我整不会了:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

注意那个 cu121 —— 这代表 CUDA 12.1。如果你像我一样,在 Mac 上开发(M1/M2 芯片),恭喜你,不用操心 CUDA,直接:

pip install torch torchvision torchaudio

但!一旦你要在 Linux 服务器上跑训练(比如公司内网的 A10 机器),就得确认 CUDA 版本是否匹配。我第一次在测试机上跑,因为驱动版本太低,直接报:

RuntimeError: Found no NVIDIA driver on your system.

当时真的想砸键盘。后来才知道,运维那边给的 Docker 镜像已经预装了正确的 PyTorch + CUDA,直接 pull 就行。血泪教训:别自己瞎装,问清楚团队的标准环境。


第一个模型:不是 MNIST,而是司机上线预测

网上教程清一色用 MNIST 手写数字识别当“Hello World”。但对我们这种业务导向的工程师来说,脱离业务的 demo 都是耍流氓

于是我和算法同事对齐了一个小目标:预测某个司机在未来 30 分钟是否会上线接单。输入是他过去 7 天的上线/下线时间序列,输出是 0 或 1。

数据长这样(简化版):

driver_id timestamp is_online
d1001 2024-05-01 08:00:00 1
d1001 2024-05-01 12:30:00 0
... ... ...

我们用滑动窗口生成样本,每个样本是 144 个时间点(每 5 分钟一个点,共 12 小时),标签是接下来 6 个点是否有任意一个为 1。

模型选了最简单的 LSTM —— 别笑,对我们这种刚入门的来说,CNN、Transformer 还是太重了。LSTM 足够轻量,又能捕捉时序依赖。

代码结构如下:

import torch
import torch.nn as nn

class DriverOnlineLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, output_size=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x shape: (batch, seq_len, 1)
        out, _ = self.lstm(x)
        # 取最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return self.sigmoid(out)

训练循环也很直白:

model = DriverOnlineLSTM()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()
        pred = model(batch_x)
        loss = criterion(pred, batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

关键点来了:你以为这就完了?No!

第一次跑,loss 降不下去,一直卡在 0.69(也就是 ln2,说明模型在瞎猜)。我差点以为自己不适合搞 AI。

后来发现两个问题:

  1. 数据没归一化:时间戳直接转成 Unix 时间戳喂进去,数值太大,梯度爆炸。
  2. 正负样本极度不均衡:司机大部分时间是离线的,正样本(上线)不到 10%。

解决办法:

  • 把时间戳转成“距离当天 0 点的分钟数”,再除以 1440 归一到 [0,1]
  • WeightedRandomSampler 重采样,或者直接在 BCELoss 里加 pos_weight
pos_weight = torch.tensor([9.0])  # 负样本是正样本的9倍
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

改完之后,loss 哗哗往下掉,AUC 从 0.55 一路冲到 0.82。那一刻,我仿佛看到了年终奖在向我招手。


开发心得:后端视角下的 PyTorch 陷阱

作为一个习惯了 Spring Boot 的后端,我在 PyTorch 里踩了不少“非传统”坑:

1. device 不统一,GPU 白搭

PyTorch 默认在 CPU 上跑。你想用 GPU?得手动 .to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

有次我忘了把 label tensor 移到 GPU,结果报错:

Expected all tensors to be on the same device, but found at least two devices...

这种错误在本地 Mac 上根本不会出现(因为没 GPU),一上测试机就炸。建议:封装一个 to_device 工具函数,所有 tensor 统一处理。

2. 模型保存别只存 weights

新手常犯的错误:

torch.save(model.state_dict(), "model.pth")

这只能保存参数,加载时还得重新定义模型结构。更好的做法是:

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss,
}, "checkpoint.pth")

或者直接保存整个模型(虽然不推荐,但调试快):

torch.save(model, "full_model.pth")

3. 别在训练时开 eval 模式

有次我为了看中间结果,不小心在训练循环里写了 model.eval(),导致 BatchNorm 和 Dropout 行为异常,loss 直接起飞。查了两小时才发现是模式错了。

记住:

  • 训练:model.train()
  • 推理/验证:model.eval()

综合:如何把模型塞进我们的后端服务?

光会训练不够,得上线。我们后端的要求很简单:

  • 低延迟:单次推理 < 50ms
  • 高并发:QPS > 1000
  • 无缝集成:最好能用 gRPC 调

方案对比:

方案 优点 缺点
直接嵌入 Python 服务 快,调试方便 性能差,难扩缩容
TorchServe 官方方案,支持多模型 配置复杂,监控弱
转 ONNX + Triton 高性能,跨语言 转换可能失败,需验证精度

我们最终选了 TorchServe,因为团队已经有 Python 微服务基础,而且它支持动态批处理(dynamic batching),能扛住早高峰流量。

部署流程:

  1. 导出 TorchScript 模型:
    traced_model = torch.jit.trace(model, example_input)
    torch.jit.save(traced_model, "driver_online.pt")
    
  2. 写 handler.py 处理请求/响应
  3. 启动 TorchServe:
    torchserve --start --model-store model_store --models driver_online=driver_online.pt
    

上线后压测,P99 延迟 32ms,完美达标。产品经理终于没再问“能不能明天上线”。


最后:AI 不是魔法,是工程

学 PyTorch 这两个月,最大的感悟是:深度学习不是黑箱,而是一堆需要精心调校的工程组件

数据清洗、特征工程、模型选择、超参调优、部署监控——每一步都能让你怀疑人生。但当你看到 A/B 测试结果显示派单成功率提升 2.3%,而 PM 在群里@你说“感谢后端兄弟给力支持”时,那种成就感,比修完一个 P0 线上 bug 还爽。

如果你也是后端,别被“算法”两个字吓住。PyTorch 对 Python 开发者极其友好,它不是要你变成数学家,而是让你多一把解决问题的锤子

至于我?现在已经开始看 Transformer 了。毕竟,听说下个项目要用 LLM 做司机意图识别……

(完)

注:本文所有代码均在 Mac M2 + Python 3.10 + PyTorch 2.2 环境下验证通过。线上环境为 Ubuntu 20.04 + CUDA 12.1 + A10 GPU。
警告:别在周五晚上十点答应算法同事帮忙跑模型,除非你想加班到凌晨三点。

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝