特斯拉AI总监的MNIST训练之旅

1. 引言

今天我们来重点介绍一下特斯拉AI总监的一篇博客《Deep Neural Nets: 33 years ago and 33 years from now》,这篇文章深入浅出地介绍了DNN最近三十年来的发展和趋势。

恩,闲话少说,我们直接开始吧!

2. 实验内容

Andrej Karpathy 主要复现了深度学习开山之作LeNet,该模型主要用于手写字符识别。他尝试利用这33年来人类一些新的改进的trick,来提升模型的效果。基于 lecun1989-repro 进行相应的实验内容如下:

  1. baseline:
eval: split train. loss 4.073383e-03. error 0.62%. misses: 45
eval: split test . loss 2.838382e-02. error 4.09%. misses: 82
  1. 原文的MSE loss换成如今多分类最为常用的损失函数Cross Entropy Loss
eval: split train. loss 9.536698e-06. error 0.00%. misses: 0
eval: split test . loss 9.536698e-06. error 4.38%. misses: 87
  1. 结果并没有提升,怀疑原始代码中的SGD优化器不够给力,于是切换成了最新的AdamW优化器,并使用“大家都知道”的最优初始学习率3e-4,还加了点weight decay,得到结果如下:
eval: split train. loss 0.000000e+00. error 0.00%. misses: 0
eval: split test . loss 0.000000e+00. error 3.59%. misses: 72
  1. 初步尝到了甜头,但仔细看评价结果,可以发现train/test的差别仍很大,提示这可能是过拟合的现象。于是决定稍微添加一些数据增强。
eval: split train. loss 8.780676e-04. error 1.70%. misses: 123
eval: split test . loss 8.780676e-04. error 2.19%. misses: 43
  1. 感觉还是有一些过拟合,遂增加dropout,并把tanh激活函数换成了ReLU。
eval: split train. loss 2.601336e-03. error 1.47%. misses: 106
eval: split test . loss 2.601336e-03. error 1.59%. misses: 32

通过上面一步一步地改进,Andrej Karpathy 总监成功把33年前经典分类问题的错误率又降低了60%!这几步虽然很常见,但也体现了总监扎实的基本功。总监还是不满意,对推理结果进行了可视化,得到模型的错例如下:
在这里插入图片描述

虽然他又尝试了一些例如Vision Transformer之类更新潮酷炫的东西,但比较遗憾的是模型都没有再涨点了。最后从本质出发,通过可视化模型的错例,对其增加了相应的数据,这使得错误率进一步降低,达到了1.25%。

eval: split train. loss 3.238392e-04. error 1.07%. misses: 31
eval: split test . loss 3.238392e-04. error 1.25%. misses: 24

观察上面的错例其实大家也能感受到有些错误模型应该是可搞对的,此时增加相应的数据确实是一个好办法。但更重要的是,希望大家也能养成把模型推理结果可视化出来审视的好习惯!

3. 实验总结

接着,Andrej Karpathy 对上述实验,进行了相应的总结:

  1. 做的事情本质并没有改变,还是可微分的神经网络、基于梯度优化那一套理论

  2. 当时的数据集规模好小啊,MNIST只有7000多张,如今比如CLIP训练图片有400百万张,而且每张图的分辨率都大得多

  3. 受限于训练资源,当时网络好小啊

  4. 当时的训练好慢啊,7000多张图+这么小的网络要跑3天,现在使用总监的Macbook可以90s训练完成

    扫描二维码关注公众号,回复: 15561782 查看本文章
  5. 针对该问题还是有进步的,可以用现在的技巧使错误率下降60%

  6. 纯增大数据集效果不大,还得配上各种训练技巧才能驾驭

  7. 再往前走得靠大模型了,就得大算力

4. 总结

尽管手写字符识别问题在如今看来已经是很成熟的算法啦,但是本文Andrej Karpathy 总监通过利用最新的深度学习策略,使得模型的错误率又下降了60%,同时也展现了其扎实的深度学习基础知识,希望大家可以都像其一样优秀。

嗯嗯,您学废了嘛?

猜你喜欢

转载自blog.csdn.net/sgzqc/article/details/129914961