init diffusion

This commit is contained in:
Hanzhang Ma 2024-05-09 13:42:22 +02:00
commit 591780d79e

171
main.ipynb Normal file
View File

@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Helper\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import inspect\n",
"import torch\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def exists(x):\n",
" return x is not None\n",
"\n",
"def default(val, d):\n",
" if exists(val):\n",
" return val\n",
" return d() if inspect.isfunction(d) else d\n",
"\n",
"class Residual(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
"\n",
" def forward(self, x, *args, **kwargs):\n",
" return self.fn(x, *args, **kwargs) + x\n",
"\n",
"# 上采样(反卷积)\n",
"def Upsample(dim):\n",
" return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n",
"\n",
"def Downsample(dim):\n",
" return nn.Conv2d(dim, dim, 4, 2 ,1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Positional embedding\n",
"\n",
"目的是让网络知道\n",
"当前是哪一个step. \n",
"ddpm采用正弦位置编码\n",
"\n",
"输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n",
"这个tensor会被加到每一个残差模块中.\n",
"\n",
"总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SinusolidalPositionEmbedding(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.dim = dim\n",
"\n",
" def forward(self, time):\n",
" device = time.device\n",
" half_dim = self.dim // 2\n",
" embeddings = math.log(10000) / (half_dim - 1)\n",
" embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n",
" embeddings = time[:, :, None] * embeddings[None, None, :]\n",
" embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n",
" return embeddings\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ResNet/ConvNeXT block"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Block(nn.Module):\n",
" def __init__(self, dim, dim_out, groups = 8):\n",
" super().__init__()\n",
" self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n",
" self.norm = nn.GroupNorm(groups, dim_out)\n",
" self.act = nn.SiLU()\n",
" \n",
" def forward(self, x, scale_shift = None):\n",
" x = self.proj(x)\n",
" x = self.norm(x)\n",
"\n",
" if exists(scale_shift):\n",
" scale, shift = scale_shift\n",
" x = x * (scale + 1) + shift\n",
"\n",
" x = self.act(x)\n",
" return x\n",
"\n",
"class ResnetBlock(nn.Module):\n",
" def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n",
" super().__init__()\n",
" self.mlp = (\n",
" nn.Sequential(\n",
" nn.SiLU(), \n",
" nn.Linear(time_emb_dim, dim_out)\n",
" )\n",
" if exists(time_emb_dim) else None\n",
" )\n",
" self.block1 = Block(dim, dim_out, groups=groups)\n",
" self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n",
" self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n",
" \n",
" def forward(self, x, time_emb = None):\n",
" h = self.block1(x)\n",
"\n",
" if exists(self.mlp) and exists(time_emb):\n",
" time_emb = self.mlp(time_emb)\n",
" h = rearrange(time_emb, 'b n -> b () n') + h\n",
"\n",
" h = self.block2(h)\n",
" return h + self.res_conv(x)\n",
" \n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "arch2vec39",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}