一个被逼到用AI写代码的程序员,如何在TensorFlow、PyTorch和JAX之间反复横跳?
大家好,我是那个在GitHub上扒了上百个开源项目源码、组里干了快两年、每天都在为“可读性”和“可维护性”跟产品经理battle的后端老油条。说来惭愧,我其实是个前端出身——别笑!早年确实写过不少Vue和React组件,直到某次被线上OOM搞崩溃后,才彻底投奔了Python阵营。
但最近,我又被迫回到了“综合战场”:公司要搞一个智能推荐模块,不仅要训练模型,还得把推理服务嵌入到现有前端系统里。于是,我这个“伪全栈”就被领导点名:“你不是啥都懂点嘛?深度学习框架调研就交给你了。”
行吧,反正周末也睡不着(双11前夜谁睡得着?),干脆把TensorFlow、PyTorch和JAX三大主流框架拉出来遛一遛。这篇文章就是我在踩坑、debug、甚至差点砸键盘之后的真实血泪总结。不讲论文,不堆术语,只聊实战——尤其是怎么让这些框架在真实业务里跑起来还不被运维骂死。
起因:产品经理说“我们要实时个性化推荐”
事情是这样的。我们有个电商H5页面,用户滑动商品时,希望实时根据行为调整排序。听起来很简单?但问题在于:
- 前端必须轻量,不能塞个300MB的模型进去
- 后端API响应要在200ms内(否则用户直接划走)
- 模型得支持在线学习(用户刚看了几件衣服,就得立刻“记住”)
团队一开始想直接上TF Serving,毕竟公司老技术栈都是TensorFlow系。但测试发现,冷启动加载模型就要1.2秒,前端同学当场表演原地去世:“哥,用户早就关页面了!”
于是,我开始横向对比三大框架:TensorFlow 2.x、PyTorch 2.0、JAX + Flax。目标很明确:训练灵活、部署轻量、前端集成友好、代码别太难看(这点对我这种代码洁癖患者很重要)。
实战场景:用Movielens数据集模拟用户偏好
为了贴近业务,我拿了一个简化版的Movielens 100K数据集:943个用户对1682部电影的评分。任务是训练一个矩阵分解模型(Matrix Factorization),预测用户对未看影片的评分。
为啥选这么“古董”的模型?因为简单、可解释、适合做POC,而且能清晰看出框架差异。真上生产我们肯定用双塔DNN,但那是另一个故事了。
1. TensorFlow 2.x:老将稳重,但有点“爹味”
TensorFlow的优势不用多说:生态成熟、TF.js能直接跑前端、SavedModel部署一气呵成。但它的API设计……怎么说呢,像极了那种特别负责但总爱替你做决定的长辈。
# TensorFlow 2.x 实现矩阵分解
import tensorflow as tf
class MatrixFactorization(tf.keras.Model):
def __init__(self, n_users, n_items, emb_dim=64):
super().__init__()
self.user_emb = tf.keras.layers.Embedding(n_users, emb_dim)
self.item_emb = tf.keras.layers.Embedding(n_items, emb_dim)
def call(self, inputs):
user_ids, item_ids = inputs
user_vec = self.user_emb(user_ids)
item_vec = self.item_emb(item_ids)
return tf.reduce_sum(user_vec * item_vec, axis=1)
model = MatrixFactorization(943, 1682)
model.compile(optimizer='adam', loss='mse')
model.fit([user_ids, item_ids], ratings, epochs=10)
看起来挺清爽?但当你想自定义训练循环(比如加在线学习逻辑)时,@tf.function 的图模式会把你整懵。上周五晚上我就卡在一个TypeError: Cannot convert a symbolic Tensor to a numpy array上两小时——最后发现是因为在tf.function里用了Python的print()。
部署方面,SavedModel导出确实方便:
model.save('mf_model')
然后TF Serving一键加载。但模型体积?32MB。前端?别想了,TF.js加载它至少1.5秒。
吐槽点:TensorFlow的文档看似全面,实则绕晕人。你想找“如何自定义梯度”,结果翻了三页全是Keras高级API。
2. PyTorch 2.0:灵活如德芙,但部署是道坎
作为PyTorch粉(虽然以前黑过它动态图慢),这次体验让我直呼“真香”。它的Eager模式写起来就像写NumPy,调试时print(tensor)直接出数值,再也不用.numpy() .detach()满天飞。
# PyTorch 2.0 实现
import torch
import torch.nn as nn
class MF(nn.Module):
def __init__(self, n_users, n_items, emb_dim=64):
super().__init__()
self.user_emb = nn.Embedding(n_users, emb_dim)
self.item_emb = nn.Embedding(n_items, emb_dim)
def forward(self, user_ids, item_ids):
return (self.user_emb(user_ids) * self.item_emb(item_ids)).sum(1)
model = MF(943, 1682)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
pred = model(user_ids, item_ids)
loss = nn.MSELoss()(pred, ratings)
loss.backward()
optimizer.step()
训练过程丝滑得像德芙巧克力。而且PyTorch 2.0的torch.compile()真的快!在我的MacBook Pro上,训练时间比TF快了18%。
但问题来了:怎么部署?
- TorchScript?导出后模型28MB,但前端不支持。
- ONNX?转换过程一堆op不兼容,报错
Unsupported node type: aten::embedding。 - 最终方案:用TorchServe,但需要额外写handler,还要配YAML。运维小哥看到后直接翻白眼:“又来?”
不过好消息是,PyTorch现在和前端也能联动了——通过ONNX Runtime Web,可以把模型转成WebAssembly跑在浏览器里。虽然折腾,但至少可行。
3. JAX + Flax:极客玩具,还是未来之光?
说实话,一开始我是拒绝JAX的。函数式编程?纯函数?自动微分?听起来像学术派在自嗨。但当我看到它在TPU上的表现,以及超小的模型体积,我动摇了。
# JAX + Flax 实现
import jax
import jax.numpy as jnp
from flax import linen as nn
class MF(nn.Module):
n_users: int
n_items: int
emb_dim: int = 64
@nn.compact
def __call__(self, user_ids, item_ids):
user_emb = nn.Embed(self.n_users, self.emb_dim)(user_ids)
item_emb = nn.Embed(self.n_items, self.emb_dim)(item_ids)
return jnp.sum(user_emb * item_emb, axis=-1)
# 训练用纯函数 + jit
def loss_fn(params, user_ids, item_ids, targets):
pred = MF(943, 1682).apply(params, user_ids, item_ids)
return jnp.mean((pred - targets) ** 2)
grad_fn = jax.jit(jax.value_and_grad(loss_fn))
JAX的核心优势在于:一切皆函数。没有隐藏状态,没有副作用,调试时可以直接print(params)。而且jax.jit编译后速度飞起,训练比PyTorch还快5%。
最让我惊喜的是模型导出:直接序列化成msgpack或pickle,只有8MB! 因为JAX模型本质上就是参数字典+纯函数,根本不需要笨重的运行时。
但代价是:学习曲线陡峭。你得理解PRNGKey、vmap、pmap这些概念。第一次写的时候,我把随机种子搞错了,导致每次推理结果都不一样,差点以为模型疯了。
综合对比:不只是速度,更是工程体验
我整理了一张实战对比表,结合了训练、部署、前端集成和代码维护性:
| 维度 | TensorFlow 2.x | PyTorch 2.0 | JAX + Flax |
|---|---|---|---|
| 训练灵活性 | 中(图模式限制多) | 高(Eager模式自由) | 极高(纯函数+jit) |
| 模型体积 | 32MB | 28MB | 8MB |
| 前端集成 | TF.js(直接支持) | ONNX Runtime Web(需转换) | 需自研推理引擎 |
| 调试体验 | 差(符号张量坑多) | 优秀(像NumPy) | 好(但需适应函数式) |
| 代码可读性 | 中(Keras封装深) | 高(直观) | 中(需理解JAX范式) |
| 团队上手成本 | 低(国内普及率高) | 中 | 高(小众) |
从综合角度看:
- 如果你有现成TF生态,且不介意模型体积,TensorFlow仍是稳妥选择。
- 如果追求开发效率和灵活性,PyTorch是大多数人的归宿。
- 如果你敢赌未来,且团队有极客精神,JAX在性能和轻量化上碾压对手。
最终方案:PyTorch + ONNX + 前端微服务
经过三轮AB测试,我们最终选了PyTorch训练 → ONNX导出 → C++推理服务的混合方案:
- 用PyTorch训练模型(开发快)
- 导出为ONNX(虽然踩坑,但社区方案成熟)
- 用ONNX Runtime C++ API写一个轻量推理服务(启动<50ms,内存<50MB)
- 前端通过WebSocket长连接获取实时Embedding(避免频繁HTTP请求)
代码片段:
# PyTorch导出ONNX
torch.onnx.export(
model,
(dummy_user, dummy_item),
"mf.onnx",
input_names=["user_id", "item_id"],
output_names=["score"],
dynamic_axes={"user_id": {0: "batch"}, "item_id": {0: "batch"}}
)
上线后,P99延迟从1200ms降到180ms,前端同学终于没再@我骂街。
写在最后:工具只是工具,人才是核心
说实话,写这篇文章时我还在用Cursor(没错,就是那个AI编程工具)。不是我不努力,而是当你要同时看三个框架的文档、调五个环境、回十个群消息时,AI真的能救命。比如刚才那段JAX代码,我让Cursor帮我补全了loss_fn的梯度计算——省了我查Flax文档半小时。
但工具再强,也替代不了对业务的理解。深度学习框架没有银弹,关键看你用在什么场景。TensorFlow适合大厂稳态系统,PyTorch适合快速迭代,JAX适合追求极致性能的团队。
而我?作为一个曾经的前端,现在的“综合”工程师,只希望下次产品经理提需求时,能先问问:“这事儿,能用规则搞定吗?非得上模型?”
(完)
P.S. 上周双11大促,我们的推荐模块零故障。运维请我喝了杯瑞幸——他说:“这次没半夜打电话,真难得。” 我笑了笑,默默打开了Cursor,开始写下周的复盘报告……

评论 0