深度学习框架实战对比:从PyTorch到TensorFlow,再到JS的离谱尝试

Java老码农
2026-01-13 18:32
阅读 505

成都的夏天又湿又闷,上周五晚上我窝在玉林路的小出租屋里,一边啃着兔头一边调试一个模型——别误会,不是我在搞什么AI艺术项目,而是被产品经理“温柔”地催着上线一个智能推荐模块。他说:“这个功能双11前必须上,不然你简历上就只能写‘精通加班’了。”

我翻了个白眼,但还是打开了终端。作为一个喜欢折腾新技术的老码农,这几年试过不少AI编程工具:GitHub Copilot 刚出来时天天用,结果经常给我生成一堆“看起来很对但跑不通”的代码;CodeWhisperer 也不错,但对中文注释支持一言难尽;最后兜兜转转,还是选了 Cursor —— 不光因为它的上下文理解强,更因为它能直接帮我重构整个训练脚本,省下不少头发。

这次要做的推荐系统,其实不算复杂:用户行为日志 + 商品特征,输出点击概率。但问题来了:用哪个深度学习框架?


起手就是PyTorch?别急,先看场景

很多新人一听到深度学习就喊“PyTorch yyds”,但实际工作中,真不是这么回事。我们团队之前有个实习生,上来就用 PyTorch 写了个模型,本地跑得飞起,结果部署到生产环境直接崩——运维小哥盯着那堆 .pth 文件和自定义 DataLoader,脸都绿了:“这玩意儿怎么塞进 Docker 镜像?”

所以这次我学乖了:先列需求:

  • 训练阶段:需要灵活调试、可视化方便(毕竟我这种懒人不想反复 print)
  • 推理阶段:要能快速集成到现有 Java 后端(别问,问就是历史包袱)
  • 团队协作:同事里有会 TensorFlow 的,也有只会 Keras 的
  • 附加要求:产品经理突发奇想,说能不能搞个“区块链存证”的 demo?(是的,他又看了某篇公众号)

好吧,那就来一场实战对比吧。


第一轮:PyTorch vs TensorFlow(Keras)

我先用公开的 Criteo CTR 数据集 做了个简化版实验。任务是二分类(是否点击),输入包含 13 个数值特征 + 26 个类别特征。

PyTorch 版本(灵活性拉满)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

class CTRModel(nn.Module):
    def __init__(self, num_embeddings, embedding_dim=8):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(n, embedding_dim) for n in num_embeddings
        ])
        self.dense = nn.Sequential(
            nn.Linear(13 + len(num_embeddings)*embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, dense_x, sparse_x):
        emb_out = torch.cat([
            emb(sparse_x[:, i]) for i, emb in enumerate(self.embeddings)
        ], dim=1)
        x = torch.cat([dense_x, emb_out], dim=1)
        return self.dense(x).squeeze()

优点很明显:结构清晰,debug 时打个断点就能看中间张量形状。而且配合 TensorBoard 或 Weights & Biases,训练曲线一目了然。

但缺点也致命:部署麻烦。虽然可以用 TorchScript 导出,但我们的后端是 Spring Boot,还得额外写 JNI 或启动 gRPC 服务——运维听了直摇头。

TensorFlow/Keras 版本(为生产而生)

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Embedding, Concatenate
from tensorflow.keras.models import Model

def build_keras_model(feature_columns, embedding_dim=8):
    inputs = []
    embeddings = []
    
    for col in feature_columns:
        if col['type'] == 'numeric':
            inp = Input(shape=(1,), name=col['name'])
            inputs.append(inp)
            embeddings.append(inp)
        else:
            inp = Input(shape=(1,), name=col['name'])
            emb = Embedding(col['vocab_size'], embedding_dim)(inp)
            emb = tf.squeeze(emb, axis=1)
            inputs.append(inp)
            embeddings.append(emb)
    
    concat = Concatenate()(embeddings)
    x = Dense(128, activation='relu')(concat)
    x = Dense(64, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=inputs, outputs=output)
    return model

Keras 的好处在于:一行 model.save('model.h5') 就能搞定保存,TensorFlow Serving 直接加载,Java 侧通过 REST API 调用就行。而且同事老王(那个只会 Keras 的)也能看懂,合作起来不费劲。

性能方面,在相同 GPU 上,两个框架训练速度差不多(PyTorch 略快 5%),但 TensorFlow 的推理延迟更低,尤其是在 batch size 较大时。

框架 训练速度 (samples/sec) 推理延迟 (ms) 部署难度 团队接受度
PyTorch 12,500 8.2 ⭐⭐⭐⭐ ⭐⭐
TensorFlow 11,800 6.5 ⭐⭐ ⭐⭐⭐⭐

结论:内部实验用 PyTorch,上线选 TensorFlow。稳字当头,毕竟谁也不想半夜被 PagerDuty 叫醒。


第二轮:JavaScript?真有人用 JS 做深度学习?

这时候产品经理又冒出来了:“听说现在前端也能跑 AI?要不咱们整个 Web 版 demo,用户行为实时预测,还能上链存证!”

我差点把可乐喷他脸上。但转念一想:反正周末没事,试试 TensorFlow.js 玩玩?

于是搞了个极简版:用 TF.js 在浏览器里加载预训练好的模型,输入表单数据,输出点击概率。代码长这样:

// 在浏览器中运行
const model = await tf.loadLayersModel('/model/model.json');

document.getElementById('predict-btn').addEventListener('click', async () => {
  const denseInputs = [age, income, ...]; // 13个数值
  const sparseInputs = [category_id, brand_id, ...]; // 26个类别ID
  
  // 构造输入张量
  const inputs = {};
  for (let i = 0; i < 13; i++) {
    inputs[`dense_${i}`] = tf.tensor2d([[denseInputs[i]]]);
  }
  for (let i = 0; i < 26; i++) {
    inputs[`sparse_${i}`] = tf.tensor2d([[sparseInputs[i]]], [1, 1], 'int32');
  }
  
  const pred = await model.predict(inputs);
  const prob = await pred.data();
  alert(`点击概率: ${(prob[0] * 100).toFixed(2)}%`);
});

结果?慢得感人。就算我把模型剪枝到只剩两层全连接,预测一次也要 300ms+。而且内存占用爆炸,Chrome 直接弹出“页面无响应”。

至于“区块链存证”?我用 ethers.js 把预测结果哈希一下发到 Rinkeby 测试网,gas fee 虽然免了,但等了 2 分钟才上链。用户早关网页了。

所以结论很明确:JS 做深度学习?除非是玩具项目,否则别碰。不过写在简历上倒是能显得“技术栈广”——面试官要是问“你用过 TF.js 吗?”,你可以淡定地说:“用过,还踩过坑。”


最终方案:混合架构 + 模型蒸馏

综合下来,我们采用了 训练用 PyTorch + 导出 ONNX + TensorFlow Serving 推理 的混合方案:

  1. 用 PyTorch 快速迭代模型(方便调参、可视化)
  2. 训练完后用 torch.onnx.export() 转成 ONNX 格式
  3. 用 TensorFlow 的 tf2onnx 工具转成 SavedModel
  4. 部署到 TF Serving,Java 后端通过 gRPC 调用

这样既保留了开发灵活性,又满足了生产稳定性。而且模型体积从 200MB 压缩到 60MB(用了量化 + 剪枝),QPS 提升了 3 倍。

最关键的是——双11当天没挂。运维请我喝了杯瑞幸,说“这次终于不用通宵了”。


写在最后:技术选型不是炫技

很多人(包括曾经的我)总想着用最新最酷的框架,觉得这样简历才好看。但现实是:线上系统稳定 > 技术新颖

我在成都这座节奏舒服的城市待久了,反而更明白:技术人的价值,不在于你用了多少“黑科技”,而在于你能不能用最稳妥的方式解决问题。

当然,如果你正在准备跳槽,那不妨在简历上写“熟悉 PyTorch/TensorFlow/ONNX/TensorFlow.js 全栈 AI 开发”——反正面试官也不会让你现场跑 JS 模型。

对了,最近我又用 Cursor 帮我自动优化了一段数据预处理脚本,效率提升明显。看来当初放弃 Copilot 是对的——有时候,工具选对了,比框架更重要。

下次再有人问我“该学 PyTorch 还是 TensorFlow”,我会说:都学,但上线时,听运维的。

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝