PyTorch Lightning解密:深度学习代码量减少60%

深度学习的小伙伴们应该都写过PyTorch代码,每次写训练循环都要搞一堆模板代码,烦不烦?今天给大家介绍个好东西 - PyTorch Lightning。这玩意儿简直是写深度学习代码的神器,不仅代码量能砍掉一大半,还能帮你自动处理各种训练细节。说真的,用了Lightning之后再也不想碰原生PyTorch了。

Lightning是个啥

Lightning说白了就是对PyTorch的一层封装,它把深度学习中常见的那些训练步骤都给标准化了。啥意思呢?就是你不用再写那些烦人的epoch循环、梯度更新、验证过程这些破事儿了,Lightning全给你包好了。

import pytorch_lightning as pl

class MyModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

看看这代码,干净利落吧?原来写PyTorch时那一大堆训练循环全没了。

模型定义超简单

在Lightning里面,咱们主要就是要继承个LightningModule类,然后实现几个关键方法就完事了。

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(28*28, 10)
    
    def forward(self, x):
        return self.layer(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

温馨提示:configure_optimizers这个方法可千万别忘了写,不然模型训练不了。这个坑我之前就踩过...

训练过程全自动

Lightning最爽的地方就是训练特别简单,你只要定义好每一步咋处理数据就行:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    self.log('train_loss', loss)
    return loss

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    self.log('val_loss', val_loss)

啥?你说想看训练进度条?想保存模型checkpoint?想用多GPU训练?这些Lightning都给你整好了,一行代码的事:

trainer = pl.Trainer(max_epochs=10, gpus=2)
trainer.fit(model, train_loader, val_loader)

炫酷的callbacks

还记得以前写PyTorch时想加个新功能要改好几个地方吗?用Lightning直接加个callback就搞定:

class MyCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"哈哈,又训练完一个epoch啦!")

trainer = pl.Trainer(callbacks=[MyCallback()])

温馨提示:callback里面能加的钩子可多了,比如on_train_starton_validation_end啥的,想在啥时候插入代码都行。

数据加载也很灵活

Lightning对数据加载也做了封装,但是保留了很大的灵活性:

class MyDataModule(pl.LightningDataModule):
    def setup(self, stage=None):
        self.train_data = MNIST(...)
        self.val_data = MNIST(...)
    
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=32)
    
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=32)

说实话,Lightning真是省心省力的好东西。代码写起来超级整洁,功能还贼全。不过刚开始用的时候可能会有点不习惯,毕竟跟原生PyTorch的写法差挺多。但是习惯了之后,真的是爽到飞起。建议大家有空试试看,保证让你的代码量直接砍掉一大半!

来源:久米 久米的小律法

THE END