{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Network Helper\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import inspect\n", "import torch\n", "import math" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def exists(x):\n", " return x is not None\n", "\n", "def default(val, d):\n", " if exists(val):\n", " return val\n", " return d() if inspect.isfunction(d) else d\n", "\n", "class Residual(nn.Module):\n", " def __init__(self, fn):\n", " super().__init__()\n", " self.fn = fn\n", "\n", " def forward(self, x, *args, **kwargs):\n", " return self.fn(x, *args, **kwargs) + x\n", "\n", "# 上采样(反卷积)\n", "def Upsample(dim):\n", " return nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n", "\n", "def Downsample(dim):\n", " return nn.Conv2d(dim, dim, 4, 2 ,1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Positional embedding\n", "\n", "目的是让网络知道\n", "当前是哪一个step. \n", "ddpm采用正弦位置编码\n", "\n", "输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.\n", "这个tensor会被加到每一个残差模块中.\n", "\n", "总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SinusolidalPositionEmbedding(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", "\n", " def forward(self, time):\n", " device = time.device\n", " half_dim = self.dim // 2\n", " embeddings = math.log(10000) / (half_dim - 1)\n", " embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)\n", " embeddings = time[:, :, None] * embeddings[None, None, :]\n", " embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)\n", " return embeddings\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ResNet/ConvNeXT block" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Block(nn.Module):\n", " def __init__(self, dim, dim_out, groups = 8):\n", " super().__init__()\n", " self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n", " self.norm = nn.GroupNorm(groups, dim_out)\n", " self.act = nn.SiLU()\n", " \n", " def forward(self, x, scale_shift = None):\n", " x = self.proj(x)\n", " x = self.norm(x)\n", "\n", " if exists(scale_shift):\n", " scale, shift = scale_shift\n", " x = x * (scale + 1) + shift\n", "\n", " x = self.act(x)\n", " return x\n", "\n", "class ResnetBlock(nn.Module):\n", " def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):\n", " super().__init__()\n", " self.mlp = (\n", " nn.Sequential(\n", " nn.SiLU(), \n", " nn.Linear(time_emb_dim, dim_out)\n", " )\n", " if exists(time_emb_dim) else None\n", " )\n", " self.block1 = Block(dim, dim_out, groups=groups)\n", " self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n", " self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", " \n", " def forward(self, x, time_emb = None):\n", " h = self.block1(x)\n", "\n", " if exists(self.mlp) and exists(time_emb):\n", " time_emb = self.mlp(time_emb)\n", " h = rearrange(time_emb, 'b n -> b () n') + h\n", "\n", " h = self.block2(h)\n", " return h + self.res_conv(x)\n", " \n", " " ] } ], "metadata": { "kernelspec": { "display_name": "arch2vec39", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 2 }