PyTorch 入门没那么难,摸鱼也能学会
上周五晚上,我正瘫在工位上刷知乎“程序员如何优雅地躺平”,突然收到运营同学的消息:“咱们能不能搞个智能推荐?用户最近老说首页内容太杂。”
我第一反应是——又来了。每次大促前都是这样,产品画完饼,运营背指标,最后锅全甩给后端和算法。
不过这次有点不一样。我们后端主栈是 Spring Boot,前端也还行,但算法团队去年裁员裁得只剩一个人,现在连个 baseline 都跑不起来。领导看我在 GitHub 上 fork 过几个 AI 项目(其实只是收藏夹吃灰),直接点名:“你不是参加过几次技术分享会嘛?试试用 PyTorch 搞个 demo 出来。”
行吧,反正最近在摸鱼准备跳槽,学点新东西简历也好看。于是这个周末,我没打游戏,没追剧,硬着头皮翻开了《深度学习入门:从零开始写神经网络》——哦不对,是官方 PyTorch 教程。
起手式:别被环境配置劝退
说实话,光是装 PyTorch 就差点让我放弃。conda、pip、CUDA 版本、cuDNN……我电脑还是公司发的那台三年前的 ThinkPad,显卡连 GPU 加速都不支持。还好 PyTorch 支持 CPU 模式,不然真得掏钱买新电脑了(老板肯定不批)。
# 我的安装命令,简单粗暴
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
装完跑个 hello world:
import torch
x = torch.rand(5, 3)
print(x)
输出一堆随机数,但至少没报错。那一刻,我仿佛看到了升职加薪(误)的曙光。
业务场景:做个简单的用户点击预测
我们的需求其实不复杂:给定用户的历史行为(比如点击过哪些文章),预测他会不会点击某篇新文章。这本质上是个二分类问题——点 or 不点。
数据从哪儿来?运营给了一个 CSV,包含用户 ID、文章 ID、是否点击,时间戳等字段。数据量不大,也就 10 万条,但脏得离谱:有重复记录、空值、甚至用户 ID 是字符串“test_user_001”。
吐槽一句:运营同学的数据清洗能力,大概和我的做饭水平差不多——能吃,但别细看。
我先用 pandas 做了基础清洗:
import pandas as pd
df = pd.read_csv("user_clicks.csv")
df = df.drop_duplicates()
df = df.dropna()
df["clicked"] = df["clicked"].astype(int) # 确保标签是 0/1
然后把用户 ID 和文章 ID 转成 embedding 的输入索引。这里用了 sklearn 的 LabelEncoder,虽然有点 overkill,但胜在简单:
from sklearn.preprocessing import LabelEncoding
user_enc = LabelEncoder()
item_enc = LabelEncoder()
df["user_idx"] = user_enc.fit_transform(df["user_id"])
df["item_idx"] = item_enc.fit_transform(df["item_id"])
n_users = len(user_enc.classes_)
n_items = len(item_enc.classes_)
模型设计:别一上来就搞 Transformer
很多新手(包括我)一接触深度学习就想上 BERT、GNN,结果连反向传播都调不明白。这次我决定佛系一点:先搞个 Wide & Deep 的简化版——其实就是两个 embedding 层拼接 + 全连接。
为啥选这个?因为:
- 可解释性还行(至少能说清楚“用户偏好”和“物品特征”)
- 训练快(我的破电脑扛得住)
- 代码短(摸鱼时间有限)
import torch.nn as nn
class ClickPredictor(nn.Module):
def __init__(self, n_users, n_items, embed_dim=32):
super().__init__()
self.user_embed = nn.Embedding(n_users, embed_dim)
self.item_embed = nn.Embedding(n_items, embed_dim)
self.fc = nn.Sequential(
nn.Linear(embed_dim * 2, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
nn.Sigmoid() # 输出 0~1 的概率
)
def forward(self, user_ids, item_ids):
user_emb = self.user_embed(user_ids)
item_emb = self.item_embed(item_ids)
x = torch.cat([user_emb, item_emb], dim=1)
return self.fc(x).squeeze()
注意最后用了 Sigmoid,因为我们要的是点击概率,不是 logits。这点我一开始忘了,loss 直接爆炸,训练一小时 loss 还是 0.69(对,就是 ln2,说明模型在瞎猜)。
数据加载:别手写 for 循环
PyTorch 的 Dataset 和 DataLoader 是真的香。以前我写模型都是自己切 batch,结果经常内存溢出或者 shuffle 不均匀。这次直接用官方工具:
from torch.utils.data import Dataset, DataLoader
class ClickDataset(Dataset):
def __init__(self, df):
self.users = torch.tensor(df["user_idx"].values, dtype=torch.long)
self.items = torch.tensor(df["item_idx"].values, dtype=torch.long)
self.labels = torch.tensor(df["clicked"].values, dtype=torch.float32)
def __len__(self):
return len(self.users)
def __getitem__(self, idx):
return self.users[idx], self.items[idx], self.labels[idx]
# 划分训练集和验证集
train_df = df.sample(frac=0.8)
val_df = df.drop(train_df.index)
train_loader = DataLoader(ClickDataset(train_df), batch_size=256, shuffle=True)
val_loader = DataLoader(ClickDataset(val_df), batch_size=256, shuffle=False)
shuffle=True 在训练时很重要,否则模型会“记住”数据顺序,泛化能力直接崩盘。
训练循环:loss 下不去?先看 learning rate
训练代码其实就那几行,但坑不少。我第一次用 Adam 优化器,lr=0.01,结果 loss 一路飙升。后来查文档才发现,embedding 层通常需要更小的学习率。
model = ClickPredictor(n_users, n_items)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss() # 二分类交叉熵
for epoch in range(10):
model.train()
total_loss = 0
for user_ids, item_ids, labels in train_loader:
optimizer.zero_grad()
preds = model(user_ids, item_ids)
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
# 验证
model.eval()
val_acc = 0
with torch.no_grad():
for user_ids, item_ids, labels in val_loader:
preds = model(user_ids, item_ids)
val_acc += ((preds > 0.5) == labels).float().mean().item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Val Acc: {val_acc/len(val_loader):.4f}")
关键点:
optimizer.zero_grad()别漏掉,否则梯度会累积(我第一次忘加,loss 直接 NaN)- 验证时记得
model.eval()和torch.no_grad(),省显存还防 bug - 用
BCELoss而不是CrossEntropyLoss,因为输出已经是概率了
跑了 10 个 epoch,验证准确率到了 82%。虽然比不上线上那些 fancy 模型,但比随机猜测(50%)强多了。运营看了 demo 说“可以先上线 A/B 测试”,我长舒一口气——终于不用改需求文档了!
和 Spring Boot 对接:API 怎么搞?
模型训好了,怎么给后端用?总不能让 Java 程序员直接调 Python 脚本吧(运维会疯)。
我用了两种方案对比:
| 方案 | 优点 | 缺点 |
|---|---|---|
| Flask 封装 API | 快速、简单 | 需要额外部署 Python 服务 |
| ONNX 导出 + Java 推理 | 无 Python 依赖 | 调试麻烦,版本兼容问题多 |
考虑到团队全是 Java 技术栈,而且 Spring Boot 项目已经有一套完善的监控和日志体系,我最终选择了 Flask 封装。虽然多一个服务,但开发效率高,还能复用现有的 CI/CD 流程。
Flask 代码就十几行:
from flask import Flask, request, jsonify
import torch
app = Flask(__name__)
model = torch.load("click_model.pth")
model.eval()
@app.route("/predict", methods=["POST"])
def predict():
data = request.json
user_id = data["user_id"]
item_id = data["item_id"]
# 这里要处理 ID 映射(略)
user_idx = user_enc.transform([user_id])[0]
item_idx = item_enc.transform([item_id])[0]
with torch.no_grad():
prob = model(
torch.tensor([user_idx]),
torch.tensor([item_idx])
).item()
return jsonify({"click_prob": prob})
Spring Boot 那边就当调第三方 API:
// 伪代码
RestTemplate restTemplate = new RestTemplate();
Map<String, Object> payload = Map.of("user_id", userId, "item_id", itemId);
ResponseEntity<Prediction> response = restTemplate.postForEntity(
"http://pytorch-service:8000/predict",
payload,
Prediction.class
);
虽然有点“微服务过度设计”的嫌疑,但好歹能跑。运维老大看了架构图直摇头:“又多一个服务要监控”,但也没拦着——毕竟双11快到了,能跑就行。
算法选择心得:别追求 SOTA
很多人一上来就想搞 SOTA(State-of-the-Art)模型,结果调参调到秃头,效果还不如逻辑回归。我这次的体会是:
- 业务指标比学术指标更重要:运营关心的是点击率提升多少,不是 AUC 多高
- 可维护性优先:代码越简单,线上出问题越容易 fix
- 数据质量 > 模型复杂度:我花 70% 时间清洗数据,30% 调模型
另外,PyTorch 的动态图机制真的适合快速实验。想改模型结构?加一行代码就行。不像 TensorFlow 1.x,改个 shape 都要重写整个 graph。
最后:躺平程序员的 AI 学习建议
作为一个在公司混了三年、天天想着跳槽的佛系程序员,我对想入门深度学习的同行有几点建议:
- 别等“准备好”再开始:我就是在 deadline 逼迫下学的,反而效率最高
- 从小任务入手:别一上来就搞 CV/NLP 大模型,先解决手头的小问题
- 重视工程落地:模型再准,不能集成到现有系统也是白搭
- 多参加技术分享会:我上次听了个 PyTorch Lightning 的分享,省了我三天 debug 时间
现在,这个点击预测模型已经在线上跑了一周,A/B 测试显示点击率提升了 5%。运营请我喝了杯瑞幸,领导说“可以考虑给你加点算法绩效”——虽然可能只是客套话。
但无所谓了。至少我学会了 PyTorch,简历上又能多写一行“熟悉深度学习框架”。下次面试,终于不用只聊 Spring Boot 的循环依赖问题了。
摸鱼学习,真香。

评论 0