快速入门指南

这里我们设置一个简单的初学者示例,用于在 MNIST 图像数据集上进行训练。

安装

首先我们需要安装 vegans。您可以通过以下方式之一进行安装:

pip install vegans

或者通过

git clone https://github.com/unit8co/vegans.git
cd vegans
pip install -e .

通过以下方式测试模块是否可以导入:

python -c "import vegans"

加载数据

这里我们将快速加载数据。专门的数据加载器适用于:

  • MNIST:MNISTLoader

  • FashionMNIST:FashionMNISTLoader

  • CelebA:CelebALoader

  • CIFAR10:Cifar10Loader

  • CIFAR100:Cifar100Loader

只有前两个数据集会自动下载。让我们使用 loading 模块加载 MNIST 数据:

import vegans.utils.loading as loading
loader = loading.MNISTLoader(root=None)
X_train, y_train, X_test, y_test = loader.load()

如果数据尚未存在,此操作会将其下载到 root (默认路径为: {{ Home directory }}/.vegans) 目录中。mnist 数据集的每张图片形状为 (1, 32, 32),而标签形状为 [10, 1],这是原始标签的独热编码(one-hot encoded)版本。

现在我们可以开始定义网络了。

模型定义

需要定义哪种类型的网络取决于您使用的算法。主要有三种不同的选择:

  1. GAN1v1 需要
    • 生成器

    • 判别器

  2. GANGAE 需要
    • 生成器

    • 判别器

    • 编码器

  3. VAE 需要
    • 编码器

    • 解码器

在本指南中,我们将使用属于第一类算法的 VanillaGAN

首先,我们需要确定所有网络的输入和输出维度。对于无监督 / 无条件的情况,这很简单:

  • 生成器
    • 输入:z_dim 潜在维度(超参数)

    • 输出:x_dim 图像维度

  • 判别器
    • 输入:x_dim 图像维度

    • 输出:1 单个输出节点(可能有所不同)

对于有监督 / 有条件算法,会稍微复杂一些:

  • 条件生成器
    • 输入:z_dim + y_dim 潜在维度和标签维度

    • 输出:x_dim 图像维度

  • 条件判别器
    • 输入:x_dim + y_dim 图像维度和标签维度

    • 输出:1 单个输出节点(可能有所不同)

我们可以通过以下方式获取这些维度:

x_dim = X_train.shape[1:]
y_dim = y_train.shape[1:]
z_dim = 64

gen_in_dim = vegans.utils.utils.get_input_dim(z_dim, y_dim)
adv_in_dim = vegans.utils.utils.get_input_dim(x_dim, y_dim)

定义生成器和判别器架构无疑是 GAN 训练中最重要(也是最困难)的部分。我们将使用以下架构:

class MyGenerator(nn.Module):
    def __init__(self, gen_in_dim, x_dim):
        super().__init__()

        self.encoding = nn.Sequential(
            nn.Conv2d(in_channels=nr_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(num_features=128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(0.2),
        )
        self.decoding = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
        )
        self.output = nn.Sigmoid()

    def forward(self, x):
        x = self.encoding(x)
        x = self.decoding(x)
        return self.output(x)

generator = MyGenerator(gen_in_dim=gen_in_dim, x_dim=x_dim)

几乎相同的架构可以再次从 loading 模块中一行代码加载,该模块会负责选择正确的输入维度。

generator = loader.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim)
discriminator = loader.load_adversary(x_dim=x_dim, y_dim=y_dim, adv_type="Discriminator")

gan_model = ConditionalVanillaGAN(
    generator=generator, adversary=discriminator, x_dim=x_dim, z_dim=z_dim, y_dim=y_dim,
)
gan_model.summary()

模型训练

现在模型训练可以一行代码完成:

gan_model.fit(X_train=X_train, y_train=y_train)

此步骤中有相当多可选的超参数可供选择。请参阅下面的完整代码示例。GAN 的训练可能需要一段时间,具体取决于您的网络大小、训练样本数量和您的硬件。

模型评估

最后我们可以使用以下方式查看 GAN 的结果:

samples, losses = gan_model.get_training_results(by_epoch=False)

fixed_labels = np.argmax(gan_model.get_fixed_labels(), axis=1)
fig, axs = plot_images(images=samples, labels=fixed_labels, show=False)
plt.show()

从现在开始,您还可以通过提供标签作为输入来生成示例。

test_labels = np.eye(N=10)
test_samples = gan_model.generate(y=test_labels)
fig, axs = plot_images(images=test_samples, labels=np.argmax(test_labels, axis=1))

保存和加载模型

网络训练完成后,可以使用以下方式轻松保存:

gan_model.save("model.torch")

并在之后加载:

gan_model = VanillaGAN.load("model.torch")

或者

gan_model = torch.load("model.torch")

完整代码片段

这是之前代码的单个完整块:

import numpy as np
import vegans.utils.loading as loading
from vegans.utils.utils import plot_images
from vegans.GAN import ConditionalVanillaGAN

loader = loading.MNISTLoader(root=None)
X_train, y_train, X_test, y_test = loader.load()

x_dim = X_train.shape[1:]
y_dim = y_train.shape[1:]
z_dim = 64

generator = loader.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim)
discriminator = loader.load_adversary(x_dim=x_dim, y_dim=y_dim, adv_type="Discriminator")

gan_model = ConditionalVanillaGAN(
    generator=generator, adversary=discriminator,
    x_dim=x_dim, z_dim=z_dim, y_dim=y_dim,
    optim=None, optim_kwargs=None,                # Optional
    feature_layer=None,                           # Optional
    fixed_noise_size=32,                          # Optional
    device=None,                                  # Optional
    ngpu=None,                                    # Optional
    folder=None,                                  # Optional
    secure=True                                   # Optional
)

gan_model.summary()
gan_model.fit(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,           # Optional
    y_test=y_test,           # Optional
    batch_size=32,           # Optional
    epochs=2,                # Optional
    steps=None,              # Optional
    print_every="0.2e",      # Optional
    save_model_every=None,   # Optional
    save_images_every=None,  # Optional
    save_losses_every=10,    # Optional
    enable_tensorboard=False # Optional
)
samples, losses = gan_model.get_training_results(by_epoch=False)

fixed_labels = np.argmax(gan_model.get_fixed_labels(), axis=1)
fig, axs = plot_images(images=samples, labels=fixed_labels)

test_labels = np.eye(N=10)
test_samples = gan_model.generate(y=test_labels)
fig, axs = plot_images(images=test_samples, labels=np.argmax(test_labels, axis=1))