快速入门指南¶
这里我们设置一个简单的初学者示例,用于在 MNIST
图像数据集上进行训练。
安装¶
首先我们需要安装 vegans
。您可以通过以下方式之一进行安装:
pip install vegans
或者通过
git clone https://github.com/unit8co/vegans.git
cd vegans
pip install -e .
通过以下方式测试模块是否可以导入:
python -c "import vegans"
加载数据¶
这里我们将快速加载数据。专门的数据加载器适用于:
MNIST:MNISTLoader
FashionMNIST:FashionMNISTLoader
CelebA:CelebALoader
CIFAR10:Cifar10Loader
CIFAR100:Cifar100Loader
只有前两个数据集会自动下载。让我们使用 loading
模块加载 MNIST
数据:
import vegans.utils.loading as loading
loader = loading.MNISTLoader(root=None)
X_train, y_train, X_test, y_test = loader.load()
如果数据尚未存在,此操作会将其下载到 root
(默认路径为: {{ Home directory }}/.vegans) 目录中。mnist 数据集的每张图片形状为 (1, 32, 32)
,而标签形状为 [10, 1],这是原始标签的独热编码(one-hot encoded)版本。
现在我们可以开始定义网络了。
模型定义¶
需要定义哪种类型的网络取决于您使用的算法。主要有三种不同的选择:
- GAN1v1 需要
生成器
判别器
- GANGAE 需要
生成器
判别器
编码器
- VAE 需要
编码器
解码器
在本指南中,我们将使用属于第一类算法的 VanillaGAN
。
首先,我们需要确定所有网络的输入和输出维度。对于无监督 / 无条件的情况,这很简单:
- 生成器
输入:
z_dim
潜在维度(超参数)输出:
x_dim
图像维度
- 判别器
输入:
x_dim
图像维度输出:
1
单个输出节点(可能有所不同)
对于有监督 / 有条件算法,会稍微复杂一些:
- 条件生成器
输入:
z_dim + y_dim
潜在维度和标签维度输出:
x_dim
图像维度
- 条件判别器
输入:
x_dim + y_dim
图像维度和标签维度输出:
1
单个输出节点(可能有所不同)
我们可以通过以下方式获取这些维度:
x_dim = X_train.shape[1:]
y_dim = y_train.shape[1:]
z_dim = 64
gen_in_dim = vegans.utils.utils.get_input_dim(z_dim, y_dim)
adv_in_dim = vegans.utils.utils.get_input_dim(x_dim, y_dim)
定义生成器和判别器架构无疑是 GAN 训练中最重要(也是最困难)的部分。我们将使用以下架构:
class MyGenerator(nn.Module):
def __init__(self, gen_in_dim, x_dim):
super().__init__()
self.encoding = nn.Sequential(
nn.Conv2d(in_channels=nr_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(num_features=64),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(num_features=128),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(num_features=256),
nn.LeakyReLU(0.2),
)
self.decoding = nn.Sequential(
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=128),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=64),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=32),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
)
self.output = nn.Sigmoid()
def forward(self, x):
x = self.encoding(x)
x = self.decoding(x)
return self.output(x)
generator = MyGenerator(gen_in_dim=gen_in_dim, x_dim=x_dim)
几乎相同的架构可以再次从 loading
模块中一行代码加载,该模块会负责选择正确的输入维度。
generator = loader.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim)
discriminator = loader.load_adversary(x_dim=x_dim, y_dim=y_dim, adv_type="Discriminator")
gan_model = ConditionalVanillaGAN(
generator=generator, adversary=discriminator, x_dim=x_dim, z_dim=z_dim, y_dim=y_dim,
)
gan_model.summary()
模型训练¶
现在模型训练可以一行代码完成:
gan_model.fit(X_train=X_train, y_train=y_train)
此步骤中有相当多可选的超参数可供选择。请参阅下面的完整代码示例。GAN 的训练可能需要一段时间,具体取决于您的网络大小、训练样本数量和您的硬件。
模型评估¶
最后我们可以使用以下方式查看 GAN 的结果:
samples, losses = gan_model.get_training_results(by_epoch=False)
fixed_labels = np.argmax(gan_model.get_fixed_labels(), axis=1)
fig, axs = plot_images(images=samples, labels=fixed_labels, show=False)
plt.show()
从现在开始,您还可以通过提供标签作为输入来生成示例。
test_labels = np.eye(N=10)
test_samples = gan_model.generate(y=test_labels)
fig, axs = plot_images(images=test_samples, labels=np.argmax(test_labels, axis=1))
保存和加载模型¶
网络训练完成后,可以使用以下方式轻松保存:
gan_model.save("model.torch")
并在之后加载:
gan_model = VanillaGAN.load("model.torch")
或者
gan_model = torch.load("model.torch")
完整代码片段¶
这是之前代码的单个完整块:
import numpy as np
import vegans.utils.loading as loading
from vegans.utils.utils import plot_images
from vegans.GAN import ConditionalVanillaGAN
loader = loading.MNISTLoader(root=None)
X_train, y_train, X_test, y_test = loader.load()
x_dim = X_train.shape[1:]
y_dim = y_train.shape[1:]
z_dim = 64
generator = loader.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim)
discriminator = loader.load_adversary(x_dim=x_dim, y_dim=y_dim, adv_type="Discriminator")
gan_model = ConditionalVanillaGAN(
generator=generator, adversary=discriminator,
x_dim=x_dim, z_dim=z_dim, y_dim=y_dim,
optim=None, optim_kwargs=None, # Optional
feature_layer=None, # Optional
fixed_noise_size=32, # Optional
device=None, # Optional
ngpu=None, # Optional
folder=None, # Optional
secure=True # Optional
)
gan_model.summary()
gan_model.fit(
X_train=X_train,
y_train=y_train,
X_test=X_test, # Optional
y_test=y_test, # Optional
batch_size=32, # Optional
epochs=2, # Optional
steps=None, # Optional
print_every="0.2e", # Optional
save_model_every=None, # Optional
save_images_every=None, # Optional
save_losses_every=10, # Optional
enable_tensorboard=False # Optional
)
samples, losses = gan_model.get_training_results(by_epoch=False)
fixed_labels = np.argmax(gan_model.get_fixed_labels(), axis=1)
fig, axs = plot_images(images=samples, labels=fixed_labels)
test_labels = np.eye(N=10)
test_samples = gan_model.generate(y=test_labels)
fig, axs = plot_images(images=test_samples, labels=np.argmax(test_labels, axis=1))