从哄睡到GPU显存爆满:一位全职妈妈的PyTorch初体验

徐娟
2025-12-29 09:22
阅读 851

凌晨2点47分,娃终于睡了。我蹑手蹑脚爬回电脑前,屏幕还亮着——PyTorch训练脚本卡在第3个epoch,显存又爆了。这时候我已经连续三天每天只睡4小时,白天陪玩、做饭、收拾“战场”,晚上等娃睡了才偷偷摸摸写代码。别笑,这大概就是“代码人生”的真实写照:左手奶瓶,右手梯度下降。

说来你可能不信,我上一份正经工作是在一家做Spring Boot微服务的电商公司干前端,整天和React、动画交互打交道。但去年跳槽时,面试官突然甩出一道题:“如果让你用深度学习优化商品推荐系统的点击率,你会怎么做?”我当时就懵了——我连TensorFlow都没跑通过!更离谱的是,那家公司居然还在搞“区块链+AI”的概念项目(是的,2024年还有人在炒这个)。虽然最后没去成,但那道面试题像根刺扎在我心里。

于是,趁着产假尾巴,我决定啃下PyTorch这块硬骨头。不为别的,就为了下次面试能挺直腰板说:“姐不仅会写CSS动画,还会调参!”

为什么是PyTorch?而不是TensorFlow?

说实话,一开始我差点被TensorFlow劝退。安装依赖像在拼俄罗斯套娃,文档写得跟天书似的。直到我在GitHub上看到一个PyTorch实现GAN生成卡通头像的项目——代码干净得像刚洗过的白衬衫,注释比我妈唠叨还详细。那一刻我知道,就是它了。

而且我们组里最近也在悄悄试水AI。上周五站会上,后端小哥神秘兮兮地说:“老板想用深度学习预测用户流失,Spring Boot服务要接模型推理接口。”产品经理立刻眼睛发亮:“能不能结合区块链存证用户行为数据?”……我当场翻了个白眼——区块链存证?用户点个“不喜欢”还要上链?但吐槽归吐槽,活儿得干啊。

第一行代码:从“Hello World”到“Hello Tensor”

装环境永远是第一道坎。conda、pip、CUDA版本……我折腾了整整一个周末,娃在旁边把积木搭了拆、拆了搭,仿佛在嘲讽我的无能。最终靠Docker才搞定:

# 这是我能跑通的最简配置(泪目)
docker run --gpus all -it --rm -v $(pwd):/workspace pytorch/pytorch:2.1.0-cuda11.8-devel

进入容器后,第一行代码必须致敬经典:

import torch
print("Hello, PyTorch! 我娃刚吐奶了,但我还能写代码!")

输出正常!那一刻的喜悦,堪比看到娃第一次翻身。

PyTorch最让我上头的是它的动态计算图。作为前端出身的人,我习惯“所见即所得”——写一行代码,马上能看到结果。PyTorch的Eager Mode完美契合这种思维。比如创建一个张量:

# 像NumPy一样自然
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.rand(2, 3)
z = x + y  # 立刻执行!不用sess.run()

对比之下,TensorFlow 1.x那种先建图再执行的方式,简直反人类。难怪现在大家都说:“PyTorch赢在用户体验。”

实战:用MNIST手写数字识别入门

光说不练假把式。我选了经典的MNIST数据集——毕竟带娃的手经常抖,说不定哪天自己也能写出“手写体”数据集(手动狗头)。

数据加载:DataLoader救我狗命

以前在前端处理大量DOM节点时,总会用虚拟滚动优化性能。PyTorch的DataLoader简直就是后端版的虚拟滚动:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

这里num_workers=4让我踩了个大坑:在Mac上跑没问题,但部署到Linux服务器时直接报错BrokenPipeError。后来才知道是多进程共享内存的问题——赶紧改成num_workers=0,世界清净了。运维同事看我的眼神仿佛在说:“前端转AI,果然不行。”

模型搭建:比写React组件还爽

定义神经网络就像搭乐高。我用nn.Sequential快速组装了一个简单CNN:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

注意到没?forward方法就是数据流的声明式描述——这不就是React的render函数吗?只不过输入是tensor,输出也是tensor。瞬间亲切感拉满!

训练循环:从OOM到收敛

前面提到显存爆炸的问题。最初我把batch_size设成256,结果GPU直接罢工:

RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB...

我一度怀疑是不是娃半夜偷用了我的GPU挖矿(毕竟他最近迷上了“挖宝藏”游戏)。后来查资料发现,PyTorch有个神器叫torch.cuda.empty_cache(),但治标不治本。真正解决方案是:

  1. 减小batch size:从256降到64
  2. 用混合精度训练torch.cuda.amp.autocast()
  3. 及时释放中间变量:避免在for循环里累积tensor

调整后的训练循环长这样:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(10):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        
        with autocast():  # 自动混合精度
            output = model(data)
            loss = F.nll_loss(output, target)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.6f}')
            # 显存监控:nvidia-smi显示稳定在4GB以下!

当准确率突破98%时,我激动得差点把娃吵醒——这可是我亲手调出来的第一个模型啊!

PyTorch vs 其他框架:一张表说清楚

作为技术分享会常客,我总被问到框架选型问题。这里总结下实战感受:

维度 PyTorch TensorFlow JAX
上手难度 ⭐⭐⭐⭐⭐(动态图真香) ⭐⭐⭐(TF2.0改善很多) ⭐⭐(函数式编程门槛高)
调试体验 ⭐⭐⭐⭐⭐(像普通Python) ⭐⭐⭐(需tf.debugging) ⭐⭐(不可变性反直觉)
生产部署 ⭐⭐⭐(TorchScript稍弱) ⭐⭐⭐⭐⭐(TF Serving成熟) ⭐(生态不完善)
社区资源 ⭐⭐⭐⭐(论文复现首选) ⭐⭐⭐⭐ ⭐⭐
移动端支持 ⭐⭐(PyTorch Mobile一般) ⭐⭐⭐⭐(TF Lite强大)

注:评分基于个人2024年实测,不代表官方立场

有意思的是,我们组最近有个项目要用Spring Boot提供AI服务。后端同学坚持用TensorFlow SavedModel,因为TF Serving和K8s集成更顺。但我偷偷用PyTorch导出ONNX模型,再通过ONNX Runtime部署——效果居然更好!看来框架之争,终究要看场景。

那些年踩过的坑:血泪教训

1. 设备不一致:CPU/GPU混用陷阱

有次本地测试好好的,推到服务器就报错:

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

原因是我加载预训练权重时没指定设备:

# 错误示范
model.load_state_dict(torch.load('model.pth'))

# 正确姿势
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model.pth', map_location=device))

从此以后,我写代码必加device参数,比检查娃尿布还勤快。

2. 随机种子玄学

模型训练结果忽高忽低,搞得我以为自己代码有bug。后来才知道PyTorch默认不固定随机种子。现在我的脚本开头必加这段:

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed_everything()

虽然牺牲了点性能(cudnn.benchmark=False),但至少结果可复现——这点对面试吹牛特别重要!

3. 内存泄漏:闭包的锅

有次训练时显存缓慢增长,最后OOM。排查半天发现是回调函数里不小心引用了模型:

# 危险代码!
def on_epoch_end():
    print(model.state_dict())  # 引用导致无法释放

# 安全做法
def on_epoch_end(weights):
    print(weights)  # 只传需要的数据

这让我想起前端里常见的闭包内存泄漏——技术债不分前后端啊!

结语:代码人生,不止一种活法

回看这两个月,PyTorch带给我的不仅是技术提升。当我在技术分享会上演示手写数字识别模型时,台下有人问:“你是怎么平衡带娃和学习的?”我笑着说:“哪有什么平衡,不过是碎片时间榨干罢了。”

其实深度学习没那么玄乎。就像前端从jQuery到React的演进,PyTorch代表的是一种更直观、更人性化的编程哲学。至于那些“区块链+AI”的PPT项目?随它去吧。我更关心明天娃会不会发烧,以及我的模型能不能在deadline前跑通。

最后送大家一句我贴在显示器边的话:“每个epoch都是新的开始,就像每个清晨娃的笑脸。

哦对了,如果你也在带娃写代码,欢迎留言交流——说不定我们能组个“深夜码农妈妈互助群”,一起对抗OOM和夜奶!

评论 0

最热最新
暂无评论
匿名用户Lv.1
0
影响力
0
文章
0
粉丝