From 591780d79e1ab812f93419fbb996e40c01b37e2e Mon Sep 17 00:00:00 2001 From: Hanzhang Ma Date: Thu, 9 May 2024 13:42:22 +0200 Subject: [PATCH] init diffusion --- main.ipynb | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 main.ipynb diff --git a/main.ipynb b/main.ipynb new file mode 100644 index 0000000..8b7714e --- /dev/null +++ b/main.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network Helper\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import inspect\n", + "import torch\n", + "import math" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def exists(x):\n", + " return x is not None\n", + "\n", + "def default(val, d):\n", + " if exists(val):\n", + " return val\n", + " return d() if inspect.isfunction(d) else d\n", + "\n", + "class Residual(nn.Module):\n", + " def __init__(self, fn):\n", + " super().__init__()\n", + " self.fn = fn\n", + "\n", + " def forward(self, x, *args, **kwargs):\n", + " return self.fn(x, *args, **kwargs) + x\n", + "\n", + "# 上采样(反卷积)\n", + "def Upsample(dim):\n", + " return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n", + "\n", + "def Downsample(dim):\n", + " return nn.Conv2d(dim, dim, 4, 2 ,1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Positional embedding\n", + "\n", + "目的是让网络知道\n", + "当前是哪一个step. \n", + "ddpm采用正弦位置编码\n", + "\n", + "输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n", + "这个tensor会被加到每一个残差模块中.\n", + "\n", + "总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SinusolidalPositionEmbedding(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.dim = dim\n", + "\n", + " def forward(self, time):\n", + " device = time.device\n", + " half_dim = self.dim // 2\n", + " embeddings = math.log(10000) / (half_dim - 1)\n", + " embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n", + " embeddings = time[:, :, None] * embeddings[None, None, :]\n", + " embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n", + " return embeddings\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ResNet/ConvNeXT block" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Block(nn.Module):\n", + " def __init__(self, dim, dim_out, groups = 8):\n", + " super().__init__()\n", + " self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n", + " self.norm = nn.GroupNorm(groups, dim_out)\n", + " self.act = nn.SiLU()\n", + " \n", + " def forward(self, x, scale_shift = None):\n", + " x = self.proj(x)\n", + " x = self.norm(x)\n", + "\n", + " if exists(scale_shift):\n", + " scale, shift = scale_shift\n", + " x = x * (scale + 1) + shift\n", + "\n", + " x = self.act(x)\n", + " return x\n", + "\n", + "class ResnetBlock(nn.Module):\n", + " def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n", + " super().__init__()\n", + " self.mlp = (\n", + " nn.Sequential(\n", + " nn.SiLU(), \n", + " nn.Linear(time_emb_dim, dim_out)\n", + " )\n", + " if exists(time_emb_dim) else None\n", + " )\n", + " self.block1 = Block(dim, dim_out, groups=groups)\n", + " self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n", + " self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", + " \n", + " def forward(self, x, time_emb = None):\n", + " h = self.block1(x)\n", + "\n", + " if exists(self.mlp) and exists(time_emb):\n", + " time_emb = self.mlp(time_emb)\n", + " h = rearrange(time_emb, 'b n -> b () n') + h\n", + "\n", + " h = self.block2(h)\n", + " return h + self.res_conv(x)\n", + " \n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "arch2vec39", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}