diffusion-model/main.ipynb
2024-05-11 18:46:56 +02:00

347 lines
10 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Helper\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import inspect\n",
"import torch\n",
"import math\n",
"from einops import rearrange\n",
"from torch import einsum"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def exists(x):\n",
" return x is not None\n",
"\n",
"def default(val, d):\n",
" if exists(val):\n",
" return val\n",
" return d() if inspect.isfunction(d) else d\n",
"\n",
"class Residual(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
"\n",
" def forward(self, x, *args, **kwargs):\n",
" return self.fn(x, *args, **kwargs) + x\n",
"\n",
"# 上采样(反卷积)\n",
"def Upsample(dim):\n",
" return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n",
"\n",
"def Downsample(dim):\n",
" return nn.Conv2d(dim, dim, 4, 2 ,1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Positional embedding\n",
"\n",
"目的是让网络知道\n",
"当前是哪一个step. \n",
"ddpm采用正弦位置编码\n",
"\n",
"输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n",
"这个tensor会被加到每一个残差模块中.\n",
"\n",
"总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class SinusolidalPositionEmbedding(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.dim = dim\n",
"\n",
" def forward(self, time):\n",
" device = time.device\n",
" half_dim = self.dim // 2\n",
" embeddings = math.log(10000) / (half_dim - 1)\n",
" embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n",
" embeddings = time[:, :, None] * embeddings[None, None, :]\n",
" embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n",
" return embeddings\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ResNet/ConvNeXT block"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Block在init的时候创建了一个卷积层一个归一化层一个激活函数\n",
"# 前向的过程包括,首先卷积,再归一化,最后激活\n",
"class Block(nn.Module):\n",
" def __init__(self, dim, dim_out, groups = 8):\n",
" super().__init__()\n",
" self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) # 提取特征\n",
" self.norm = nn.GroupNorm(groups, dim_out) # 归一化, 使得网络训练更快速,调整和缩放神经网络中间层来实现梯度的稳定\n",
" self.act = nn.SiLU()\n",
" \n",
" def forward(self, x, scale_shift = None):\n",
" x = self.proj(x)\n",
" x = self.norm(x)\n",
"\n",
" if exists(scale_shift):\n",
" scale, shift = scale_shift\n",
" x = x * (scale + 1) + shift\n",
"\n",
" x = self.act(x)\n",
" return x\n",
"\n",
"class ResnetBlock(nn.Module):\n",
" def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n",
" super().__init__()\n",
" self.mlp = (\n",
" nn.Sequential(\n",
" nn.SiLU(), \n",
" nn.Linear(time_emb_dim, dim_out)\n",
" )\n",
" if exists(time_emb_dim) else None\n",
" )\n",
" # 第一个块的作用是将输入的特征图的维度从dim变成dim_out\n",
" self.block1 = Block(dim, dim_out, groups=groups)\n",
" self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n",
" self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
" \n",
" def forward(self, x, time_emb = None):\n",
" h = self.block1(x)\n",
"\n",
" if exists(self.mlp) and exists(time_emb):\n",
" # 时间序列送到多层感知机里\n",
" time_emb = self.mlp(time_emb)\n",
" # 为了让时间序列的维度和特征图的维度一致,所以需要增加一个维度\n",
" h = rearrange(time_emb, 'b n -> b () n') + h\n",
"\n",
" h = self.block2(h)\n",
" return h + self.res_conv(x)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConvNextBlock(nn.Module):\n",
" def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):\n",
" super().__init()\n",
" self.mlp = (\n",
" nn.Sequential(\n",
" nn.GELU(),\n",
" nn.Linear(time_emb_dim, dim)\n",
" )\n",
" if exists(time_emb_dim) else None \n",
" )\n",
"\n",
" self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n",
"\n",
" self.net = nn.Sequential(\n",
" nn.GroupNorm(1, dim) if norm else nn.Identity(),\n",
" nn.Conv2d(dim, dim_out * mult, 3, padding=1),\n",
" nn.GELU(),\n",
" nn.GroupNorm(1, dim_out * mult),\n",
" nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),\n",
" )\n",
" self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
"\n",
" def forward(self, x, time_emb=None):\n",
" h = self.ds_conv(x)\n",
"\n",
" if exists(self.mlp) and exists(time_emb):\n",
" condition = self.mlp(time_emb)\n",
" h = rearrange(time_emb, 'b c -> b c 1 1') + h\n",
" h = self.net(h)\n",
" return h + self.res_conv(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Attention module"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Attention(nn.Module):\n",
" def __init__(self, dim, heads=4, dim_head=32):\n",
" super().__init__()\n",
" self.scale = dim_head ** -0.5\n",
" self.heads = heads\n",
" hidden_dim = dim_head * dim # 计算隐藏层的维度\n",
" self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # 通过卷积层将输入的特征图的维度变成hidden_dim * 3\n",
" self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n",
" \n",
" def forward(self, x):\n",
" b, c, h ,w = x.shape\n",
" qkv = self.to_qkv(x).chunk(3, dim=1)\n",
" q, k, v= map(\n",
" lambda t: rearrange(t, \"b (h c) x y -> b h c (x y)\", h=self.heads),\n",
" qkv\n",
" )\n",
" q = q * self.scale\n",
"\n",
" sim = einsum(\"b h d i, b h d j -> b h i j\", q, k)\n",
" sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n",
" attn = sim.softmax(dim=-1)\n",
"\n",
" out = einsum(\"b h i j, b h d j -> b h i d\", attn, v)\n",
" out = rearrange(out, \"b h (x y) d -> b (h d) x y\", x=h, y=w)\n",
" return self.to_out(out)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LinearAttention(nn.Module):\n",
" def __init__(self, dim, heads=4, dim_head=32):\n",
" super().__init__()\n",
" self.scale = dim_head ** -0.5\n",
" self.heads = heads\n",
" hidden_dim = dim_head * heads\n",
" self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n",
" self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),\n",
" nn.GroupNorm(1, dim))\n",
"\n",
" def forward(self, x):\n",
" b, c, h, w = x.shape\n",
" qkv = self.to_qkv(x).chunk(3, dim=1)\n",
" q, k, v = map(\n",
" lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), \n",
" qkv\n",
" )\n",
"\n",
" q = q.softmax(dim=-2)\n",
" k = k.softmax(dim=-1)\n",
"\n",
" q = q * self.scale\n",
" context = torch.einsum(\"b h d n, b h e n -> b h d e\", k, v)\n",
"\n",
" out = torch.einsum(\"b h d e, b h d n -> b h e n\", context, q)\n",
" out = rearrange(out, \"b h c (x y) -> b (h c) x y\", h=self.heads, x=h, y=w)\n",
" return self.to_out(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Group Normalization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class PreNorm(nn.Module):\n",
" def __init__(self, dim, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" self.norm = nn.GroupNorm(1, dim)\n",
" \n",
" def forward(self, x):\n",
" x = self.norm(x)\n",
" return self.fn(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conditional U-Net"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Unet(nn.Module):\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "arch2vec39",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}