BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)

论坛 期权论坛     
选择匿名的用户   2021-5-31 09:53   1293   0
<div id="js_content">
<p><img src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-b0a938c2f8bb0fb992195eb546b47f5f"></p>
<p>一只小狐狸带你解锁 <strong>炼丹术&amp;<strong>NLP</strong></strong><strong> </strong>秘籍</p>
<h2>作者:夕小瑶、rumor酱</h2>
<h1>前言</h1>
<p>虽然TPU的显存令人羡慕,但是由于众所周知的原因,绝大部分人还是很难日常化使用的。英伟达又一直在挤牙膏,至今单卡的最大显存也仅仅到32G(参考V100、DGX-2)。然而,训练一个24层的BERT Large模型的时候,如果sequence length开满512,那么batch size仅仅开到8(有时候能到10)就把这寥寥32G的显存打满了。如果想训练一个48层乃至100层的BERT Large,那完全是土豪们的游戏了,需要疯狂的模型并行&#43;分布式多机训练。<br></p>
<p>但!是!万能的小夕前不久在Daxiang Dong大佬的安利下,发现了&#64;陈天奇 大佬2016年的一篇宝藏paper!</p>
<img src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-22e88058d0f629ccfc616520ac495cea">
<p>简单的划一下重点:</p>
<p>这篇paper用时间换空间的思想,<strong>在前向时只保存部分中间节点,在反向时重新计算没保存的部分</strong>。论文通过这种机制,在每个batch只多计算一次前向的情况下,把n层网络的占用显存优化到了
    <svg style="vertical-align: -0.566ex;width: 6.774ex;height: 2.473ex;" viewbox="0 -843 2994 1093">
     <g fill="currentColor" stroke="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">
      <g>
       <g>
        <path d="M740 435Q740 320 676 213T511 42T304 -22Q207 -22 138 35T51 201Q50 209 50 244Q50 346 98 438T227 601Q351 704 476 704Q514 704 524 703Q621 689 680 617T740 435ZM637 476Q637 565 591 615T476 665Q396 665 322 605Q242 542 200 428T157 216Q157 126 200 73T314 19Q404 19 485 98T608 313Q637 408 637 476Z"></path>
       </g>
       <g transform="translate(763, 0)">
        <path d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path>
       </g>
       <g transform="translate(1152, 0)">
        <g transform="translate(853, 0)">
         <g>
          <path d="M21 287Q22 293 24 303T36 341T56 388T89 425T135 442Q171 442 195 424T225 390T231 369Q231 367 232 367L243 378Q304 442 382 442Q436 442 469 415T503 336T465 179T427 52Q427 26 444 26Q450 26 453 27Q482 32 505 65T540 145Q542 153 560 153Q580 153 580 145Q580 144 576 130Q568 101 554 73T508 17T439 -10Q392 -10 371 17T350 73Q350 92 386 193T423 345Q423 404 379 404H374Q288 404 229 303L222 291L189 157Q156 26 151 16Q138 -11 108 -11Q95 -11 87 -5T76 7T74 17Q74 30 112 180T152 343Q153 348 153 366Q153 405 129 405Q91 405 66 305Q60 285 60 284Q58 278 41 278H27Q21 284 21 287Z"></path>
         </g>
        </g>
        <g transform="translate(0, -17)">
         <path d="M95 178Q89 178 81 186T72 200T103 230T169 280T207 309Q209 311 212 311H213Q219 311 227 294T281 177Q300 134 312 108L397 -77Q398 -77 501 136T707 565T814 786Q820 800 834 800Q841 800 846 794T853 782V776L620 293L385 -193Q381 -200 366 -200Q357 -200 354 -197Q352 -195 256 15L160 225L144 214Q129 202 113 190T95 178Z"></path>
        </g>
       </g>
       <g transform="translate(2605, 0)">
        <path d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path>
       </g>
      </g>
     </g>
    </svg>。在极端情况下,仍可用
    <svg style="vertical-align: -0.566ex;width: 9.052ex;height: 2.262ex;" viewbox="0 -750 4001 1000">
     <g fill="currentColor" stroke="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">
      <g>
       <g>
        <path d="M740 435Q740 320 676 213T511 42T304 -22Q207 -22 138 35T51 201Q50 209 50 244Q50 346 98 438T227 601Q351 704 476 704Q514 704 524 703Q621 689 680 617T740 435ZM637 476Q637 565 591 615T476 665Q396 665 322 605Q242 542 200 428T157 216Q157 126 200 73T314 19Q404 19 485 98T608 313Q637 408 637 476Z"></path>
       </g>
       <g transform="translate(763, 0)">
        <path d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path>
       </g>
       <g transform="translate(1152, 0)">
        <path d="M21 287Q22 293 24 303T36 341T56 388T89 425T135 442Q171 442 195 424T225 390T231 369Q231 367 232 367L243 378Q304 442 382 442Q436 442 469 415T503 336T465 179T427 52Q427 26 444 26Q450 26 453 27Q482 32 505 65T540 145Q542 153 560 153Q580 153 580 145Q580 144 576 130Q568 101 554 73T508 17T439 -10Q392 -10 371 17T350 73Q350 92 386 193T423 345Q423 404 379 404H374Q288 404 229 303L222 291L189 157Q156 26 151 16Q138 -11 108 -11Q95 -11 87 -5T76 7T74 17Q74 30 112 180T152 343Q153 348 153 366Q153 405 129 405Q91 405 66 305Q60 285 60 284Q58 278 41 278H27Q21 284 21 287Z"></path>
       </g>
       <g transform="translate(1752, 0)">
        <path d="M117
分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP