这样分析,我读懂了生成对抗网络(GAN)

当下优秀的图像生成模型中,大家有没有发现人物的生成,特别是人的脸部特征几乎是无可挑剔,在很多的AI生成图片中,模型生成的人脸与真实照片对比几乎分辨不出来,如此真实的生成图片,它的背后就离不开强大的生成对抗网络(GAN)的支持。本篇将为大家揭开生成对抗网络的面纱,聊聊它是如何工作的。

01

什么是生成对抗网络

生成对抗网络(Generative Adversarial Networks,简称GAN)是一种深度学习模型。它通过两个神经网络——生成器(Generator)和判别器(Discriminator)——之间的对抗训练来生成新的、与训练数据类似的数据。

图片

1. GAN的工作原理

GAN的核心思想是让生成器和判别器进行博弈:

  • 生成器(G):接收随机噪声作为输入,生成模拟的样本。生成器的目标是生成足够逼真的数据,以“欺骗”判别器。
  • 判别器(D):接收输入数据(可能是真实样本或生成样本),输出该样本为真实样本的概率。判别器的目标是准确区分真实样本和生成样本。

2. 训练对抗过程:

(1)生成器不断改进生成的数据,试图让判别器将其误判为真实数据。

(2)判别器则不断优化,提高对真实数据和生成数据的区分能力。 这种对抗过程可以类比为一个“博弈游戏”,最终达到纳什均衡,生成器能够生成高质量的假数据,就如同图像生成模型中生成的假人脸。

3. 举例说明:

我们来举一个生动的例子,就如下图所示,我们先安排几个角色:

图片

  • 盗贼 - 生成器(Generator)
  • 警察 - 判别器(Discriminator)
  • 真钞 - 原始数据
  • 假钞 - 生成数据

盗贼从银行正常获取到钞票后,开始制作假钞,开始时制假技术低劣,很容易就被警察识别到了,但盗贼不断地提高技术,让假钞越变越真,警察就越来越难分辨了,然后警察也不断提高识别假钞的能力,直到盗贼的制假技术不断提高,到了警察都无法分辨时,那盗贼就完全胜利了,从此生成的钞票就可以弄假成真,无法分辨了。这个博弈的过程需要很多个回合,直到最后生成可以以假乱真的数据。以上这个例子也为我们揭示了图像生成模型的博弈原理。

02

用Pytorch实现GAN

下面是一个使用 PyTorch 实现生成对抗网络的简单示例。我们将通过一个简单的例子来演示它的基本工作原理:生成器(Generator)生成逼真的数据,判别器(Discriminator)区分真假数据。

1. 示例说明:生成手写数字图像

我们将使用 MNIST 数据集,这是一个包含手写数字图像的经典数据集。生成器将生成看起来像手写数字的图像,判别器将尝试区分真实图像和生成图像。

2. 代码实现

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torchvision.utils import make_gridimport matplotlib.pyplot as plt
# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超参数batch_size = 64learning_rate = 0.0002num_epochs = 5nz = 100  # 噪声向量的维度
# 数据加载与预处理transform = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器class Generator(nn.Module):    def __init__(self):        super(Generator, self).__init__()        self.model = nn.Sequential(            nn.Linear(nz, 256),            nn.ReLU(True),            nn.Linear(256512),            nn.ReLU(True),            nn.Linear(5121024),            nn.ReLU(True),            nn.Linear(102428 * 28),            nn.Tanh()  # 输出范围为 [-11]        )    def forward(self, x):        return self.model(x).view(-112828)
# 定义判别器class Discriminator(nn.Module):    def __init__(self):        super(Discriminator, self).__init__()        self.model = nn.Sequential(            nn.Linear(28 * 28512),            nn.LeakyReLU(0.2),            nn.Linear(512256),            nn.LeakyReLU(0.2),            nn.Linear(2561),            nn.Sigmoid()  # 输出范围为 [01]        )    def forward(self, x):        x = x.view(-128 * 28)        return self.model(x)
# 初始化模型generator = Generator().to(device)discriminator = Discriminator().to(device)
# 定义损失函数和优化器criterion = nn.BCELoss()g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 训练过程for epoch in range(num_epochs):    for i, (images, _) in enumerate(train_loader):                # 创建标签        real_labels = torch.ones(batch_size, 1).to(device)        fake_labels = torch.zeros(batch_size, 1).to(device)                # 训练判别器        # 真实图像        images = images.to(device)        outputs = discriminator(images)        d_loss_real = criterion(outputs, real_labels)        real_score = outputs                # 生成图像        noise = torch.randn(batch_size, nz).to(device)        fake_images = generator(noise)        outputs = discriminator(fake_images.detach())  # 使用detach()避免计算生成器的梯度        d_loss_fake = criterion(outputs, fake_labels)        fake_score = outputs                # 反向传播和优化        d_loss = d_loss_real + d_loss_fake        d_optimizer.zero_grad()        g_optimizer.zero_grad()        d_loss.backward()        d_optimizer.step()                # 训练生成器        fake_images = generator(noise)        outputs = discriminator(fake_images)        g_loss = criterion(outputs, real_labels)                # 反向传播和优化        d_optimizer.zero_grad()        g_optimizer.zero_grad()        g_loss.backward()        g_optimizer.step()                if (i + 1) % 300 == 0:            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], '                  f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '                  f'D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')        # 保存生成的图像    fake_images = fake_images.reshape(fake_images.size(0), 12828)    grid = make_grid(fake_images[:25], nrow=5)    plt.figure(figsize=(55))    plt.imshow(grid.permute(120).cpu().detach().numpy())    plt.axis('off')    plt.show()
# 保存模型torch.save(generator.state_dict(), 'generator.pth')torch.save(discriminator.state_dict(), 'discriminator.pth')

3. 代码解释

(1)数据加载与预处理:

  • 使用 `torchvision.datasets.MNIST` 加载 MNIST 数据集,并进行归一化处理。
  • 使用 `DataLoader` 将数据加载到训练过程中。

(2)定义生成器和判别器:

  • 生成器:将随机噪声向量转换为 28x28 的图像。
  • 判别器:将输入图像分类为真实图像或生成图像。

(3)训练过程:

  • 训练判别器:

使用真实图像和生成图像分别计算损失。

优化判别器,使其能够更好地区分真假图像。

  • 训练生成器:

生成图像并让判别器判断。

优化生成器,使其生成的图像更逼真。

(4)可视化:

每个 epoch 结束时,生成并显示一些生成的图像,观察生成器的进展。

(5)保存模型:

训练完成后,保存生成器和判别器的模型参数。

4.  输出结果

运行代码后,你会看到生成器生成的图像逐渐变得逼真。最终,生成器能够生成看起来像手写数字的图像。

03

总结

图片

生成对抗网络(GAN)就像是一个生成器(造假者)和判别器(警察)之间的博弈游戏。通过不断地对抗和学习,造假者(生成器)会越来越擅长制造逼真的数据,而警察(判别器)会越来越擅长鉴别真假。这种对抗训练的方式让生成器能够生成高质量的数据,广泛应用于图像生成、视频生成、艺术创作等领域。

来源:码农随心笔记

THE END