深度学习框架实战对比:从PyTorch到TensorFlow,再到JS的离谱尝试
成都的夏天又湿又闷,上周五晚上我窝在玉林路的小出租屋里,一边啃着兔头一边调试一个模型——别误会,不是我在搞什么AI艺术项目,而是被产品经理“温柔”地催着上线一个智能推荐模块。他说:“这个功能双11前必须上,不然你简历上就只能写‘精通加班’了。”
我翻了个白眼,但还是打开了终端。作为一个喜欢折腾新技术的老码农,这几年试过不少AI编程工具:GitHub Copilot 刚出来时天天用,结果经常给我生成一堆“看起来很对但跑不通”的代码;CodeWhisperer 也不错,但对中文注释支持一言难尽;最后兜兜转转,还是选了 Cursor —— 不光因为它的上下文理解强,更因为它能直接帮我重构整个训练脚本,省下不少头发。
这次要做的推荐系统,其实不算复杂:用户行为日志 + 商品特征,输出点击概率。但问题来了:用哪个深度学习框架?
起手就是PyTorch?别急,先看场景
很多新人一听到深度学习就喊“PyTorch yyds”,但实际工作中,真不是这么回事。我们团队之前有个实习生,上来就用 PyTorch 写了个模型,本地跑得飞起,结果部署到生产环境直接崩——运维小哥盯着那堆 .pth 文件和自定义 DataLoader,脸都绿了:“这玩意儿怎么塞进 Docker 镜像?”
所以这次我学乖了:先列需求:
- 训练阶段:需要灵活调试、可视化方便(毕竟我这种懒人不想反复 print)
- 推理阶段:要能快速集成到现有 Java 后端(别问,问就是历史包袱)
- 团队协作:同事里有会 TensorFlow 的,也有只会 Keras 的
- 附加要求:产品经理突发奇想,说能不能搞个“区块链存证”的 demo?(是的,他又看了某篇公众号)
好吧,那就来一场实战对比吧。
第一轮:PyTorch vs TensorFlow(Keras)
我先用公开的 Criteo CTR 数据集 做了个简化版实验。任务是二分类(是否点击),输入包含 13 个数值特征 + 26 个类别特征。
PyTorch 版本(灵活性拉满)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
class CTRModel(nn.Module):
def __init__(self, num_embeddings, embedding_dim=8):
super().__init__()
self.embeddings = nn.ModuleList([
nn.Embedding(n, embedding_dim) for n in num_embeddings
])
self.dense = nn.Sequential(
nn.Linear(13 + len(num_embeddings)*embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, dense_x, sparse_x):
emb_out = torch.cat([
emb(sparse_x[:, i]) for i, emb in enumerate(self.embeddings)
], dim=1)
x = torch.cat([dense_x, emb_out], dim=1)
return self.dense(x).squeeze()
优点很明显:结构清晰,debug 时打个断点就能看中间张量形状。而且配合 TensorBoard 或 Weights & Biases,训练曲线一目了然。
但缺点也致命:部署麻烦。虽然可以用 TorchScript 导出,但我们的后端是 Spring Boot,还得额外写 JNI 或启动 gRPC 服务——运维听了直摇头。
TensorFlow/Keras 版本(为生产而生)
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Embedding, Concatenate
from tensorflow.keras.models import Model
def build_keras_model(feature_columns, embedding_dim=8):
inputs = []
embeddings = []
for col in feature_columns:
if col['type'] == 'numeric':
inp = Input(shape=(1,), name=col['name'])
inputs.append(inp)
embeddings.append(inp)
else:
inp = Input(shape=(1,), name=col['name'])
emb = Embedding(col['vocab_size'], embedding_dim)(inp)
emb = tf.squeeze(emb, axis=1)
inputs.append(inp)
embeddings.append(emb)
concat = Concatenate()(embeddings)
x = Dense(128, activation='relu')(concat)
x = Dense(64, activation='relu')(x)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=output)
return model
Keras 的好处在于:一行 model.save('model.h5') 就能搞定保存,TensorFlow Serving 直接加载,Java 侧通过 REST API 调用就行。而且同事老王(那个只会 Keras 的)也能看懂,合作起来不费劲。
性能方面,在相同 GPU 上,两个框架训练速度差不多(PyTorch 略快 5%),但 TensorFlow 的推理延迟更低,尤其是在 batch size 较大时。
| 框架 | 训练速度 (samples/sec) | 推理延迟 (ms) | 部署难度 | 团队接受度 |
|---|---|---|---|---|
| PyTorch | 12,500 | 8.2 | ⭐⭐⭐⭐ | ⭐⭐ |
| TensorFlow | 11,800 | 6.5 | ⭐⭐ | ⭐⭐⭐⭐ |
结论:内部实验用 PyTorch,上线选 TensorFlow。稳字当头,毕竟谁也不想半夜被 PagerDuty 叫醒。
第二轮:JavaScript?真有人用 JS 做深度学习?
这时候产品经理又冒出来了:“听说现在前端也能跑 AI?要不咱们整个 Web 版 demo,用户行为实时预测,还能上链存证!”
我差点把可乐喷他脸上。但转念一想:反正周末没事,试试 TensorFlow.js 玩玩?
于是搞了个极简版:用 TF.js 在浏览器里加载预训练好的模型,输入表单数据,输出点击概率。代码长这样:
// 在浏览器中运行
const model = await tf.loadLayersModel('/model/model.json');
document.getElementById('predict-btn').addEventListener('click', async () => {
const denseInputs = [age, income, ...]; // 13个数值
const sparseInputs = [category_id, brand_id, ...]; // 26个类别ID
// 构造输入张量
const inputs = {};
for (let i = 0; i < 13; i++) {
inputs[`dense_${i}`] = tf.tensor2d([[denseInputs[i]]]);
}
for (let i = 0; i < 26; i++) {
inputs[`sparse_${i}`] = tf.tensor2d([[sparseInputs[i]]], [1, 1], 'int32');
}
const pred = await model.predict(inputs);
const prob = await pred.data();
alert(`点击概率: ${(prob[0] * 100).toFixed(2)}%`);
});
结果?慢得感人。就算我把模型剪枝到只剩两层全连接,预测一次也要 300ms+。而且内存占用爆炸,Chrome 直接弹出“页面无响应”。
至于“区块链存证”?我用 ethers.js 把预测结果哈希一下发到 Rinkeby 测试网,gas fee 虽然免了,但等了 2 分钟才上链。用户早关网页了。
所以结论很明确:JS 做深度学习?除非是玩具项目,否则别碰。不过写在简历上倒是能显得“技术栈广”——面试官要是问“你用过 TF.js 吗?”,你可以淡定地说:“用过,还踩过坑。”
最终方案:混合架构 + 模型蒸馏
综合下来,我们采用了 训练用 PyTorch + 导出 ONNX + TensorFlow Serving 推理 的混合方案:
- 用 PyTorch 快速迭代模型(方便调参、可视化)
- 训练完后用
torch.onnx.export()转成 ONNX 格式 - 用 TensorFlow 的
tf2onnx工具转成 SavedModel - 部署到 TF Serving,Java 后端通过 gRPC 调用
这样既保留了开发灵活性,又满足了生产稳定性。而且模型体积从 200MB 压缩到 60MB(用了量化 + 剪枝),QPS 提升了 3 倍。
最关键的是——双11当天没挂。运维请我喝了杯瑞幸,说“这次终于不用通宵了”。
写在最后:技术选型不是炫技
很多人(包括曾经的我)总想着用最新最酷的框架,觉得这样简历才好看。但现实是:线上系统稳定 > 技术新颖。
我在成都这座节奏舒服的城市待久了,反而更明白:技术人的价值,不在于你用了多少“黑科技”,而在于你能不能用最稳妥的方式解决问题。
当然,如果你正在准备跳槽,那不妨在简历上写“熟悉 PyTorch/TensorFlow/ONNX/TensorFlow.js 全栈 AI 开发”——反正面试官也不会让你现场跑 JS 模型。
对了,最近我又用 Cursor 帮我自动优化了一段数据预处理脚本,效率提升明显。看来当初放弃 Copilot 是对的——有时候,工具选对了,比框架更重要。
下次再有人问我“该学 PyTorch 还是 TensorFlow”,我会说:都学,但上线时,听运维的。

评论 0