diffusion-model/main.ipynb

952 lines
2.0 MiB
Plaintext
Raw Normal View History

2024-05-09 13:42:22 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Helper\n",
"\n"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 2,
2024-05-09 13:42:22 +02:00
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import inspect\n",
"import torch\n",
2024-05-11 18:46:56 +02:00
"import math\n",
"from einops import rearrange\n",
2024-05-11 23:59:23 +02:00
"from torch import einsum\n",
"import torch.nn.functional as F"
2024-05-09 13:42:22 +02:00
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 3,
2024-05-09 13:42:22 +02:00
"metadata": {},
"outputs": [],
"source": [
"def exists(x):\n",
" return x is not None\n",
"\n",
"def default(val, d):\n",
" if exists(val):\n",
" return val\n",
" return d() if inspect.isfunction(d) else d\n",
"\n",
"class Residual(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
"\n",
" def forward(self, x, *args, **kwargs):\n",
" return self.fn(x, *args, **kwargs) + x\n",
"\n",
"# 上采样(反卷积)\n",
"def Upsample(dim):\n",
" return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n",
"\n",
"def Downsample(dim):\n",
" return nn.Conv2d(dim, dim, 4, 2 ,1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Positional embedding\n",
"\n",
"目的是让网络知道\n",
"当前是哪一个step. \n",
"ddpm采用正弦位置编码\n",
"\n",
"输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n",
"这个tensor会被加到每一个残差模块中.\n",
"\n",
"总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 4,
2024-05-09 13:42:22 +02:00
"metadata": {},
"outputs": [],
"source": [
"class SinusolidalPositionEmbedding(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.dim = dim\n",
"\n",
2024-05-11 23:59:23 +02:00
"\n",
2024-05-09 13:42:22 +02:00
" def forward(self, time):\n",
" device = time.device\n",
" half_dim = self.dim // 2\n",
2024-05-11 23:59:23 +02:00
" embeddings = math.log(10000) / (half_dim - 1)# 1/10000^(2i/d) = 1/10000^(i/(d/2)) = 10000^(-i/(d/2)\n",
2024-05-09 13:42:22 +02:00
" embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n",
" embeddings = time[:, :, None] * embeddings[None, None, :]\n",
" embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n",
" return embeddings\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ResNet/ConvNeXT block"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 5,
2024-05-09 13:42:22 +02:00
"metadata": {},
"outputs": [],
"source": [
2024-05-11 18:46:56 +02:00
"# Block在init的时候创建了一个卷积层一个归一化层一个激活函数\n",
"# 前向的过程包括,首先卷积,再归一化,最后激活\n",
2024-05-09 13:42:22 +02:00
"class Block(nn.Module):\n",
" def __init__(self, dim, dim_out, groups = 8):\n",
" super().__init__()\n",
2024-05-11 18:46:56 +02:00
" self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) # 提取特征\n",
" self.norm = nn.GroupNorm(groups, dim_out) # 归一化, 使得网络训练更快速,调整和缩放神经网络中间层来实现梯度的稳定\n",
2024-05-09 13:42:22 +02:00
" self.act = nn.SiLU()\n",
" \n",
" def forward(self, x, scale_shift = None):\n",
" x = self.proj(x)\n",
" x = self.norm(x)\n",
"\n",
" if exists(scale_shift):\n",
" scale, shift = scale_shift\n",
" x = x * (scale + 1) + shift\n",
"\n",
" x = self.act(x)\n",
" return x\n",
"\n",
"class ResnetBlock(nn.Module):\n",
" def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n",
" super().__init__()\n",
" self.mlp = (\n",
" nn.Sequential(\n",
" nn.SiLU(), \n",
" nn.Linear(time_emb_dim, dim_out)\n",
" )\n",
" if exists(time_emb_dim) else None\n",
" )\n",
2024-05-11 18:46:56 +02:00
" # 第一个块的作用是将输入的特征图的维度从dim变成dim_out\n",
2024-05-09 13:42:22 +02:00
" self.block1 = Block(dim, dim_out, groups=groups)\n",
" self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n",
" self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
" \n",
" def forward(self, x, time_emb = None):\n",
" h = self.block1(x)\n",
"\n",
" if exists(self.mlp) and exists(time_emb):\n",
2024-05-11 18:46:56 +02:00
" # 时间序列送到多层感知机里\n",
2024-05-09 13:42:22 +02:00
" time_emb = self.mlp(time_emb)\n",
2024-05-11 18:46:56 +02:00
" # 为了让时间序列的维度和特征图的维度一致,所以需要增加一个维度\n",
2024-05-09 13:42:22 +02:00
" h = rearrange(time_emb, 'b n -> b () n') + h\n",
"\n",
" h = self.block2(h)\n",
" return h + self.res_conv(x)\n",
" \n",
" "
]
2024-05-11 18:46:56 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 6,
2024-05-11 18:46:56 +02:00
"metadata": {},
"outputs": [],
"source": [
"class ConvNextBlock(nn.Module):\n",
" def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):\n",
" super().__init()\n",
" self.mlp = (\n",
" nn.Sequential(\n",
" nn.GELU(),\n",
" nn.Linear(time_emb_dim, dim)\n",
" )\n",
" if exists(time_emb_dim) else None \n",
" )\n",
"\n",
" self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n",
"\n",
" self.net = nn.Sequential(\n",
" nn.GroupNorm(1, dim) if norm else nn.Identity(),\n",
" nn.Conv2d(dim, dim_out * mult, 3, padding=1),\n",
" nn.GELU(),\n",
" nn.GroupNorm(1, dim_out * mult),\n",
" nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),\n",
" )\n",
" self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
"\n",
" def forward(self, x, time_emb=None):\n",
" h = self.ds_conv(x)\n",
"\n",
" if exists(self.mlp) and exists(time_emb):\n",
" condition = self.mlp(time_emb)\n",
" h = rearrange(time_emb, 'b c -> b c 1 1') + h\n",
" h = self.net(h)\n",
" return h + self.res_conv(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Attention module"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 7,
2024-05-11 18:46:56 +02:00
"metadata": {},
"outputs": [],
"source": [
"class Attention(nn.Module):\n",
" def __init__(self, dim, heads=4, dim_head=32):\n",
" super().__init__()\n",
" self.scale = dim_head ** -0.5\n",
" self.heads = heads\n",
" hidden_dim = dim_head * dim # 计算隐藏层的维度\n",
" self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # 通过卷积层将输入的特征图的维度变成hidden_dim * 3\n",
" self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n",
" \n",
" def forward(self, x):\n",
" b, c, h ,w = x.shape\n",
" qkv = self.to_qkv(x).chunk(3, dim=1)\n",
" q, k, v= map(\n",
" lambda t: rearrange(t, \"b (h c) x y -> b h c (x y)\", h=self.heads),\n",
" qkv\n",
" )\n",
" q = q * self.scale\n",
"\n",
" sim = einsum(\"b h d i, b h d j -> b h i j\", q, k)\n",
" sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n",
" attn = sim.softmax(dim=-1)\n",
"\n",
" out = einsum(\"b h i j, b h d j -> b h i d\", attn, v)\n",
" out = rearrange(out, \"b h (x y) d -> b (h d) x y\", x=h, y=w)\n",
" return self.to_out(out)\n"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 8,
2024-05-11 18:46:56 +02:00
"metadata": {},
"outputs": [],
"source": [
"class LinearAttention(nn.Module):\n",
" def __init__(self, dim, heads=4, dim_head=32):\n",
" super().__init__()\n",
" self.scale = dim_head ** -0.5\n",
" self.heads = heads\n",
" hidden_dim = dim_head * heads\n",
" self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n",
" self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),\n",
" nn.GroupNorm(1, dim))\n",
"\n",
" def forward(self, x):\n",
" b, c, h, w = x.shape\n",
" qkv = self.to_qkv(x).chunk(3, dim=1)\n",
" q, k, v = map(\n",
" lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), \n",
" qkv\n",
" )\n",
"\n",
" q = q.softmax(dim=-2)\n",
" k = k.softmax(dim=-1)\n",
"\n",
" q = q * self.scale\n",
" context = torch.einsum(\"b h d n, b h e n -> b h d e\", k, v)\n",
"\n",
" out = torch.einsum(\"b h d e, b h d n -> b h e n\", context, q)\n",
" out = rearrange(out, \"b h c (x y) -> b (h c) x y\", h=self.heads, x=h, y=w)\n",
" return self.to_out(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Group Normalization"
]
},
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 9,
2024-05-11 18:46:56 +02:00
"metadata": {},
"outputs": [],
"source": [
"class PreNorm(nn.Module):\n",
" def __init__(self, dim, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" self.norm = nn.GroupNorm(1, dim)\n",
" \n",
" def forward(self, x):\n",
" x = self.norm(x)\n",
" return self.fn(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conditional U-Net"
]
},
2024-05-11 23:59:23 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. 首先输入通过一个卷积层同时计算step所对应得embedding\n",
"2. 通过一系列的下采样stage每个stage都包含2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + downsample operation \n",
"3. 在网络中间应用一个带attention的ResNet或者ConvNeXT\n",
"4. 通过一系列的上采样stage每个stage都包含2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + upsample operation\n",
"5. 最终通过一个ResNet/ConvNeXT blocl和一个卷积层。"
]
},
2024-05-11 18:46:56 +02:00
{
"cell_type": "code",
2024-05-11 23:59:23 +02:00
"execution_count": 10,
2024-05-11 18:46:56 +02:00
"metadata": {},
"outputs": [],
"source": [
2024-05-11 23:59:23 +02:00
"from functools import partial\n",
2024-05-11 18:46:56 +02:00
"class Unet(nn.Module):\n",
2024-05-11 23:59:23 +02:00
" def __init__(\n",
" self,\n",
" dim, # 这里的dim是特征图的维度\n",
" init_dim = None,\n",
" out_dim = None,\n",
" dim_mults=(1, 2, 4, 8),\n",
" channels = 3,\n",
" with_time_emb = True,\n",
" resnet_block_groups = 8,\n",
" use_convnext = True,\n",
" convnext_mult = 2\n",
" ):\n",
" super().__init__()\n",
" self.channels = channels\n",
"\n",
" init_dim = default(init_dim, dim // 3 * 2) # 初始化维度 为 dim // 3 * 2 /dim // 3 * 2可能是经验得来的\n",
" self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) # 这里的kernel_size是7padding是3这样的设置是为了保持特征图的大小不变\n",
"\n",
" dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # dim * m是为了增加特征图的维度\n",
" in_out = list(zip(dims[:-1], dims[1:])) # 创建输入维度和输出维度的元组\n",
"\n",
" if use_convnext:\n",
" block_klass = partial(ConvNextBlock, mult = convnext_mult)\n",
" else:\n",
" block_klass = partial(ResnetBlock, groups=resnet_block_groups)\n",
"\n",
" # time_embeddings:\n",
" if with_time_emb:\n",
" time_dim = dim * 4\n",
" self.time_mlp = nn.Sequential(\n",
" SinusolidalPositionEmbedding(dim),\n",
" nn.Linear(dim, time_dim),\n",
" nn.GELU(),\n",
" nn.Linear(time_dim, time_dim),\n",
" )\n",
" else:\n",
" time_dim = None\n",
" self.time_mlp = None\n",
"\n",
" # 降采样的layers\n",
" self.downs = nn.ModuleList([])\n",
" self.ups = nn.ModuleList([])\n",
" num_resolutions = len(in_out)\n",
"\n",
" for ind, (dim_in, dim_out) in enumerate(in_out):\n",
" is_last = ind >= (num_resolutions - 1)\n",
" self.downs.append(\n",
" nn.ModuleList(\n",
" [\n",
" block_klass(dim_in, dim_out, time_emb_dim=time_dim),\n",
" block_klass(dim_out, dim_out, time_emb_dim=time_dim),\n",
" Residual(PreNorm(dim_out, LinearAttention(dim_out))),\n",
" Downsample(dim_out) if not is_last else nn.Identity()\n",
" ]\n",
" )\n",
" )\n",
" \n",
" # 因为中间的layer是单独的或者说不是对称的一部分所以需要单独处理\n",
" mid_dim = dims[-1] # 下降到最后一层的维度 是dims的最后一个维度\n",
" self.mid_blocks = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)\n",
" self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention))\n",
" self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)\n",
"\n",
" # 上采样的layers\n",
" for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):\n",
" is_last = ind >= (num_resolutions - 1)\n",
" self.ups.append(\n",
" nn.ModuleList(\n",
" [\n",
" block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),\n",
" block_klass(dim_in, dim_in, time_emb_dim=time_dim),\n",
" Residual(PreNorm(dim_in, LinearAttention(dim_in))),\n",
" Upsample(dim_in) if not is_last else nn.Identity()\n",
" ]\n",
" )\n",
" )\n",
" \n",
" out_dim = default(out_dim, channels)\n",
" self.final_conv = nn.Sequential(\n",
" block_klass(dim, dim),\n",
" nn.Conv2d(dim, out_dim, 1)\n",
" )\n",
"\n",
" def forward(self, x, time):\n",
" x = self.init_conv(x)\n",
" t = self.time_mlp(time) if exists(self.time_mlp) else None\n",
" h = []\n",
"\n",
" # 下采样\n",
" for block1, block2, attn, downsample in self.downs:\n",
" x = block1(x, t)\n",
" x = block2(x, t)\n",
" x = attn(x)\n",
" h.append(x)\n",
" x = downsample(x)\n",
"\n",
" # 中间的layer\n",
" x = self.mid_blocks(x, t)\n",
" x = self.mid_attn(x)\n",
" x = self.mid_block2(x, t)\n",
"\n",
" # 上采样\n",
" for block1, block2, attn, upsample in self.ups:\n",
" x = torch.cat((x, h.pop()), dim = 1)\n",
" x = block1(x, t)\n",
" x = block2(x, t)\n",
" x = attn(x)\n",
" x = upsample(x)\n",
"\n",
" return self.final_conv(x)\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 前向扩散"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def cosine_beta_schedule(timesteps, s = 0.008):\n",
" steps = timesteps + 1\n",
" x = torch.linspace(0, timesteps, steps)\n",
" alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n",
" alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n",
" betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n",
" return torch.clip(betas, 0.0001, 0.9999)\n",
"\n",
"def linear_beta_schedule(timesteps):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" return torch.linspace(beta_start, beta_end, timesteps)\n",
"\n",
"def quadratic_beta_schedule(timesteps):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2\n",
"\n",
"def sigmoid_beta_schedule(timesteps):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" betas = torch.linspace(-6, 6, timesteps)\n",
" return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"timesteps = 200\n",
"\n",
"betas = linear_beta_schedule(timesteps=timesteps)\n",
"\n",
"# betas是添加噪声的方差\n",
"# alpha是去掉噪声的量\n",
"# alpha_cumprod是去掉噪声的量的累积乘积\n",
"# ddpm中的方差是不变的, 计算公式是var = beta_t * (1 - alpha_{t-1}^cumprod) / (1 - alpha_t)\n",
"# 正向的计算的数学公式是这样: x_t = sqrt(alpha_t^cumprod) * x_0 + sqrt(1 - alpha_t^cumprod) * N(0, 1)\n",
"\n",
"\n",
"alphas = 1. - betas\n",
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
"alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1,0), value=1.0)\n",
"sqrt_recip_alphas = torch.sqrt(1.0 / alphas)\n",
"\n",
"# 计算diffusion q(x_t | x_{t-1}) 和其他的\n",
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)\n",
"\n",
"# 计算后验方差q(x_{t-1} | x_t, x_0)\n",
"posteriors_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n",
"\n",
"def extract(a, t, x_shape):\n",
" batch_size = t.shape[0]\n",
" out = a.gather(-1, t.cpu())\n",
" return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)\n",
" # return out.reshape(batch_size, *((1,) * (len(x_shape - 1))).to(t.device))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nDT8ya6t29amB7WkZ182sjnnSvbeJ/szBxEBhrAAWxQochMUuAPqFmVAlhE3YlwFUQAJCxSILOxwhCPi/0++915rzWyM8WU9743CgVto0tsKb/Lgf/r5P1QwkJiSAwn29siM6/6W8s+n6XtSAxGFHRF7Y/G2/OnxcAAVQiq1dswc4tWqo8UPt/w6dofq8enBRf+t+KlV7g9LAgqxYtNWG5aWS6ialHMWBJBjTLWI1jrFeBjGFGKEPHSXbc3aCNJKnFuDFLHXjzG/aN2FeD1MD+tM1jWEHmgtsWc2Sm+53p22y7IohYKqgdjexZxKSdnvg3UEuMStcxcmu6zXw3Sq2cWSrcNEPw71oyo2y+7O3XXOnMrBtSws6Jq6EhxUM8S5lkOEm+5P8/Itteeht739tM2xdyrnPFDZvHLD422f2WatSSqngOcjPX+7Pz2eQghjd9jCT0a3Gi9gIedq0Ea/GycgRDQuW3HdKM2neH98OOactn1pkJWW0iYf3seTReSw4MPp8+7XHLfeubF7jGlZ1tfj8fO2VbLVOpCtGCe3m0yHhwzPfs+fnv7udn/zYTtdLts9KBkOfXe9/9Rf9F6DWo9i8Z6uHz6e/PWdG4AyW84H+WXmF+FNtUG1HrCRAlC8pG+D+hS3+/k0cjvc9+duSve5nO1TKrvpVU5Ys+qdyfE2r1+nxyPVviRBGRkfycb37V/2R6zbSMq3PI7juNzvTg/nE9zfa5RldLgjQ3TSFPdjXa/Ybwy4ruswdEqp1e/GGCJa13Xq/kYap3itcD0MFwXHELeQv4L6yMwx3sZxCMnH6pEMa4LmdTpZxa01onHev2g7d2aC9iQiSlOub1VWlg81uwpvH07fvb1d+/FEir332iBRm+/v2vB0uMzb1wqi6SMzDn1b19BaySk4Y0jY2XHbNjvwur2d+x+WbckS+6m/r7NS+nabx+GQcGVyLCxQAaQhEgxaT6k+p4gaj0ZRLt/QzIgYPXOajg/mZb5pfbKmW+cvljTkDo4hpLtil1M/jZf78jOxCBgDCoFzgul4itGnuguEUnej/jaWux1g229EBJU1aKN74a+IGEs1tssJcsBa6HJ8fNn/KCJD70gAgObbYk3fdV1Gb+lhWf9o9FCLsrbzq50uxcfXHIDYCSvdq+vb89N0kpjvOfZj7/1mOxdLbQKkXIy5ueCA930vzLVW1SCloqZ+zXEqWlW1Y25cHdRU4pzT4TR6H6SS0R0gV5DSasyZpBCzECKiMnrb787oEMJoj4h12xelGJGksVNDKTDiwsZutURBEGLUBnWKcdNBRBgl59x34/t97vph9/GjMbbvZr+x1rkWo1khpRCzQM9aV0Cmaji1CqV2rGsBUpShhBqrNEVs2KBA8S1ZJBAbMgF6RztI2aM2zZpDivB/i3+ohr9r+oM5/n/a/tswFwQAUkIFGmhhkIdMTwinbvyQzG9lnxV8jJWM+xOHh6Z1hR5Vr5QCaaUaRMcaCxOBMQagldJERJEFgG1/toePexbM0ZmqLHkvqvXmEFOIqtKoRgW2xaxV6wedN9TWNhDS6r68K03O6fvtnWyvrSmlmoYuy6kfSstbCYMbWJdtTYBW2VbEbJuwbSygtMStSHPTMKZ06yytq0eDRNB1NsZcC5EeU2Kt+pT3UvI4mWV7bpKNc8DjvITJRtQmNWhFHLkSyzgeQk5v6Xa0g21soHuf1+F8jHX1YR47Lth8CH3XQW41VUFth55BLdsmVI3T+75b20ETAtJw8vJTw2jowUDX5FZoQ1Z9mpKTgJVyYw/MelZp1VW/r5+fvvsS/SUpEby7NlXMNeF/8uE/UirUJFP3AzUSmGPyQ/+RIS330k9jyhsCKz2mvHRj6vCw+R3ZiuhSY2nr2PfbWsx4SRERY/DL4+UX19vXYVJ+5wc7skahmPMdW3R2EBpW3yylEFLXjzFmZywitZIVccBXRaNfkckyo7FUcyilGHXSfMr1JlBIbam8l3gYxwMyx1hESpPQ6QPTWKpvcBu66eXtFbXqhv5+v6PAYF1JuYKxXQPh6O3gXJYvMa3T+Kig7CELaGxmsKeSMqjXjF8G9bcCNpZo1BEquF72OBe49/JPs/zYajZ6bNXvKw72hOoquitV+xK2+PVwtDVLy+As1/QEFIhTLq/WtLQ9dPqR7fvb2hSCwvHp8vTy9rvL5fH9mo/H6WX5f07DB2xWM+77ayqhcyelj/s+p1QeHz/c5qtROmzhMJ39FowdjZpaTSKISKH9SahB+rU2izZtXn8ax4PkXtNYa2ySCnaaD03i6FLcVs3n/nT6w8//7oen6b6XSv3Un27Pz8xxOtPb8u2k/nHGn/aw9PqzwtaP27q2sH/WY2FxDJh8sGZkQ1u49cMx7S9971ItMUhOwIBdx5rLffP9iFRPrRplduO61/eZ7F23UeFjU7eQ7tYaBTZ7p/Vdobomc4R8x93wg9r3MLUP6ofX9zfb25SC0thaqpC1ZiIAiiklpVEzpYg5WKu1dbC1RSGUcFqW5W//9h/9+PNXNlup0uKht+JMCf7GTmozhiao3OhbXn/NzBW/iVTjBDCAOCOHijlmTwoRse+OcW/W9sv643jC1kzYqSFoPhnqAO/FR9aFCKQ6kXOphLb69G5FCbbX+dvpcni9vj59eNhWfzwer7e1FmW4Z9ateOFIKueaff6jkl+d+t+sy6tz+n7bu8mR8hKsshLq3bjinPMbQu6cHtfyrpSS2kmZGhagFalaffDpz4ptjipFQdC2c+PkfFiFUwyKqQtlsW5pbceqWulYPlSopfqQ935QWglSLTEmdABACCHsnbUETMgx5oZOqR1gJ7lorYVvKZXb7J8eTnHPzFawgZbNb6MbqGJThxDXCqXU3fb69f1dW5Nz7ti+5F0p9WSGdd1Sp5BJ+9JsS0KTOXAobOFbeDfGdE297DcQMsoyq9JqkQqIDZsuvVIqt9pASmspx7E3iNhSMJZ88s7Zfd+dGhCVVt26Z8MKW2UmoeJrDFi3HE/KGGNKKYwSUsm1sTYNUJXNdK6URqhijEoZhQwACtqorMQsImbsU8k1poPrl1orSsHaQGrNrQGjUkhb3bXWiJhzrCkDQFGQGT4k+61Tf5D4z+dv//1u+jWNmdV/Nv8ZAAoCCCmgAg24aYQPVX1U6hP3txgWjVabrmLLaeYKjR2pgbRGgFq4wajUoGzNjZkBGwkQkSItgq3CNhbzVkc1fW1z6TKl+Oge110UMUDC5ntFg+lJOOcM0hxapW2uzfTd/X4/Ho8ALSbfWhZFACipdKiM4saylzQojLsl5Uv1l/Nfb34WTtIOVV47fspZBFerzb5K50wsL7GeCNAZI7UIVMRWWxQUMm5b4+COAHBb3o6X8e326vretY6dnvet74ew+MENITUgHF2ci0el8pon6ghARn6L1x4fK/gCW4rL0/GcdlDquG5Nqc32w7LurLSxOueIULTi2GiO32pro/vsqE9liXlv6JKhockBuTUIhDGXHmFo8KzChYYokpFba1zraMxb2fB/9eEfH4dz2FPvOilCpJZlOV+m5l3JiAqJo3NuXarr2PSpzo9d3397+3E4tVzWbjhta3La9a0xXtbtRZlMMJDVwEmAMUypCShGya3tTgEqjqkMncSQkc267sa4qR+gCSKuu1W6Cu6KGiJqrbdtEShWfYy7Nl1yzq37N2WC4U+11qa/1fBYa5/L3I8IYq0+xVCJXmLOpDRb9/ryrpQeu9EZ+3b9se9dqTHFTWvt1DFFGPrRpmmuEXvLknS+E+xkjnPqejmyC7msts8+3Ds3+XQlUlr/qhbK5dk61cIkUFhvLas5PoOM1ozOOWw2LIkptvq6hsPDp7Sus8IDoYLCiIj6vXIn1TltS85a2W/PP374/FAy5dpQUso7NtPZBwSNnPbwojhoPjawAFLaXWrozBHEZumEvxLw4A4hvZakgXTBP4FyBn+jubZyxToatOv2cpwe3vYynQyIsu28z0s37GDKFsGWk51wSV+rrJYJm83B2P4Ebcm
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from PIL import Image\n",
"import requests\n",
"url = 'http://images.cocodataset.org/val2017/000000039769.jpg'\n",
"image = Image.open(requests.get(url, stream=True).raw)\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"处理照片\n",
"\n",
"定义一个transform 接受一张图片([0, 255]),输出一个tensor([-1, 1])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 128, 128])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize\n",
"\n",
"image_size = 128\n",
"transform = Compose([\n",
" Resize(image_size),\n",
" CenterCrop(image_size),\n",
" ToTensor(),\n",
" Lambda(lambda t: (t * 2) - 1)\n",
"])\n",
"\n",
"x_start = transform(image).unsqueeze(0)\n",
"x_start.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"逆过程"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"reverse_transform = Compose([\n",
" Lambda(lambda x: (x + 1) / 2),\n",
" Lambda(lambda t: t.permute(1, 2, 0)),\n",
" Lambda(lambda t: t * 255.),\n",
" Lambda(lambda t: t.numpy().astype(np.uint8)),\n",
" ToPILImage(),\n",
"])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAIAAAACACAIAAABMXPacAACAJUlEQVR4nDz92ZPsWZIehrn7WX5b7JH7evOudWuv6qrqdXq6Z4AhZgAQMFGiCKORMslMj3qS6U1m+hck8UGPgowyykyUSI5EEOAADWC6Z+ul9u1W3S33zMiMPX7xW89x10NU6yUtH8LCMs/i/n2f+/kc/8uDPyMgg5pQETEpw8AiDCwOBAQRPAEKEiEJeBTnhQFIoSUUhkqTRrCEBsExMYpjBo0GRCESIAujAkZUXpygQ1EsrJBEGBABPIpiEBZQAISexAI6QgTUCCxACE4ESs+Vrzx6BAUgKMjiABAAWLwgiQiC8lg5FhJx4hHBCxMggvIABFyDK3xdsqvEe/YeGIQrdl6ABRBJoEZgAeNZamBCdgwaxYAySETgGGphBqy9m0rxlIszX135aofUJgWGKALVQd1WgSXthSqpLKiEVEwBIAk4BDGKNCgG0kBakBmBCBENEQgAiRWqjWDlPSkrLABeiAAEQCvBmj0pREAlMaIIkSCLAIJSRAKrrWOPDkWBYgEC8CTkIWD0AqoWR6CRBIGERZAUIKFnUAKi0QAqBkYhBMq9qxhZHBAwExGDgEcQRBAAFAFk9kII4FlQwNeIAIQAQITAiADiC3YLX1a+EhAv4kEceAEQJEABAQABUQxUgWdkQmARq0wARMACULNUIILAwgX6c66WzCOuDUpL6RIgBEhQCUIJbAEbSjHELMToShSDHoFE0AsSESFojcaBY2AnCOCEQQhAHCAQCQk74NXqKfEgSOgJtSavMQRBJARkEIeEIiCMNTAKeGQEJcAoBIAeHIJCYAFRSMzeCwoRiAMBBEBBRmbvEBFIsRBLhUCOfcVcuRIQAbSIF0BgQRAhYu8RDXHFAB4QvzvICAgACpABRKEXUUtXzn1eSsXsURAAK3EMAEAeHAo6EUARQQ/eg4ggISjQCkEhAkgJ7IURCAAq4ULkissrrm+4qkTWSIugAhQAj5CQIdAViAaOyWpQhAqRAJCBERBFmAEJtQMGAGEA8owKQMR7FgABQAZRIOzBAVciBhEQyLMgGWZGRAEEAQYUZhFEEREvSMyCICCs0DMqQFHfhQvx6EFQAFkqZPEiDkslRgExoAYgFkFw7GtfVN4LihdGEQLhVfQBYqnAMwiKuIKFkBhrFkZhXp0KQBY/93Xl6yWXjj0AAyADsiCA94Je2GPlwSOgB2ABEQFgQjKIKIqQBMADMHAlgEBexAFn4ibeP3X5kKu5OIMYol6K14QGNQI4gIi0RnECFXuPEoJBAEYhUQw1CNXABbP2XAOARxIWBd6v9kZYAEgcgfYiDARAzKTJ14AgojwDEaAGrmF1nYVIwAmw1AoDAELwAMDAKAxCDpwCLeBIFAGx1AA1inFSiRATEijCmkF7zyJZDbUXDygAiKsF5RJFeRRAFFAeHAGKeI+OxZAQAQAACpfsKp9lvirYexRhQQQEYQDPUoMAewbxIE4YxAESiAIBAiRUBEIigMqBE8CK2QuzgAOuADz7QvxTl82knosTgYioBh8AsogisqQsGi+if397HPhcOACLICJMhAigwCvQmtACeBZmYBYkQUSqwTEzMgMWAkojCioE8cJeHAJ6IYISpQIAEA9oyXMt4MEx1wpAI3oCAFZMiDUDECoQZPAgrNCxrBJgKcII2nvniVFqAHbCIChEhOLFOyGCGtkzAqIXAQHg7zZVQEAAURyDCLhKIPNl6aWWioFFAAQBxImwgAPv2bOAgAiIB0YRREWiBABBEJUgMIhHduJZ2AMAeCdSitTAJXMt7orrK1/WwCxgAAkwE69RK6AKRARQHJFygihKExggQfAkWhBXu4IkAgpEryKGIDMLMpTgUACRBcSDJ8DVp1GxeCnZ1+BJgASQFQEhCQEC1AhYSyXIzADgFCoUxeAQFEgtKBoCBZUiEhEHgvJd+AJEDxUJOo9OKkFPq1DlCLAWAUJhYYUKQRi8CLB4WQE0Ic8egCspK+9ydhWzgAhoAUbxDhCRPDCLiAiDePSr+4AiCpEREEhABAgRAMGDoIiIFOicwOqvdSK5uFz8mOuBVFMQMWSU7SA45wvnECAFDsBHYnOUANCIIiIEqQFEvEXt2TEoDcp7DwAaAVDpiktCQgJBQMGK2aEn8QoDTVqh8qvIWHMNToRFuAbRgAKMICQr8IcenBcmWB0ZUIIKA0RmESIC8RV4i+iBQbwIamREEAERLNk7rkCAQAuVBAQACEyiAIABRTRjzSJuBV6AhIWFWbjiKue6YmbwDOTFARCAQ/GCzIDMHhCYVyhNEDUhEyhBYmAABARgEWBA8uIFpAashZ1wIb5iycTPxM25nnA9RQnicKvTbsfBdmc90ng5Hr+8uLJINcOsLJzkGetSXJtsA4IICVE5RBY2KAAgWCtRiIgoCKJZkMGBJwYB8E68Y0+omApmhVh5EQQREQc1IAAQCFRAJA4AUZwgAhADongWL6AIQAhrqRA8ihIWixrJOQQUAvCeufTM5D2zE2FghaCACJlYKgCGigA1ETMAFB6YxTGssAMwV064ZvbAwp7BsSAAeZQVJgMgRg8iAITfQVX/HTgSBNQkUgMAMAEJICACQi3sAGrhXFzJbs5+JNVC/JwdB7ZkqZkCaxpJI2kme1sb693WYjTqhPiTx/eSJLyZpPP5zNd8O5qnVdngap2qBtkYjEVtUYEoIEQmTQoBvAgBayERERbPwoSCSCjiRZjrSjQCCKBAzYyCDAACzIAMuRLtgUW8oGjRComBnfjVf6UBAEoC0oAAXJMCBvYZMzhwLCJC4B0jAIACBYi1CIHz3gswgkL0xAyIxODQASgvzOy9/H7dERFWPEwhKBYGcQoMIIt4EQREBhARAahlBbBRoTBIBcLfMQhB8QDAIoVI6t2Qi1uuxuBzFLKm1Yi7QRAE4XSW5mURBzbQuJnESWRrMeMCpnl91AgcOuF6a6NlAI42usNFNRnNrhapdlXb24hURLpBOhSjgBwwgdeoDRL+H3d/JuBRNCASEBFWPi9YCIGQVpwYBTxwLahABLgWD+ABCBCBkYEVCoE4QQ/CIgSgUSsQBkBURE5EowCjiFQkCkEBAoBnYI2IoAgIARAVi0NBRARwKJoBCRygAAALsYhfoRgAD14Ea/aA4IQVrKiXcsJenBcR8AJKQFYRCwQIFAPU38V59iumIFUJMvfVkOtbcalBnYTdVnun2+q0DJnQUXBydrXM804rAqG8zBpEjUa0ubvb2zq8GVxNby4CcVFIxdI1Q9PpJGvN9dvryd9+9mQ8m5GTELCFJkJqKGtQB0gGNSIaFJ17J+CVAJFFcAhQiatFAJAEAKFmh+AEFIOsWDiIrLKkxxXGAAXOCzIgoAchAKiRlVoldy8sCkEheFEE2gkq9PAdurReQMCDiEKEVdZERGFBFPGA3guIKAQvXHlZoRVmACeewXtRsmKwgELMLB48CwMoARRgQLVCVQ4QhJ04DwBAjl0JrhKes5txORFeoI973UdHu5v9jZ2NbmwkjpuTeVazjxLrvKx3WplzVgfFPM1rN54v1naKh6+8ettZm9yej4eXi3RZSagis72T7CctR3x+PTy+GCzSJbNfils6H6EOUYWkNSKh6IprBkTwxDkKC5Bnztmt2CkiMjAIIiCjOKmdF0QGIAIBRgfesxOBGoEQUIDQkSgQJEGFBICI4sWjkCIWcUDEIixAgk4cAJCwR/4OCTAKkCLw3jlh+Q6G0opFI4AAeHErHKkAWL6L7Q49iHggEV5B0lX8WW2j8+zAewEPvOLAS3ZL8XMul8JzcWDtwZ3D9959U7zrNJpxaJx3l8Pr6SwdD2dVWUSxtq7VbHWaUbjRXguTzvXtrav5/OIkCO3GwdFwschGmXAemphtRKSSxBwdbtYit5MZ1zKbzmauDLhqKhuzj4Asop76glDBd6oLihciXbIwlwKMqAAVM3op/IpvIQoAgAiiXsVlcCIinhFRkBSSBhYAWX0QMQBNgIyiwIuIY8fgWVgAV7BVAJi9ILEIICIwInnhkj0iAIAFQhCClQiHwqswpRS
"text/plain": [
"<PIL.Image.Image image mode=RGB size=128x128>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reverse_transform(x_start.squeeze())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def q_sample(x_start, t, noise=None):\n",
" if noise is None:\n",
" noise = torch.randn_like(x_start)\n",
" sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)\n",
" sqrt_one_minus_alphas_cumprod_t = extract(\n",
" sqrt_one_minus_alphas_cumprod, t, x_start.shape\n",
" )\n",
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise\n",
"\n",
"def get_noisy_image(x_start, t):\n",
" # add noise\n",
" x_noisy = q_sample(x_start, t = t)\n",
" noisy_image = reverse_transform(x_noisy.squeeze())\n",
" return noisy_image\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):\n",
" if not isinstance(imgs[0], list):\n",
" # 即使只有一行也以2d形式展示\n",
" imgs = [imgs]\n",
"\n",
" num_rows = len(imgs)\n",
" num_cols = len(imgs[0]) + with_orig\n",
" fig, axs = plt.subplots(figsize=(200, 200), nrows = num_rows, ncols=num_cols,squeeze=False)\n",
" for row_idx, row in enumerate(imgs):\n",
" row = [image] + row if with_orig else row\n",
" for col_idx, img in enumerate(row):\n",
" ax = axs[row_idx, col_idx]\n",
" ax.imshow(img, **imshow_kwargs)\n",
" ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
" \n",
" if with_orig:\n",
" axs[0, 0].set(title=\"Original\")\n",
" axs[0, 0].title.set_size(8)\n",
" if row_title is not None:\n",
" for row_idx in range(num_rows):\n",
" axs[row_idx, 0].set(ylabel=row_title[row_idx])\n",
" \n",
" plt.tight_layout()\n",
"\n",
2024-05-11 18:46:56 +02:00
" "
]
2024-05-11 23:59:23 +02:00
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAThUAAA+hCAYAAABaOrb1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAE4VElEQVR4nOzdZ5xsZZ3+619VdU5kNznnnAUBARFQkKhIUJBkImcEJSgICEoQEAZQEJQs2YCIpEFykiCSkRzdhA07dHedFzOcz/+8OYNHpu6Z43W93vBd3V21wrOe9axGu91uFwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA/xKa6Q0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKBzmukNAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACgc5rpDQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAoHOa6Q0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKBzmukNAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACgc5rpDQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAoHOa6Q0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKBzuj7MPxofH68XX3yxhoeHq9Fo/HdvEwAAAAAAAAAAAAAAAAAAAB+xdrtd77zzTs0+++zVbDb/of/WM+cAAAAAAAAAAAAAAAAAAAD/u3nmHAAAAAAAAAAAAAAAAAAA4F+XZ84BAAAAAAAAAAAAAAAAAAD+dXnmHAAAAAAAAAAAAAAAAAAA4F/Xh33mvOvD/M9efPHFmmuuuT6yjQMAAAAAAAAAAAAAAAAAACDjueeeqznnnPMf+m88cw4AAAAAAAAAAAAAAAAAAPD/D545BwAAAAAAAAAAAAAAAAAA+NflmXMAAAAAAAAAAAAAAAAAAIB/XZ45BwAAAAAAAAAAAAAAAAAA+Nf1Xz1z3vVh/ifDw8NVVXXC7GtVf/ND/ScfqWaz1fHmB7objVi7Od4da3f1TIu1q6oa7aFYe6ym5tqN0Vi7NZ77nk2t8Vi7OdaMtRuNybF2s5X7e9dY548j/6fkzz42ntu/VHsslh5ttmPt1njuON7V6Iu1G8Hf+ejY+7F2d6sn1q6q6qrcuVsFz5ercp+38UbwHKJy58vt0dzvvKvVG2u3x3Pnbe1m8Jwx+B37jw3I9Rtjwevx7uA542jwnLGRO2es4DVpO3hN2ggewxvjuTGIqqp27isevT5qtIPXxMF9enKksRm8Lmw1c/uXrsqdt2X/4lWN4Ge9/eFuufy3CP7YVZU7X67WlFg6eY0ybSy4Tx/LjWtXK3t9NG00+VnP/ezd7eQ5RPAapYLXpMH2eHDcqR0c4hxtB7/fVdUTvD6aErx3NR68Hm+PBr9nwfPlRiN3vjpWuXGI5LlyV/K+fGWvzrrGc8fxKcF7pVOm5T7rk9u5v/ho8JxxPDjOOS14D6cZ/HtXVU0dy/3eRxvB+3bjufOX0cC83Q80mrnxl+St0nZwDkh7LPdZm5ZLV1dw/mxV1ZTR4Fz1nty5U3fwgrxRuXarlfuwjwc/a6PBOQHBS4QaDc+9mdh+L9Z+qp07jr8wLdd+KTjmNUczN491pq7cPeLu4Llydyt33jZjIzhvuKpGgvOO+tq5z/po8P70lODzZj3BeayDwXPG/uC85VZwfHd0LDiPtKqawYH1nuB+vRW8P92uzH51SnusTpj41//7+fF/xAf/zVXz7VqDzc5/V8+c/hMdb37g4+/9JdZefODxWLsG3861q+rN5m9i7f1bS8ba2/5w1lh7yW0nxdqXtW6NtSfMdEKsveCkXPuVC1eNtactdUesXVW1xiIbxtp/nzYh1n7iY7+NtUdqxlh7+evvjbUvX3uFWHvmscVj7fE5z42137r3e7F2VdWqcz4ba7/8hz/H2u+PfCzWHljqzVj7lYsmxtqLrnBLrP3EpB/E2jM+/E6sveypZ8bav3gw9x2rqjrmvtz10eGvzRFrL9i+K9Z+fcYHYu3V51kt1u55ZDDWfnbVBWLtGZ7MHct2XuayWLuq6vrZ54+1x8ami7VvmTt3Tdy3/fKx9pbb5z7r5x4eS9dKX/lcrL3xdffH2gdulbsmraparOfyWLvv3iVi7dd2miHWXukPL8baL12Xm4f68EFPxNpznLx+rF3rrh5LTx7+ZKxdVfWZf8+Ne40+e3asfVp/7t7VW998JNZ+/KiRWHu2zXLPkmwx6bOx9ql35+ag3jRv9t74QTP+LtZe7o31Yu37hnPrmJw04YxY+0uTjom153sxN7b8+rtzx9pzPJAbdxqf9blYu6pq8uy59lODK8XaW+yZ+569uP1nYu23vjAx1l71odw1yokb5s4ZFzo9N99p4hKfirWrql54NTfOOf3dv4+1Z1oxlq7WcjvG2i+fnBsDWXSLpWLtG393X6y9RvumWPvRtzaIte/6zr/H2lVV67/w5Vi7r+fBWHvCb3Njbo8vklsPbOLcf4y1Pzk2S6x967Mbx9ovPbFvrL3oJrvE2lVVbz94Q6z927t/Fms3VszNof3NQ7lrsyWX/UmsPfinL8Xa237zyli72bVprN19Zm5NiaqqhzZeN9ZecDR3LO27eZlYe+CZ3PoKb2z4t1h7+h1y75V996LcPIw3n8jNMdtrufNi7aqqR25ZJNZ+YDj3fGPX33JjID3f3CjSnTw6rY6648p/6pnzPWe7vnqbnb8X89JOe3S8+YFdH809EzplwYti7W3/kBsTr6rapvvAWPu2A38ea+/+i9tj7We6to+1P79obg2S0ya8EmufcnTu+eNdJ+TmeldVffqBV2PtAxfLzdV6fKHcc3rHHZg7puxzfu5884/bfjPWPnHzzWPt3U7MzY9baNN9Yu2qqgv2zK3h9vyiuXvbSy29Way9xA5fi7WHV9o61j54ptz6LV/aLzfnecfnd461q6o+tvAlsfZ1e+SOZxMPyT2DfOeVuXGrd9/8e6z97Z/n7vd95fO548ma888Ta09cLHcPpKpq/qfPirVn+Nh3Yu3v3TNfrH3cbbn5KzM/njueHNl+IdZ+afMFY+2Nztoi1v7zjNm177Z+4MZY+3Nb5e7rv3L9T2PtkaP/FGs/9e3ZYu3L5sqt0XTZkrnzl2Vm+WqsXVU17YLcOcy6Xd+NtZ/49Eyx9k175NY4WHef3FoiR52Ym3/bmD+3DtsSX3go1p79wIdj7aqqGy47Ltb+w6lXx9p3XZCb3//wKs/E2rcemZsn9pPLDoi1n7l3v1h7i798Jdauqvr6b3L3kD57f2bOUlXVW6/kxiFWnf0fnyv1UemePvdsw6x3587bJmyYewb56qteirWfPSJ3z66qarNHr4m1rzsgd312Wjs3v+8bj+bWDj7z/Nz8/ou7c9eke9yYG1ued97sOvhXLZdbC227Ha6KtQ/5Qm6NpgVeyK0Bt8t6ueftX/5Bbl3yA3ffKtZ++rLcz73dVrl1saqqdvjFtrH21Yvm5mzfsXBuHuv6S+TWEpn56ht
"text/plain": [
"<Figure size 20000x20000 with 5 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 损失函数\n",
"\n",
"注意理解这里的顺序\n",
"\n",
"1. 先采样噪声\n",
"2. 用这个噪声去加噪图片\n",
"3. 根据加噪了的图片去预测第一步中采样的噪声"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def p_losses(denoise_model, x_start, t, noise=None, loss_type=\"l1\"):\n",
" # 采样噪声\n",
" if noise is None:\n",
" noise = torch.randn_like(x_start)\n",
"\n",
" # 用采样的噪声生成噪声图像\n",
" x_noisy = q_sample(x_start, t=t, noise=noise)\n",
" predicted_noise = denoise_model(x_noisy, t)\n",
"\n",
" if loss_type == 'l1':\n",
" loss = F.l1_loss(noise, predicted_noise)\n",
" elif loss_type == 'l2':\n",
" loss = F.mse_loss(noise, predicted_noise)\n",
" elif loss_type == \"huber\":\n",
" loss = F.smooth_l1_loss(noise, predicted_noise)\n",
" else:\n",
" raise NotImplementedError()\n",
" \n",
" return loss\n",
" \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义数据集PyTorch Dataset 和 DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "846cd4de5dc54577932d82e25d12a3eb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/30.9M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eafd5138ee6641609ef9ce06d2ef5054",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/5.18M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e1ea36ae914141a1828ae64507deb962",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06b2043f4d744a058618e5371b5ab960",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"fashion_mnist\")\n",
"image_size = 28\n",
"channels = 1\n",
"batch_size = 128\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from torchvision import transforms\n",
"from torch.utils.data import DataLoader\n",
"\n",
"transform = Compose([\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Lambda(lambda t: t * 2 - 1)\n",
"])\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['pixel_values'])\n"
]
}
],
"source": [
"def transforms(examples):\n",
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n",
" del examples[\"image\"]\n",
" return examples\n",
"\n",
"transformed_dataset = dataset.with_transform(transforms).remove_columns(\"label\")\n",
"dataloader = DataLoader(transformed_dataset[\"train\"], batch_size=batch_size, shuffle=True)\n",
"batch = next(iter(dataloader))\n",
"print(batch.keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 采样/去噪\n",
"![algorithm2](https://pic2.zhimg.com/80/v2-1fe0665f5014f1bcd9ea89301d78b629_720w.webp)\n",
"\n",
"采样过程发生在反向去噪时。对于一张纯噪声,扩散模型一步步地去除噪声最终得到真实图片,采样事实上就是定义的去除噪声这一行为。 观察上图中第四行, t-1 步的图片是由 t 步的图片减去一个噪声得到的,只不过这个噪声是由 $\\theta$ 网络拟合出来,并且 rescale 过的而已。 这里要注意第四行式子的最后一项,采样时每一步也都会加上一个从正态分布采样的纯噪声。\n",
"\n",
"理想情况下,最终我们会得到一张看起来像是从真实数据分布中采样得到的图片。\n",
"\n",
"我们将上述过程写成代码,这里的代码相比于原论文的代码略有简化但是效果接近。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"@torch.no_grad()\n",
"def p_sample(model, x, t, t_index):\n",
" betas_t = extract(betas, t, x.shape)\n",
" sqrt_one_minus_alphas_cumprod_t = extract(\n",
" sqrt_one_minus_alphas_cumprod, t, x.shape\n",
" )\n",
" sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)\n",
"\n",
" model_mean = sqrt_recip_alphas_t * (\n",
" x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t\n",
" )\n",
"\n",
" if t_index == 0:\n",
" return model_mean\n",
" else:\n",
" posterior_variance_t = extract(posteriors_variance, t, x.shape)\n",
" noise = torch.randn_like(x)\n",
" return model_mean + torch.sqrt(posterior_variance_t) * noise\n",
"\n",
"@torch.no_grad()\n",
"\n",
"def p_sample_loop(model, shape):\n",
" device = next(model.parameters()).device\n",
"\n",
" b = shape[0]"
]
2024-05-09 13:42:22 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "arch2vec39",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-05-11 23:59:23 +02:00
"version": "3.9.16"
2024-05-09 13:42:22 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}