avatar

Yiyu Qiu

Undergraduate Student
Shanghai Jiao Tong University
3500063778@sjtu.edu.cn


VAE代码实践与原理解析

· Yiyu Qiu VAE DDPM DDIM SMLD SDE/ODE PyTorch


引言

在上一篇博客中,我们详细梳理了变分自编码器(VAE)的理论基础。本篇博客将通过 PyTorch 实现一个简单的 VAE,并结合代码逐步解析其背后的数学原理,帮助大家更好地理解 VAE 的工作机制。


1. 数据准备

我们以经典的 MNIST 数据集为例,使用 PyTorch 的 DataLoader 加载数据:

# --- 数据准备 (MNIST) ---
transform = transforms.ToTensor()
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

原理解析

MNIST 数据集包含 28x28 的灰度图像。我们将其展平为 784 维的向量,作为 VAE 的输入。VAE 的目标是学习这些图像的隐变量表示,并能够从隐变量中重建图像。

2. 模型定义

VAE 的核心由三个部分组成:编码器、重参数化层和解码器。

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # 编码器
        self.fc1 = nn.Linear(784, 400)
        self.fc2_mu = nn.Linear(400, latent_dim)
        self.fc2_logvar = nn.Linear(400, latent_dim)
        
        # 解码器
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2_mu(h), self.fc2_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

原理解析

编码器(网络前2层):将输入数据 $x$ 映射到隐变量的均值 $\mu$ 和对数方差 $\log \sigma^2$。

数学公式:$q_\phi(z\mid x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma_\phi^2(x)))$

代码实现:fc1 提取特征,fc2_mu 和 fc2_logvar 分别输出均值和对数方差。

重参数化层:通过重参数化技巧从隐变量分布中采样。

数学公式:$z = \mu + \epsilon \cdot \sigma$,其中 $\epsilon \sim \mathcal{N}(0, I)$

代码实现:reparameterize 函数。

解码器:将隐变量 $z$ 映射回原始数据空间。

代码实现:fc3 和 fc4 还原数据,并通过 sigmoid 将输出限制在 [0, 1](这里限制到[0,1]是针对MNSIT特别改动,因为MNSIT为黑白图片(即数据是二值的)经过totensor后会被限制到0~1,因此这里我们把$p_\theta(x\mid z)$看做了伯努利分布,因此用BCE替换了二次范数)。

3. 损失函数

def loss_function(recon_x, x, mu, logvar):
    # 1. 重构损失 (Binary Cross Entropy)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # 2. KL 散度公式
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

原理解析

重构损失:衡量重建图像与原始图像之间的差异。

数学公式:$\mathbb{E}{q\phi(z x)}[\log p_\theta(x z)]$

代码实现:使用二元交叉熵(BCE)计算像素级误差。

KL 散度:约束隐变量分布接近标准正态分布。

数学公式:$D_{KL}(q_\phi(z x) p(z))$

代码实现:根据 KL 散度的闭式解直接计算。

4. 训练循环

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    print(f'Epoch {epoch+1}, Average Loss: {train_loss / len(train_loader.dataset):.4f}')

原理解析

  1. 前向传播:输入数据经过编码器、重参数化层和解码器,得到重建图像。
  2. 计算损失:调用 loss_function 计算重构损失和 KL 散度。
  3. 反向传播:通过 loss.backward() 计算梯度,并更新模型参数。

5. 生成新图像

训练完成后,我们可以在隐空间中随机采样,并通过解码器生成新图像:

def generate_digit(model):
    model.eval()
    with torch.no_grad():
        # 在隐空间随机采样一个向量
        sample = torch.randn(1, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        plt.imshow(sample.view(28, 28), cmap='gray')
        plt.show()

generate_digit(model)

原理解析

隐空间采样:从标准正态分布中采样隐变量 $z$。

解码生成:通过解码器将 $z$ 映射回图像空间。


Powered by Jekyll and Minimal Light theme.