PyTorch Lightning解密:深度学习代码量减少60%
深度学习的小伙伴们应该都写过PyTorch代码,每次写训练循环都要搞一堆模板代码,烦不烦?今天给大家介绍个好东西 - PyTorch Lightning。这玩意儿简直是写深度学习代码的神器,不仅代码量能砍掉一大半,还能帮你自动处理各种训练细节。说真的,用了Lightning之后再也不想碰原生PyTorch了。
Lightning是个啥
Lightning说白了就是对PyTorch的一层封装,它把深度学习中常见的那些训练步骤都给标准化了。啥意思呢?就是你不用再写那些烦人的epoch循环、梯度更新、验证过程这些破事儿了,Lightning全给你包好了。
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
看看这代码,干净利落吧?原来写PyTorch时那一大堆训练循环全没了。
模型定义超简单
在Lightning里面,咱们主要就是要继承个LightningModule
类,然后实现几个关键方法就完事了。
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(28*28, 10)
def forward(self, x):
return self.layer(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
温馨提示:configure_optimizers
这个方法可千万别忘了写,不然模型训练不了。这个坑我之前就踩过...
训练过程全自动
Lightning最爽的地方就是训练特别简单,你只要定义好每一步咋处理数据就行:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
self.log('val_loss', val_loss)
啥?你说想看训练进度条?想保存模型checkpoint?想用多GPU训练?这些Lightning都给你整好了,一行代码的事:
trainer = pl.Trainer(max_epochs=10, gpus=2)
trainer.fit(model, train_loader, val_loader)
炫酷的callbacks
还记得以前写PyTorch时想加个新功能要改好几个地方吗?用Lightning直接加个callback就搞定:
class MyCallback(pl.Callback):
def on_train_epoch_end(self, trainer, pl_module):
print(f"哈哈,又训练完一个epoch啦!")
trainer = pl.Trainer(callbacks=[MyCallback()])
温馨提示:callback里面能加的钩子可多了,比如on_train_start
、on_validation_end
啥的,想在啥时候插入代码都行。
数据加载也很灵活
Lightning对数据加载也做了封装,但是保留了很大的灵活性:
class MyDataModule(pl.LightningDataModule):
def setup(self, stage=None):
self.train_data = MNIST(...)
self.val_data = MNIST(...)
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=32)
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=32)
说实话,Lightning真是省心省力的好东西。代码写起来超级整洁,功能还贼全。不过刚开始用的时候可能会有点不习惯,毕竟跟原生PyTorch的写法差挺多。但是习惯了之后,真的是爽到飞起。建议大家有空试试看,保证让你的代码量直接砍掉一大半!
来源: