这样分析,我读懂了生成对抗网络(GAN)
当下优秀的图像生成模型中,大家有没有发现人物的生成,特别是人的脸部特征几乎是无可挑剔,在很多的AI生成图片中,模型生成的人脸与真实照片对比几乎分辨不出来,如此真实的生成图片,它的背后就离不开强大的生成对抗网络(GAN)的支持。本篇将为大家揭开生成对抗网络的面纱,聊聊它是如何工作的。
01
什么是生成对抗网络
生成对抗网络(Generative Adversarial Networks,简称GAN)是一种深度学习模型。它通过两个神经网络——生成器(Generator)和判别器(Discriminator)——之间的对抗训练来生成新的、与训练数据类似的数据。
1. GAN的工作原理
GAN的核心思想是让生成器和判别器进行博弈:
- 生成器(G):接收随机噪声作为输入,生成模拟的样本。生成器的目标是生成足够逼真的数据,以“欺骗”判别器。
- 判别器(D):接收输入数据(可能是真实样本或生成样本),输出该样本为真实样本的概率。判别器的目标是准确区分真实样本和生成样本。
2. 训练对抗过程:
(1)生成器不断改进生成的数据,试图让判别器将其误判为真实数据。
(2)判别器则不断优化,提高对真实数据和生成数据的区分能力。 这种对抗过程可以类比为一个“博弈游戏”,最终达到纳什均衡,生成器能够生成高质量的假数据,就如同图像生成模型中生成的假人脸。
3. 举例说明:
我们来举一个生动的例子,就如下图所示,我们先安排几个角色:
- 盗贼 - 生成器(Generator)
- 警察 - 判别器(Discriminator)
- 真钞 - 原始数据
- 假钞 - 生成数据
盗贼从银行正常获取到钞票后,开始制作假钞,开始时制假技术低劣,很容易就被警察识别到了,但盗贼不断地提高技术,让假钞越变越真,警察就越来越难分辨了,然后警察也不断提高识别假钞的能力,直到盗贼的制假技术不断提高,到了警察都无法分辨时,那盗贼就完全胜利了,从此生成的钞票就可以弄假成真,无法分辨了。这个博弈的过程需要很多个回合,直到最后生成可以以假乱真的数据。以上这个例子也为我们揭示了图像生成模型的博弈原理。
02
用Pytorch实现GAN
下面是一个使用 PyTorch 实现生成对抗网络的简单示例。我们将通过一个简单的例子来演示它的基本工作原理:生成器(Generator)生成逼真的数据,判别器(Discriminator)区分真假数据。
1. 示例说明:生成手写数字图像
我们将使用 MNIST 数据集,这是一个包含手写数字图像的经典数据集。生成器将生成看起来像手写数字的图像,判别器将尝试区分真实图像和生成图像。
2. 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超参数
batch_size = 64
learning_rate = 0.0002
num_epochs = 5
nz = 100 # 噪声向量的维度
# 数据加载与预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(nz, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 28 * 28),
nn.Tanh() # 输出范围为 [-1, 1]
)
def forward(self, x):
return self.model(x).view(-1, 1, 28, 28)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # 输出范围为 [0, 1]
)
def forward(self, x):
x = x.view(-1, 28 * 28)
return self.model(x)
# 初始化模型
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 定义损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 训练过程
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 创建标签
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 训练判别器
# 真实图像
images = images.to(device)
outputs = discriminator(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 生成图像
noise = torch.randn(batch_size, nz).to(device)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach()) # 使用detach()避免计算生成器的梯度
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# 反向传播和优化
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
g_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 训练生成器
fake_images = generator(noise)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
# 反向传播和优化
d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i + 1) % 300 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], '
f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
f'D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')
# 保存生成的图像
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
grid = make_grid(fake_images[:25], nrow=5)
plt.figure(figsize=(5, 5))
plt.imshow(grid.permute(1, 2, 0).cpu().detach().numpy())
plt.axis('off')
plt.show()
# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
3. 代码解释
(1)数据加载与预处理:
- 使用 `torchvision.datasets.MNIST` 加载 MNIST 数据集,并进行归一化处理。
- 使用 `DataLoader` 将数据加载到训练过程中。
(2)定义生成器和判别器:
- 生成器:将随机噪声向量转换为 28x28 的图像。
- 判别器:将输入图像分类为真实图像或生成图像。
(3)训练过程:
- 训练判别器:
使用真实图像和生成图像分别计算损失。
优化判别器,使其能够更好地区分真假图像。
- 训练生成器:
生成图像并让判别器判断。
优化生成器,使其生成的图像更逼真。
(4)可视化:
每个 epoch 结束时,生成并显示一些生成的图像,观察生成器的进展。
(5)保存模型:
训练完成后,保存生成器和判别器的模型参数。
4. 输出结果
运行代码后,你会看到生成器生成的图像逐渐变得逼真。最终,生成器能够生成看起来像手写数字的图像。
03
总结
生成对抗网络(GAN)就像是一个生成器(造假者)和判别器(警察)之间的博弈游戏。通过不断地对抗和学习,造假者(生成器)会越来越擅长制造逼真的数据,而警察(判别器)会越来越擅长鉴别真假。这种对抗训练的方式让生成器能够生成高质量的数据,广泛应用于图像生成、视频生成、艺术创作等领域。
来源:码农随心笔记