diff --git a/main.ipynb b/main.ipynb index 8b7714e..a787d72 100644 --- a/main.ipynb +++ b/main.ipynb @@ -10,14 +10,16 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import inspect\n", "import torch\n", - "import math" + "import math\n", + "from einops import rearrange\n", + "from torch import einsum" ] }, { @@ -68,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -97,15 +99,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ + "# Block在init的时候创建了一个卷积层,一个归一化层,一个激活函数\n", + "# 前向的过程包括,首先卷积,再归一化,最后激活\n", "class Block(nn.Module):\n", " def __init__(self, dim, dim_out, groups = 8):\n", " super().__init__()\n", - " self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)\n", - " self.norm = nn.GroupNorm(groups, dim_out)\n", + " self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) # 提取特征\n", + " self.norm = nn.GroupNorm(groups, dim_out) # 归一化, 使得网络训练更快速,调整和缩放神经网络中间层来实现梯度的稳定\n", " self.act = nn.SiLU()\n", " \n", " def forward(self, x, scale_shift = None):\n", @@ -129,6 +133,7 @@ " )\n", " if exists(time_emb_dim) else None\n", " )\n", + " # 第一个块的作用是,将输入的特征图的维度从dim变成dim_out\n", " self.block1 = Block(dim, dim_out, groups=groups)\n", " self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)\n", " self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", @@ -137,7 +142,9 @@ " h = self.block1(x)\n", "\n", " if exists(self.mlp) and exists(time_emb):\n", + " # 时间序列送到多层感知机里\n", " time_emb = self.mlp(time_emb)\n", + " # 为了让时间序列的维度和特征图的维度一致,所以需要增加一个维度\n", " h = rearrange(time_emb, 'b n -> b () n') + h\n", "\n", " h = self.block2(h)\n", @@ -145,6 +152,174 @@ " \n", " " ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ConvNextBlock(nn.Module):\n", + " def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):\n", + " super().__init()\n", + " self.mlp = (\n", + " nn.Sequential(\n", + " nn.GELU(),\n", + " nn.Linear(time_emb_dim, dim)\n", + " )\n", + " if exists(time_emb_dim) else None \n", + " )\n", + "\n", + " self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n", + "\n", + " self.net = nn.Sequential(\n", + " nn.GroupNorm(1, dim) if norm else nn.Identity(),\n", + " nn.Conv2d(dim, dim_out * mult, 3, padding=1),\n", + " nn.GELU(),\n", + " nn.GroupNorm(1, dim_out * mult),\n", + " nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),\n", + " )\n", + " self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n", + "\n", + " def forward(self, x, time_emb=None):\n", + " h = self.ds_conv(x)\n", + "\n", + " if exists(self.mlp) and exists(time_emb):\n", + " condition = self.mlp(time_emb)\n", + " h = rearrange(time_emb, 'b c -> b c 1 1') + h\n", + " h = self.net(h)\n", + " return h + self.res_conv(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attention module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Attention(nn.Module):\n", + " def __init__(self, dim, heads=4, dim_head=32):\n", + " super().__init__()\n", + " self.scale = dim_head ** -0.5\n", + " self.heads = heads\n", + " hidden_dim = dim_head * dim # 计算隐藏层的维度\n", + " self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # 通过卷积层将输入的特征图的维度变成hidden_dim * 3\n", + " self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n", + " \n", + " def forward(self, x):\n", + " b, c, h ,w = x.shape\n", + " qkv = self.to_qkv(x).chunk(3, dim=1)\n", + " q, k, v= map(\n", + " lambda t: rearrange(t, \"b (h c) x y -> b h c (x y)\", h=self.heads),\n", + " qkv\n", + " )\n", + " q = q * self.scale\n", + "\n", + " sim = einsum(\"b h d i, b h d j -> b h i j\", q, k)\n", + " sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n", + " attn = sim.softmax(dim=-1)\n", + "\n", + " out = einsum(\"b h i j, b h d j -> b h i d\", attn, v)\n", + " out = rearrange(out, \"b h (x y) d -> b (h d) x y\", x=h, y=w)\n", + " return self.to_out(out)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class LinearAttention(nn.Module):\n", + " def __init__(self, dim, heads=4, dim_head=32):\n", + " super().__init__()\n", + " self.scale = dim_head ** -0.5\n", + " self.heads = heads\n", + " hidden_dim = dim_head * heads\n", + " self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n", + " self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),\n", + " nn.GroupNorm(1, dim))\n", + "\n", + " def forward(self, x):\n", + " b, c, h, w = x.shape\n", + " qkv = self.to_qkv(x).chunk(3, dim=1)\n", + " q, k, v = map(\n", + " lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), \n", + " qkv\n", + " )\n", + "\n", + " q = q.softmax(dim=-2)\n", + " k = k.softmax(dim=-1)\n", + "\n", + " q = q * self.scale\n", + " context = torch.einsum(\"b h d n, b h e n -> b h d e\", k, v)\n", + "\n", + " out = torch.einsum(\"b h d e, b h d n -> b h e n\", context, q)\n", + " out = rearrange(out, \"b h c (x y) -> b (h c) x y\", h=self.heads, x=h, y=w)\n", + " return self.to_out(out)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Group Normalization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PreNorm(nn.Module):\n", + " def __init__(self, dim, fn):\n", + " super().__init__()\n", + " self.fn = fn\n", + " self.norm = nn.GroupNorm(1, dim)\n", + " \n", + " def forward(self, x):\n", + " x = self.norm(x)\n", + " return self.fn(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conditional U-Net" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Unet(nn.Module):\n", + " " + ] } ], "metadata": { diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..16a661c --- /dev/null +++ b/test.ipynb @@ -0,0 +1,62 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 1, 2],\n", + " [ 7, 8],\n", + " [13, 14]])\n", + "tensor([[ 3, 4],\n", + " [ 9, 10],\n", + " [15, 16]])\n", + "tensor([[ 5, 6],\n", + " [11, 12],\n", + " [17, 18]])\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "# 创建一个大小为(3, 6)的张量\n", + "tensor = torch.tensor([[1, 2, 3, 4, 5, 6],\n", + " [7, 8, 9, 10, 11, 12],\n", + " [13, 14, 15, 16, 17, 18]])\n", + "\n", + "# 沿着第二个维度分成3份\n", + "chunks = tensor.chunk(3, dim=1)\n", + "\n", + "# 打印分割后的张量\n", + "for chunk in chunks:\n", + " print(chunk)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "arch2vec39", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}