PyTorch 入门没那么难,摸鱼也能学会

App数据
2025-12-22 06:11
阅读 870

上周五晚上,我正瘫在工位上刷知乎“程序员如何优雅地躺平”,突然收到运营同学的消息:“咱们能不能搞个智能推荐?用户最近老说首页内容太杂。”
我第一反应是——又来了。每次大促前都是这样,产品画完饼,运营背指标,最后锅全甩给后端和算法。

不过这次有点不一样。我们后端主栈是 Spring Boot,前端也还行,但算法团队去年裁员裁得只剩一个人,现在连个 baseline 都跑不起来。领导看我在 GitHub 上 fork 过几个 AI 项目(其实只是收藏夹吃灰),直接点名:“你不是参加过几次技术分享会嘛?试试用 PyTorch 搞个 demo 出来。”

行吧,反正最近在摸鱼准备跳槽,学点新东西简历也好看。于是这个周末,我没打游戏,没追剧,硬着头皮翻开了《深度学习入门:从零开始写神经网络》——哦不对,是官方 PyTorch 教程。


起手式:别被环境配置劝退

说实话,光是装 PyTorch 就差点让我放弃。conda、pip、CUDA 版本、cuDNN……我电脑还是公司发的那台三年前的 ThinkPad,显卡连 GPU 加速都不支持。还好 PyTorch 支持 CPU 模式,不然真得掏钱买新电脑了(老板肯定不批)。

# 我的安装命令,简单粗暴
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

装完跑个 hello world:

import torch
x = torch.rand(5, 3)
print(x)

输出一堆随机数,但至少没报错。那一刻,我仿佛看到了升职加薪(误)的曙光。


业务场景:做个简单的用户点击预测

我们的需求其实不复杂:给定用户的历史行为(比如点击过哪些文章),预测他会不会点击某篇新文章。这本质上是个二分类问题——点 or 不点。

数据从哪儿来?运营给了一个 CSV,包含用户 ID、文章 ID、是否点击,时间戳等字段。数据量不大,也就 10 万条,但脏得离谱:有重复记录、空值、甚至用户 ID 是字符串“test_user_001”。

吐槽一句:运营同学的数据清洗能力,大概和我的做饭水平差不多——能吃,但别细看。

我先用 pandas 做了基础清洗:

import pandas as pd

df = pd.read_csv("user_clicks.csv")
df = df.drop_duplicates()
df = df.dropna()
df["clicked"] = df["clicked"].astype(int)  # 确保标签是 0/1

然后把用户 ID 和文章 ID 转成 embedding 的输入索引。这里用了 sklearnLabelEncoder,虽然有点 overkill,但胜在简单:

from sklearn.preprocessing import LabelEncoding

user_enc = LabelEncoder()
item_enc = LabelEncoder()

df["user_idx"] = user_enc.fit_transform(df["user_id"])
df["item_idx"] = item_enc.fit_transform(df["item_id"])

n_users = len(user_enc.classes_)
n_items = len(item_enc.classes_)

模型设计:别一上来就搞 Transformer

很多新手(包括我)一接触深度学习就想上 BERT、GNN,结果连反向传播都调不明白。这次我决定佛系一点:先搞个 Wide & Deep 的简化版——其实就是两个 embedding 层拼接 + 全连接。

为啥选这个?因为:

  • 可解释性还行(至少能说清楚“用户偏好”和“物品特征”)
  • 训练快(我的破电脑扛得住)
  • 代码短(摸鱼时间有限)
import torch.nn as nn

class ClickPredictor(nn.Module):
    def __init__(self, n_users, n_items, embed_dim=32):
        super().__init__()
        self.user_embed = nn.Embedding(n_users, embed_dim)
        self.item_embed = nn.Embedding(n_items, embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()  # 输出 0~1 的概率
        )

    def forward(self, user_ids, item_ids):
        user_emb = self.user_embed(user_ids)
        item_emb = self.item_embed(item_ids)
        x = torch.cat([user_emb, item_emb], dim=1)
        return self.fc(x).squeeze()

注意最后用了 Sigmoid,因为我们要的是点击概率,不是 logits。这点我一开始忘了,loss 直接爆炸,训练一小时 loss 还是 0.69(对,就是 ln2,说明模型在瞎猜)。


数据加载:别手写 for 循环

PyTorch 的 DatasetDataLoader 是真的香。以前我写模型都是自己切 batch,结果经常内存溢出或者 shuffle 不均匀。这次直接用官方工具:

from torch.utils.data import Dataset, DataLoader

class ClickDataset(Dataset):
    def __init__(self, df):
        self.users = torch.tensor(df["user_idx"].values, dtype=torch.long)
        self.items = torch.tensor(df["item_idx"].values, dtype=torch.long)
        self.labels = torch.tensor(df["clicked"].values, dtype=torch.float32)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return self.users[idx], self.items[idx], self.labels[idx]

# 划分训练集和验证集
train_df = df.sample(frac=0.8)
val_df = df.drop(train_df.index)

train_loader = DataLoader(ClickDataset(train_df), batch_size=256, shuffle=True)
val_loader = DataLoader(ClickDataset(val_df), batch_size=256, shuffle=False)

shuffle=True 在训练时很重要,否则模型会“记住”数据顺序,泛化能力直接崩盘。


训练循环:loss 下不去?先看 learning rate

训练代码其实就那几行,但坑不少。我第一次用 Adam 优化器,lr=0.01,结果 loss 一路飙升。后来查文档才发现,embedding 层通常需要更小的学习率。

model = ClickPredictor(n_users, n_items)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()  # 二分类交叉熵

for epoch in range(10):
    model.train()
    total_loss = 0
    for user_ids, item_ids, labels in train_loader:
        optimizer.zero_grad()
        preds = model(user_ids, item_ids)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # 验证
    model.eval()
    val_acc = 0
    with torch.no_grad():
        for user_ids, item_ids, labels in val_loader:
            preds = model(user_ids, item_ids)
            val_acc += ((preds > 0.5) == labels).float().mean().item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Val Acc: {val_acc/len(val_loader):.4f}")

关键点:

  • optimizer.zero_grad() 别漏掉,否则梯度会累积(我第一次忘加,loss 直接 NaN)
  • 验证时记得 model.eval()torch.no_grad(),省显存还防 bug
  • BCELoss 而不是 CrossEntropyLoss,因为输出已经是概率了

跑了 10 个 epoch,验证准确率到了 82%。虽然比不上线上那些 fancy 模型,但比随机猜测(50%)强多了。运营看了 demo 说“可以先上线 A/B 测试”,我长舒一口气——终于不用改需求文档了!


和 Spring Boot 对接:API 怎么搞?

模型训好了,怎么给后端用?总不能让 Java 程序员直接调 Python 脚本吧(运维会疯)。

我用了两种方案对比:

方案 优点 缺点
Flask 封装 API 快速、简单 需要额外部署 Python 服务
ONNX 导出 + Java 推理 无 Python 依赖 调试麻烦,版本兼容问题多

考虑到团队全是 Java 技术栈,而且 Spring Boot 项目已经有一套完善的监控和日志体系,我最终选择了 Flask 封装。虽然多一个服务,但开发效率高,还能复用现有的 CI/CD 流程。

Flask 代码就十几行:

from flask import Flask, request, jsonify
import torch

app = Flask(__name__)
model = torch.load("click_model.pth")
model.eval()

@app.route("/predict", methods=["POST"])
def predict():
    data = request.json
    user_id = data["user_id"]
    item_id = data["item_id"]
    
    # 这里要处理 ID 映射(略)
    user_idx = user_enc.transform([user_id])[0]
    item_idx = item_enc.transform([item_id])[0]
    
    with torch.no_grad():
        prob = model(
            torch.tensor([user_idx]),
            torch.tensor([item_idx])
        ).item()
    
    return jsonify({"click_prob": prob})

Spring Boot 那边就当调第三方 API:

// 伪代码
RestTemplate restTemplate = new RestTemplate();
Map<String, Object> payload = Map.of("user_id", userId, "item_id", itemId);
ResponseEntity<Prediction> response = restTemplate.postForEntity(
    "http://pytorch-service:8000/predict", 
    payload, 
    Prediction.class
);

虽然有点“微服务过度设计”的嫌疑,但好歹能跑。运维老大看了架构图直摇头:“又多一个服务要监控”,但也没拦着——毕竟双11快到了,能跑就行。


算法选择心得:别追求 SOTA

很多人一上来就想搞 SOTA(State-of-the-Art)模型,结果调参调到秃头,效果还不如逻辑回归。我这次的体会是:

  • 业务指标比学术指标更重要:运营关心的是点击率提升多少,不是 AUC 多高
  • 可维护性优先:代码越简单,线上出问题越容易 fix
  • 数据质量 > 模型复杂度:我花 70% 时间清洗数据,30% 调模型

另外,PyTorch 的动态图机制真的适合快速实验。想改模型结构?加一行代码就行。不像 TensorFlow 1.x,改个 shape 都要重写整个 graph。


最后:躺平程序员的 AI 学习建议

作为一个在公司混了三年、天天想着跳槽的佛系程序员,我对想入门深度学习的同行有几点建议:

  1. 别等“准备好”再开始:我就是在 deadline 逼迫下学的,反而效率最高
  2. 从小任务入手:别一上来就搞 CV/NLP 大模型,先解决手头的小问题
  3. 重视工程落地:模型再准,不能集成到现有系统也是白搭
  4. 多参加技术分享会:我上次听了个 PyTorch Lightning 的分享,省了我三天 debug 时间

现在,这个点击预测模型已经在线上跑了一周,A/B 测试显示点击率提升了 5%。运营请我喝了杯瑞幸,领导说“可以考虑给你加点算法绩效”——虽然可能只是客套话。

但无所谓了。至少我学会了 PyTorch,简历上又能多写一行“熟悉深度学习框架”。下次面试,终于不用只聊 Spring Boot 的循环依赖问题了。

摸鱼学习,真香。

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝