Cycle-GAN代码解读
1  model.py⽂件
1.1  初始化函数
as nn
functional as F
import torch
# 初始化函数
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
if hasattr(m, "bias") and m.bias is not None:
elif classname.find("BatchNorm2d") != -1:
1.2  RESNET 模块定义
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
从⽣成器中截取⼀个resnet模块其结构如下所⽰。
1.3  模型定义
⽣成器定义:模型⼀上来就是3个“卷积块”,每个卷积块包含:⼀个2D卷积层,⼀个Instance Normalization层和⼀个ReLU。这3个“卷积块”是⽤来降采样的。然后是9个“残差块”,每个残差块包含2个卷积层,每个卷积层后⾯都有⼀个Instance Normalization 层,第⼀个Instance Normalization层后⾯是ReLU激活函数,这些使⽤残差连接。然后过3个“上采样块”,每个块包含⼀个2D转置卷积层,1个Instance Normalization和1个ReLU激活函数。最后⼀层是⼀个2D卷积层,使⽤tanh作为激活函数,该层⽣成的形状为
(256,256,3)的图像。这个Generator的输⼊和输出的⼤⼩是⼀摸⼀样的,都是(256,256,3)。
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial convolution block
# 初始化卷积模块
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
# 降采样 3个卷积模块
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
# resnet模块  num_residual_blocks=9
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
# 上采样
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
# 输出层
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
def forward(self, x):
del(x)
判别器定义:判别⽹络的架构类似于PatchGAN中的判别⽹络架构,是⼀个包含⼏个卷积块的深度卷积神经⽹络。
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
# 计算判别器输出的图⽚⼤⼩(PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
del(img)
判别器的结构如下所⽰:
2  datasets.py⽂件
主要是ImageDataset函数的操作,__init__操作将trainA和trainB的路径读⼊files_A 和files_B;__getitem__对两个⽂件夹的图⽚进⾏读取,若不是RGB图⽚则进⾏转换;__len__返回两个⽂件夹数据数量的⼤值。
import os
from torch.utils.data import Dataset
from PIL import Image
ansforms as transforms
# 转为rgb图⽚
def to_rgb(image):
rgb_image = w("RGB", image.size)
rgb_image.paste(image)
return rgb_image
variable used in lambda# 对数据进⾏读取
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, "trainA") + "/*.*"))
self.files_B = sorted(glob.glob(os.path.join(root, "trainB") + "/*.*"))
'''
self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))
'''
def __getitem__(self, index):
image_A = Image.open(self.files_A[index % len(self.files_A)])
if self.unaligned:
image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])        else:
image_B = Image.open(self.files_B[index % len(self.files_B)])
# Convert grayscale images to rgb
if de != "RGB":
image_A = to_rgb(image_A)
if de != "RGB":
image_B = to_rgb(image_B)
item_A = ansform(image_A)
item_B = ansform(image_B)
return {"A": item_A, "B": item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
3  utils.py⽂件
主要关注学习率衰减(LambdaLR)。
import datetime
import sys
from torch.autograd import Variable
import torch
import numpy as np
from torchvision.utils import save_image
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))
class LambdaLR:
def __init__(self, n_epochs, offset, decay_start_epoch):
assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
4  cyclegan.py⽂件
4.1  导⼊相关库以及进⾏参数设置
导⼊相关库