"""Common layers""" import torch.nn as nn import torch import torch.nn.functional as F import math def get_act(config): """Get actiuvation functions from the config file.""" if config.model.nonlinearity.lower() == 'elu': return nn.ELU() elif config.model.nonlinearity.lower() == 'relu': return nn.ReLU() elif config.model.nonlinearity.lower() == 'lrelu': return nn.LeakyReLU(negative_slope=0.2) elif config.model.nonlinearity.lower() == 'swish': return nn.SiLU() elif config.model.nonlinearity.lower() == 'tanh': return nn.Tanh() else: raise NotImplementedError('activation function does not exist!') def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0): conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, padding=padding) return conv # from DDPM def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb