diffusionNAG/MobileNetV3/models/layers.py
2024-03-15 14:38:51 +00:00

45 lines
1.5 KiB
Python

"""Common layers"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
padding=padding)
return conv
# from DDPM
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb