## Network Helper



In [10]:
import torch.nn as nn
import inspect
import torch
import math
from einops import rearrange
from torch import einsum

In [2]:
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if inspect.isfunction(d) else d

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# 上采样(反卷积)
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2 ,1)

## Positional embedding

目的是让网络知道
当前是哪一个step. 
ddpm采用正弦位置编码

输入是shape为(batch_size, 1)的tensor, batch中每一个sample所处的t,并且将这个tensor转换为shape为(batch_size, dim)的tensor.
这个tensor会被加到每一个残差模块中.

总之就是将$t$编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step

In [3]:
class SinusolidalPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, :, None] * embeddings[None, None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
        

## ResNet/ConvNeXT block

In [4]:
# Block在init的时候创建了一个卷积层，一个归一化层，一个激活函数
# 前向的过程包括，首先卷积，再归一化，最后激活
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) # 提取特征
        self.norm = nn.GroupNorm(groups, dim_out) # 归一化， 使得网络训练更快速，调整和缩放神经网络中间层来实现梯度的稳定
        self.act = nn.SiLU()
    
    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = (
            nn.Sequential(
                nn.SiLU(), 
                nn.Linear(time_emb_dim, dim_out)
            )
            if exists(time_emb_dim) else None
        )
        # 第一个块的作用是，将输入的特征图的维度从dim变成dim_out
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out=dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    
    def forward(self, x, time_emb = None):
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            # 时间序列送到多层感知机里
            time_emb = self.mlp(time_emb)
            # 为了让时间序列的维度和特征图的维度一致，所以需要增加一个维度
            h = rearrange(time_emb, 'b n -> b () n') + h

        h = self.block2(h)
        return h + self.res_conv(x)
    
    

In [None]:
class ConvNextBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
        super().__init()
        self.mlp = (
            nn.Sequential(
                nn.GELU(),
                nn.Linear(time_emb_dim, dim)
            )
            if exists(time_emb_dim) else None   
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            condition = self.mlp(time_emb)
            h = rearrange(time_emb, 'b c -> b c 1 1') + h
        h = self.net(h)
        return h + self.res_conv(x)

## Attention module

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * dim # 计算隐藏层的维度
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # 通过卷积层将输入的特征图的维度变成hidden_dim * 3
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
    
    def forward(self, x):
        b, c, h ,w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v= map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads),
            qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


In [None]:
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), 
            qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

## Group Normalization

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)
    
    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

## Conditional U-Net

In [None]:
class Unet(nn.Module):
    