PyTorch快速入门:深度学习框架初探

赵玉
2025-12-14 18:19
阅读 890

大家好,我是老K,一个在杭州卷了快五年的后端工程师。GitHub Copilot 付费用户(没错,就是那个被某些人骂“AI抄代码”的工具),用了快两年,Vim 没卸过,IDE 基本只在看同事代码时打开——主要是怕他们用花里胡哨的插件把我带偏了。目前在一家和阿里、网易抢人的中型互联网公司混日子,主业写 Springboot,副业研究怎么让简历在 Boss 直聘上多收到几个“你好,在看你的简历”……

去年双11前夜,我们组突然接到一个“战略性需求”:给推荐系统加个基于用户行为序列的深度排序模型。产品经理原话是:“要不整个 DNN 吧?我看隔壁组搞了个 Transformer,效果贼好。” 我当时正一边用 vim 改着第 8 版接口文档,一边心里默念“求你别提算法”,结果领导拍板:“老K,你不是简历上写了‘了解机器学习’吗?正好练练手。”

于是,我这个常年和 @Autowiredapplication.yml 打交道的 Java 老兵,被迫拿起了 Python 和 PyTorch


为什么是 PyTorch?而不是 TensorFlow?

说实话,一开始我想直接跑路。但考虑到跳槽时如果能在简历上加上“独立完成深度学习模型开发与部署”,那至少能多拿两个面试机会——尤其现在杭州这边大厂都在招“懂业务+会 AI”的全栈型后端。

选 PyTorch 而不是 TensorFlow,主要有三个原因:

  1. 生态友好:PyTorch 的社区文档对新手极其友好,不像 TF 那样动不动就让你先学 tf.function 再理解 graph mode
  2. 动态图调试爽:作为一个 Vim 党,习惯了逐行调试,PyT Torch 的 eager execution 模式让我可以直接 print(tensor) 看中间结果,不用像 TF 1.x 那样画计算图到怀疑人生。
  3. 求职市场更吃香:翻了下最近半年阿里 P6/P7 的 JD,80% 以上都写着“熟悉 PyTorch 优先”。网易伏羲实验室那边甚至直接要求“熟练使用 PyTorch 构建自定义模型”。

🤯 小插曲:第一次跑 import torch 报错 CUDA not available,我以为显卡坏了,后来发现是我本地没装 CUDA 驱动……运维小哥笑我说:“你连 GPU 服务器都没申请,就想跑训练?”


实战:从零构建一个点击率预估模型(CTR)

我们的业务场景很典型:用户浏览商品列表 → 点击或不点击 → 用历史行为预测下次点击概率。

数据来自公司内部埋点,字段包括:

  • user_id, item_id
  • 用户画像(年龄、城市、设备)
  • 商品特征(类目、价格、销量)
  • 行为序列(最近点击的 10 个 item_id)

目标:输入这些特征,输出点击概率(0~1)。

第一步:环境搭建(别跳过!)

# 强烈建议用 conda,不然 pip 装 torch 可能把你 Linux 系统搞崩
conda create -n dl python=3.9
conda activate dl
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install pandas numpy scikit-learn matplotlib

💡 个人经验:别在 Windows 上折腾深度学习。我上周五晚上加班到 11 点,就因为 WSL2 的路径权限问题,差点把键盘扔了。


第二步:数据预处理(比模型还重要)

PyTorch 的 DatasetDataLoader 是神器,尤其是配合 collate_fn 处理变长序列。

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch

class CTRDataset(Dataset):
    def __init__(self, df, user_map, item_map):
        self.df = df
        self.user_map = user_map
        self.item_list = [item_map.get(x, 0) for x in df['item_seq'].apply(eval)]  # 简化处理

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        user_id = self.user_map[row['user_id']]
        item_id = self.item_map.get(row['item_id'], 0)
        item_seq = torch.tensor(self.item_list[idx], dtype=torch.long)
        label = torch.tensor(row['label'], dtype=torch.float)
        return user_id, item_id, item_seq, label

# 使用示例
train_loader = DataLoader(
    CTRDataset(train_df, user2id, item2id),
    batch_size=512,
    shuffle=True,
    collate_fn=lambda x: tuple(map(torch.stack, zip(*x)))  # 自动 padding 不在这里做,简化
)

🙃 吐槽:产品经理说“数据已经清洗好了”,结果我打开 CSV 发现 item_seq 字段是字符串 "[1,2,3]",还得 eval() 解析……这要是线上事故,我得背锅。


第三步:模型搭建(别一上来就 Transformer!)

作为新手,我一开始也想直接上 BERT4Rec,但被 leader 劝住了:“先跑通 LR + Embedding,再迭代。”

最终我选择了经典的 Wide & Deep 结构,兼顾记忆性(wide)和泛化性(deep):

import torch.nn as nn

class WideAndDeep(nn.Module):
    def __init__(self, user_num, item_num, embed_dim=64, hidden_dims=[128, 64]):
        super().__init__()
        # Embedding 层
        self.user_embed = nn.Embedding(user_num, embed_dim)
        self.item_embed = nn.Embedding(item_num, embed_dim)
        
        # Deep 部分:MLP
        input_dim = embed_dim * 2  # user + item
        layers = []
        for h in hidden_dims:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))
            input_dim = h
        self.deep = nn.Sequential(*layers)
        
        # Wide 部分:直接连接原始特征(这里简化为 user_id + item_id 的交叉)
        self.wide = nn.Linear(embed_dim * 2, 1)
        self.final = nn.Linear(hidden_dims[-1] + 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, user_id, item_id, _):
        user_emb = self.user_embed(user_id)
        item_emb = self.item_embed(item_id)
        deep_input = torch.cat([user_emb, item_emb], dim=1)
        deep_out = self.deep(deep_input)
        wide_out = self.wide(deep_input)
        combined = torch.cat([deep_out, wide_out], dim=1)
        logits = self.final(combined)
        return self.sigmoid(logits).squeeze()

✅ 为什么不用 item_seq?因为第一版先验证 pipeline 是否通。等基础模型 AUC > 0.7 再引入序列模型(比如 GRU 或 DIN)。


第四步:训练 & 调优(GPU 是刚需)

model = WideAndDeep(user_num, item_num).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for user, item, seq, label in train_loader:
        user, item, label = user.cuda(), item.cuda(), label.cuda()
        pred = model(user, item, seq)
        loss = criterion(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")

踩坑记录:

问题 解决方案
loss 一直是 0.6931(log(2)) 标签没归一化,或者模型没学到东西 → 检查 label 是否为 float 类型
显存爆了 减小 batch_size,或用 torch.cuda.empty_cache()(治标不治本)
AUC 上不去 加入更多特征交叉,或换更深的网络

😭 最惨一次:训练跑了 3 小时,最后发现 label 全是 0,因为测试集漏了正样本……当时真的想砸电脑。


第五步:效果评估(别只看 loss!)

在推荐/广告场景,AUCLogLoss 比 accuracy 更有意义。

from sklearn.metrics import roc_auc_score

model.eval()
preds, labels = [], []
with torch.no_grad():
    for user, item, seq, label in val_loader:
        user, item = user.cuda(), item.cuda()
        pred = model(user, item, seq).cpu().numpy()
        preds.extend(pred)
        labels.extend(label.numpy())

auc = roc_auc_score(labels, preds)
print(f"Validation AUC: {auc:.4f}")

我们的第一版 Wide&Deep 模型 AUC 达到了 0.721,虽然不高,但比原来的 LR(0.68)有提升。leader 说:“够上线灰度了。”


和 Springboot 对接:让模型真正落地

模型训完不能只在 Jupyter Notebook 里 plt.show(),得服务化。

我们用 Flask 包了一层 REST API,然后被 Springboot 服务调用:

# inference.py
from flask import Flask, request, jsonify
import torch

app = Flask(__name__)
model = torch.load('ctr_model.pth').eval().cuda()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    user_id = torch.tensor([data['user_id']]).cuda()
    item_id = torch.tensor([data['item_id']]).cuda()
    with torch.no_grad():
        prob = model(user_id, item_id, None).item()
    return jsonify({'click_prob': prob})

Springboot 侧用 RestTemplate 调用:

// Java 代码(Springboot)
ResponseEntity<Map> response = restTemplate.postForEntity(
    "http://ml-service:5000/predict",
    Map.of("user_id", 123, "item_id", 456),
    Map.class
);
Double prob = (Double) response.getBody().get("click_prob");

🔥 真实场景:上线第一天 QPS 超过 1000,Flask 直接崩了。后来换成 TorchServe + Docker 容器化,才扛住流量。运维说:“你这 Python 服务比 Java 还娇气。”


写在最后:PyTorch 对求职真的有用吗?

坦白讲,如果你只是想在简历上加一行“熟悉 PyTorch”,那意义不大。但现在大厂(尤其是阿里、字节、网易)的后端岗,越来越看重“工程+算法”复合能力。

我在上个月面试某大厂时,面试官看到我简历里写了“使用 PyTorch 构建 CTR 模型并上线”,直接跳过八股文,问了半小时模型结构、特征工程、线上监控——最后给了 offer。

所以,别为了学而学。带着业务问题去学 PyTorch,才是最快的成长路径

另外,GitHub Copilot 在写 PyTorch 时真的香!比如我打 # define a GRU-based sequential model,它能自动生成完整 class,虽然要改,但省了查文档的时间。当然,别全信,有一次它把 nn.GRU 写成 nn.RNN,害我 debug 一小时……


给新手的几点建议

  • 别一上来就搞大模型:先跑通 MNIST,再上业务数据。
  • 重视数据质量:80% 的时间在清洗和构造特征。
  • GPU 是生产力工具:公司不给配?自己租阿里云 ECS(学生机便宜)。
  • 和业务方对齐指标:别只盯着 loss,问清楚“什么算好效果”。
  • 简历要写具体:别写“使用深度学习提升效果”,写“通过 Wide&Deep 模型将 AUC 从 0.68 提升至 0.72,带来 GMV +3%”。

现在我已经开始研究 DIN(Deep Interest Network)了,准备把 item_seq 加进去。如果顺利,下个月就能在简历里写“精通序列建模”(手动狗头)。

如果你也在杭州,正在用 Springboot 写 CRUD,但又想往 AI 方向靠——别犹豫,从 PyTorch 开始吧。说不定哪天,你也能在双11前夜,一边喝着瑞幸,一边看着自己的模型在线上跑出收益。

共勉。

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝