深度学习框架实战对比踩坑记录

需求之外
2026-07-01 11:32
阅读 880

上个月刚入职新公司,工位还没捂热,leader就甩过来一个需求:给内部客服系统搭一套智能知识库。我一听就乐了——这不就是让我从Java后端老哥转型AI工程师吗?不过话说回来,在京东干了五年,618和双11的流量洪峰都扛过来了,还怕这点挑战?

但这次不一样,不是写CRUD,不是搞分布式事务,是真刀真枪地搞深度学习。之前在公司顶多写写推荐系统的规则引擎,模型训练这事儿,还真是头一回。

需求到底是个啥

简单说,就是把公司历年的客服工单、产品文档、FAQ整理成一个知识库,然后训练一个模型,让客服小姐姐能直接问问题,系统自动给出答案。听起来像是个RAG(检索增强生成)的活儿,但leader说预算有限,先不搞大模型,用传统的文本分类+检索方案先跑起来。

我拿到数据一看,大概有50万条工单记录,分布在200多个品类里。数据清洗完,训练集大概40万条,验证集5万,测试集5万。文本长度参差不齐,短的三五个字,长的能写小作文。

框架选型:纠结了三天

说实话,选框架这事儿比我想象的纠结多了。目前主流的深度学习框架就那么几个:PyTorch、TensorFlow、PaddlePaddle,还有一些新兴的比如JAX。我花了一整天时间调研,又花了一天写demo对比。

先说结论,直接上表格:

维度 PyTorch TensorFlow PaddlePaddle
动态图支持 原生支持,舒服 2.x之后也支持了 支持,但生态差一截
社区生态 GitHub上最活跃 老牌,文档全 百度自家生态,中文文档友好
部署能力 TorchScript/ONNX TF Serving很成熟 Paddle Inference还行
上手难度 中等,Pythonic 略繁琐,概念多 跟PyTorch很像
GPU利用率 略低,但差距不大
中文NLP支持 一般,靠HuggingFace 一般 很好,有ERNIE等预训练模型

最后我选了PyTorch。原因很简单:HuggingFace的Transformers库太香了,而且PyTorch的动态图调试起来真的爽,不用像TF 1.x那样搞一堆tf.Session()

不过这里要提一嘴,PaddlePaddle在中文NLP这块确实有优势,百度搞的ERNIE模型在中文任务上效果很顶。如果你们做的是纯中文场景,可以认真考虑一下。

上手Trae,效率直接起飞

说到写代码,不得不安利一下Trae这个AI编程助手。之前在公司用惯了JetBrains全家桶,来到新公司发现同事都在用Trae,说是字节跳动出的,免费还贼好用。

我一开始是不信的,AI写代码能有多靠谱?结果用了两天,真香了。

举个真实例子,我在写数据预处理的pipeline时,需要把原始工单文本做分词、去停用词、转token id。这种活儿说简单也简单,说繁琐也繁琐。我直接在Trae里输入:

帮我写一个Python函数,输入是原始文本列表,输出是token id列表。
要求:
1. 使用jieba分词
2. 去掉停用词(停用词表从stopwords.txt读取)
3. 使用BertTokenizer做tokenization
4. 支持batch处理
5. 加上详细的类型注解和docstring

Trae直接给我生成了大概80行代码,逻辑基本没问题,我稍微改改就能跑。以前这种活儿怎么也得写一两个小时,现在十分钟搞定。

不过也有翻车的时候。有一次让它帮我写一个自定义的DataLoader,它生成的代码里有个bug,在__getitem__里直接对list做了index操作,但没处理越界的情况。结果跑到第10000条数据的时候直接崩了,报错信息是IndexError: list index out of range。当时真的想砸电脑,debug了半小时才发现是AI生成的代码有问题。

所以说,AI辅助编程是好,但代码review不能省。特别是涉及到业务逻辑的地方,一定要自己过一遍。

Amazon Q帮我搞定了部署

模型训练完了,接下来是部署。这块我用的是Amazon Q,AWS出的AI助手。

为什么用Amazon Q?因为新公司的基础设施全在AWS上,leader说让我用ECS部署模型服务。我之前在京东用的是自研的容器平台,AWS这一套还真不太熟。

Amazon Q帮了大忙。我直接问它:

我要在AWS ECS上部署一个PyTorch模型服务,用FastAPI做接口。
请帮我生成Dockerfile和ECS Task Definition。
模型文件大概500MB,需要GPU支持。

它给我生成的Dockerfile长这样:

FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY ./model /app/model
COPY ./src /app/src

EXPOSE 8080

CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8080"]

ECS Task Definition也给了,包括GPU资源配置、日志设置、健康检查什么的,基本拿来就能用。

不过这里有个坑要提醒大家:PyTorch的镜像默认很大,加上模型文件,最后我的镜像搞到了8个G。ECS拉镜像的时候巨慢,启动一个任务要等好几分钟。后来我优化了一下,用multi-stage build,把最终镜像压到了3个G左右,启动时间从5分钟降到了1分钟。

模型训练的那些坑

说回模型训练本身。我用的方案是BERT-base做文本编码,然后接一个分类头。训练过程中踩了不少坑,分享几个印象深刻的。

第一个坑:OOM(Out of Memory)

刚开始训练的时候,batch size设的64,直接OOM了。我用的是一张A10G的卡,显存24G。BERT-base本身不算大,但512的sequence length加上64的batch size,显存直接爆了。

解决办法很经典:gradient accumulation。把batch size设成16,每4个step更新一次参数,等效batch size还是64。

# 关键配置
accumulation_steps = 4
effective_batch_size = 16 * accumulation_steps  # 64

for step, batch in enumerate(dataloader):
    loss = model(**batch).loss
    loss = loss / accumulation_steps
    loss.backward()
    
    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

第二个坑:学习率调度

一开始我用的是固定学习率2e-5,训练了10个epoch,发现验证集loss在第5个epoch之后就开始反弹了,典型的过拟合。后来换成了linear schedule with warmup,效果好了很多。

from transformers import get_linear_schedule_with_warmup

num_training_steps = len(dataloader) * num_epochs
num_warmup_steps = int(num_training_steps * 0.1)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

第三个坑:类别不平衡

200多个品类,有的品类有上万条数据,有的只有几十条。直接训的话,模型会偏向大类。我用了focal loss来解决这个问题:

import torch
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        BCE_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

效果很明显,小类目的准确率从30%提升到了65%左右。

Antigravity?那是个彩蛋

说到Antigravity,这其实不是个框架,是个Python的彩蛋。在Python里输入import antigravity,会打开一个xkcd的漫画页面。

>>> import antigravity

这玩意儿当然不能帮你训练模型,但在团队分享的时候提一嘴,能活跃一下气氛。毕竟搞AI的,不能太严肃,得有点极客精神。

不过说正经的,我后来发现Antigravity这个名字被一个开源项目用了,是个轻量级的神经网络可视化库,可以画模型结构图和训练曲线。虽然功能不如TensorBoard强大,但胜在简单好用,适合快速debug。

# 简单示例
import antigravity as ag

model = MyModel()
ag.plot_model(model, to_file='model.png')
ag.plot_training_history(history, to_file='history.png')

最终效果

折腾了差不多三周,知识库系统终于上线了。来看下效果:

指标 数值
训练集准确率 92.3%
验证集准确率 87.6%
测试集准确率 86.9%
推理延迟(P99) 120ms
模型大小 420MB
日均请求量 约5万次

客服小姐姐反馈说,系统给出的答案准确率大概有八成左右,比之前翻文档快多了。leader也比较满意,说下个季度可以试试上RAG方案,接入大模型。

一些心得

最后总结几点心得,给同样在搞深度学习的后端同学参考:

  1. 框架选择别纠结太久。PyTorch和TensorFlow都能干活,选一个深入学就行。我选PyTorch主要是因为HuggingFace生态好,如果你做CV,可能TensorFlow更合适。

  2. AI辅助编程要用,但不能全信。Trae和Amazon Q确实能提效,但生成的代码一定要review。特别是涉及到业务逻辑和边界条件的地方,AI容易想当然。

  3. 数据比模型重要。我花了大概一周时间做数据清洗和标注,比训练模型的时间还长。但事实证明,干净的数据比调参重要得多。

  4. 部署要提前规划。模型训练完了再想怎么部署,往往会发现很多坑。建议一开始就想好部署方案,包括镜像大小、GPU资源、扩缩容策略等。

  5. 别怕踩坑。在京东扛了五年流量洪峰,什么线上事故没见过?搞AI也一样,OOM、梯度爆炸、过拟合,都是正常的。踩坑的过程就是学习的过程。

好了,今天就聊到这里。新公司的试用期还有四个月,得继续加油了。如果有关于深度学习框架选型或者部署方面的问题,欢迎在评论区交流。

哦对了,下次双11要是再让我值班,我可能会哭。但这次的知识库项目,确实让我学到了很多新东西。从后端到AI,这条路还很长,但至少迈出了第一步。

共勉。

评论 0

最热最新
暂无评论
需求之外Lv.1
0
影响力
0
文章
0
粉丝