一个被逼到用AI写代码的程序员,如何在TensorFlow、PyTorch和JAX之间反复横跳?

字段又改名了
2025-12-15 14:29
阅读 743

大家好,我是那个在GitHub上扒了上百个开源项目源码、组里干了快两年、每天都在为“可读性”和“可维护性”跟产品经理battle的后端老油条。说来惭愧,我其实是个前端出身——别笑!早年确实写过不少Vue和React组件,直到某次被线上OOM搞崩溃后,才彻底投奔了Python阵营。

但最近,我又被迫回到了“综合战场”:公司要搞一个智能推荐模块,不仅要训练模型,还得把推理服务嵌入到现有前端系统里。于是,我这个“伪全栈”就被领导点名:“你不是啥都懂点嘛?深度学习框架调研就交给你了。”

行吧,反正周末也睡不着(双11前夜谁睡得着?),干脆把TensorFlow、PyTorch和JAX三大主流框架拉出来遛一遛。这篇文章就是我在踩坑、debug、甚至差点砸键盘之后的真实血泪总结。不讲论文,不堆术语,只聊实战——尤其是怎么让这些框架在真实业务里跑起来还不被运维骂死


起因:产品经理说“我们要实时个性化推荐”

事情是这样的。我们有个电商H5页面,用户滑动商品时,希望实时根据行为调整排序。听起来很简单?但问题在于:

  • 前端必须轻量,不能塞个300MB的模型进去
  • 后端API响应要在200ms内(否则用户直接划走)
  • 模型得支持在线学习(用户刚看了几件衣服,就得立刻“记住”)

团队一开始想直接上TF Serving,毕竟公司老技术栈都是TensorFlow系。但测试发现,冷启动加载模型就要1.2秒,前端同学当场表演原地去世:“哥,用户早就关页面了!”

于是,我开始横向对比三大框架:TensorFlow 2.x、PyTorch 2.0、JAX + Flax。目标很明确:训练灵活、部署轻量、前端集成友好、代码别太难看(这点对我这种代码洁癖患者很重要)。


实战场景:用Movielens数据集模拟用户偏好

为了贴近业务,我拿了一个简化版的Movielens 100K数据集:943个用户对1682部电影的评分。任务是训练一个矩阵分解模型(Matrix Factorization),预测用户对未看影片的评分。

为啥选这么“古董”的模型?因为简单、可解释、适合做POC,而且能清晰看出框架差异。真上生产我们肯定用双塔DNN,但那是另一个故事了。

1. TensorFlow 2.x:老将稳重,但有点“爹味”

TensorFlow的优势不用多说:生态成熟、TF.js能直接跑前端、SavedModel部署一气呵成。但它的API设计……怎么说呢,像极了那种特别负责但总爱替你做决定的长辈。

# TensorFlow 2.x 实现矩阵分解
import tensorflow as tf

class MatrixFactorization(tf.keras.Model):
    def __init__(self, n_users, n_items, emb_dim=64):
        super().__init__()
        self.user_emb = tf.keras.layers.Embedding(n_users, emb_dim)
        self.item_emb = tf.keras.layers.Embedding(n_items, emb_dim)

    def call(self, inputs):
        user_ids, item_ids = inputs
        user_vec = self.user_emb(user_ids)
        item_vec = self.item_emb(item_ids)
        return tf.reduce_sum(user_vec * item_vec, axis=1)

model = MatrixFactorization(943, 1682)
model.compile(optimizer='adam', loss='mse')
model.fit([user_ids, item_ids], ratings, epochs=10)

看起来挺清爽?但当你想自定义训练循环(比如加在线学习逻辑)时,@tf.function 的图模式会把你整懵。上周五晚上我就卡在一个TypeError: Cannot convert a symbolic Tensor to a numpy array上两小时——最后发现是因为在tf.function里用了Python的print()

部署方面,SavedModel导出确实方便:

model.save('mf_model')

然后TF Serving一键加载。但模型体积?32MB。前端?别想了,TF.js加载它至少1.5秒。

吐槽点:TensorFlow的文档看似全面,实则绕晕人。你想找“如何自定义梯度”,结果翻了三页全是Keras高级API。


2. PyTorch 2.0:灵活如德芙,但部署是道坎

作为PyTorch粉(虽然以前黑过它动态图慢),这次体验让我直呼“真香”。它的Eager模式写起来就像写NumPy,调试时print(tensor)直接出数值,再也不用.numpy() .detach()满天飞。

# PyTorch 2.0 实现
import torch
import torch.nn as nn

class MF(nn.Module):
    def __init__(self, n_users, n_items, emb_dim=64):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, emb_dim)
        self.item_emb = nn.Embedding(n_items, emb_dim)

    def forward(self, user_ids, item_ids):
        return (self.user_emb(user_ids) * self.item_emb(item_ids)).sum(1)

model = MF(943, 1682)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(user_ids, item_ids)
    loss = nn.MSELoss()(pred, ratings)
    loss.backward()
    optimizer.step()

训练过程丝滑得像德芙巧克力。而且PyTorch 2.0的torch.compile()真的快!在我的MacBook Pro上,训练时间比TF快了18%。

但问题来了:怎么部署?

  • TorchScript?导出后模型28MB,但前端不支持。
  • ONNX?转换过程一堆op不兼容,报错Unsupported node type: aten::embedding
  • 最终方案:用TorchServe,但需要额外写handler,还要配YAML。运维小哥看到后直接翻白眼:“又来?”

不过好消息是,PyTorch现在和前端也能联动了——通过ONNX Runtime Web,可以把模型转成WebAssembly跑在浏览器里。虽然折腾,但至少可行。


3. JAX + Flax:极客玩具,还是未来之光?

说实话,一开始我是拒绝JAX的。函数式编程?纯函数?自动微分?听起来像学术派在自嗨。但当我看到它在TPU上的表现,以及超小的模型体积,我动摇了。

# JAX + Flax 实现
import jax
import jax.numpy as jnp
from flax import linen as nn

class MF(nn.Module):
    n_users: int
    n_items: int
    emb_dim: int = 64

    @nn.compact
    def __call__(self, user_ids, item_ids):
        user_emb = nn.Embed(self.n_users, self.emb_dim)(user_ids)
        item_emb = nn.Embed(self.n_items, self.emb_dim)(item_ids)
        return jnp.sum(user_emb * item_emb, axis=-1)

# 训练用纯函数 + jit
def loss_fn(params, user_ids, item_ids, targets):
    pred = MF(943, 1682).apply(params, user_ids, item_ids)
    return jnp.mean((pred - targets) ** 2)

grad_fn = jax.jit(jax.value_and_grad(loss_fn))

JAX的核心优势在于:一切皆函数。没有隐藏状态,没有副作用,调试时可以直接print(params)。而且jax.jit编译后速度飞起,训练比PyTorch还快5%。

最让我惊喜的是模型导出:直接序列化成msgpack或pickle,只有8MB! 因为JAX模型本质上就是参数字典+纯函数,根本不需要笨重的运行时。

但代价是:学习曲线陡峭。你得理解PRNGKeyvmappmap这些概念。第一次写的时候,我把随机种子搞错了,导致每次推理结果都不一样,差点以为模型疯了。


综合对比:不只是速度,更是工程体验

我整理了一张实战对比表,结合了训练、部署、前端集成和代码维护性:

维度 TensorFlow 2.x PyTorch 2.0 JAX + Flax
训练灵活性 中(图模式限制多) 高(Eager模式自由) 极高(纯函数+jit)
模型体积 32MB 28MB 8MB
前端集成 TF.js(直接支持) ONNX Runtime Web(需转换) 需自研推理引擎
调试体验 差(符号张量坑多) 优秀(像NumPy) 好(但需适应函数式)
代码可读性 中(Keras封装深) 高(直观) 中(需理解JAX范式)
团队上手成本 低(国内普及率高) 高(小众)

综合角度看:

  • 如果你有现成TF生态,且不介意模型体积,TensorFlow仍是稳妥选择
  • 如果追求开发效率和灵活性,PyTorch是大多数人的归宿
  • 如果你敢赌未来,且团队有极客精神,JAX在性能和轻量化上碾压对手

最终方案:PyTorch + ONNX + 前端微服务

经过三轮AB测试,我们最终选了PyTorch训练 → ONNX导出 → C++推理服务的混合方案:

  1. 用PyTorch训练模型(开发快)
  2. 导出为ONNX(虽然踩坑,但社区方案成熟)
  3. 用ONNX Runtime C++ API写一个轻量推理服务(启动<50ms,内存<50MB)
  4. 前端通过WebSocket长连接获取实时Embedding(避免频繁HTTP请求)

代码片段:

# PyTorch导出ONNX
torch.onnx.export(
    model,
    (dummy_user, dummy_item),
    "mf.onnx",
    input_names=["user_id", "item_id"],
    output_names=["score"],
    dynamic_axes={"user_id": {0: "batch"}, "item_id": {0: "batch"}}
)

上线后,P99延迟从1200ms降到180ms,前端同学终于没再@我骂街。


写在最后:工具只是工具,人才是核心

说实话,写这篇文章时我还在用Cursor(没错,就是那个AI编程工具)。不是我不努力,而是当你要同时看三个框架的文档、调五个环境、回十个群消息时,AI真的能救命。比如刚才那段JAX代码,我让Cursor帮我补全了loss_fn的梯度计算——省了我查Flax文档半小时。

但工具再强,也替代不了对业务的理解。深度学习框架没有银弹,关键看你用在什么场景。TensorFlow适合大厂稳态系统,PyTorch适合快速迭代,JAX适合追求极致性能的团队。

而我?作为一个曾经的前端,现在的“综合”工程师,只希望下次产品经理提需求时,能先问问:“这事儿,能用规则搞定吗?非得上模型?”

(完)

P.S. 上周双11大促,我们的推荐模块零故障。运维请我喝了杯瑞幸——他说:“这次没半夜打电话,真难得。” 我笑了笑,默默打开了Cursor,开始写下周的复盘报告……

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝