引言
在上一篇博客中,我们详细梳理了变分自编码器(VAE)的理论基础。本篇博客将通过 PyTorch 实现一个简单的 VAE,并结合代码逐步解析其背后的数学原理,帮助大家更好地理解 VAE 的工作机制。
1. 数据准备
我们以经典的 MNIST 数据集为例,使用 PyTorch 的 DataLoader 加载数据:
# --- 数据准备 (MNIST) ---
transform = transforms.ToTensor()
train_loader = DataLoader(
datasets.MNIST('./data', train=True, download=True, transform=transform),
batch_size=batch_size,
shuffle=True
)
原理解析
MNIST 数据集包含 28x28 的灰度图像。我们将其展平为 784 维的向量,作为 VAE 的输入。VAE 的目标是学习这些图像的隐变量表示,并能够从隐变量中重建图像。
2. 模型定义
VAE 的核心由三个部分组成:编码器、重参数化层和解码器。
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# 编码器
self.fc1 = nn.Linear(784, 400)
self.fc2_mu = nn.Linear(400, latent_dim)
self.fc2_logvar = nn.Linear(400, latent_dim)
# 解码器
self.fc3 = nn.Linear(latent_dim, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h = F.relu(self.fc1(x))
return self.fc2_mu(h), self.fc2_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
原理解析
编码器(网络前2层):将输入数据 $x$ 映射到隐变量的均值 $\mu$ 和对数方差 $\log \sigma^2$。
数学公式:$q_\phi(z\mid x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma_\phi^2(x)))$
代码实现:fc1 提取特征,fc2_mu 和 fc2_logvar 分别输出均值和对数方差。
重参数化层:通过重参数化技巧从隐变量分布中采样。
数学公式:$z = \mu + \epsilon \cdot \sigma$,其中 $\epsilon \sim \mathcal{N}(0, I)$
代码实现:reparameterize 函数。
解码器:将隐变量 $z$ 映射回原始数据空间。
代码实现:fc3 和 fc4 还原数据,并通过 sigmoid 将输出限制在 [0, 1](这里限制到[0,1]是针对MNSIT特别改动,因为MNSIT为黑白图片(即数据是二值的)经过totensor后会被限制到0~1,因此这里我们把$p_\theta(x\mid z)$看做了伯努利分布,因此用BCE替换了二次范数)。
3. 损失函数
def loss_function(recon_x, x, mu, logvar):
# 1. 重构损失 (Binary Cross Entropy)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# 2. KL 散度公式
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
原理解析
重构损失:衡量重建图像与原始图像之间的差异。
| 数学公式:$\mathbb{E}{q\phi(z | x)}[\log p_\theta(x | z)]$ |
代码实现:使用二元交叉熵(BCE)计算像素级误差。
KL 散度:约束隐变量分布接近标准正态分布。
| 数学公式:$D_{KL}(q_\phi(z | x) | p(z))$ |
代码实现:根据 KL 散度的闭式解直接计算。
4. 训练循环
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch+1}, Average Loss: {train_loss / len(train_loader.dataset):.4f}')
原理解析
- 前向传播:输入数据经过编码器、重参数化层和解码器,得到重建图像。
- 计算损失:调用 loss_function 计算重构损失和 KL 散度。
- 反向传播:通过 loss.backward() 计算梯度,并更新模型参数。
5. 生成新图像
训练完成后,我们可以在隐空间中随机采样,并通过解码器生成新图像:
def generate_digit(model):
model.eval()
with torch.no_grad():
# 在隐空间随机采样一个向量
sample = torch.randn(1, latent_dim).to(device)
sample = model.decode(sample).cpu()
plt.imshow(sample.view(28, 28), cmap='gray')
plt.show()
generate_digit(model)
原理解析
隐空间采样:从标准正态分布中采样隐变量 $z$。
解码生成:通过解码器将 $z$ 映射回图像空间。