diffusionNAG/NAS-Bench-201/main_exp/diffusion/run_lib.py
2024-03-15 14:38:51 +00:00

286 lines
8.8 KiB
Python

import torch
import sys
import numpy as np
import random
sys.path.append('.')
import sampling
import datasets_nas
from models import cate
from models import digcn
from models import digcn_meta
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import sde_lib
from utils import *
from analysis.arch_functions import BasicArchMetricsMeta
from all_path import *
def get_sampling_fn_meta(config):
## Set SDE
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
sde = sde_lib.subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
## Get data normalizer inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
## Get sampling function
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(
config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps,
conditional=True,
data_name=config.sampling.check_dataname,
num_sample=config.model.num_sample)
return sampling_fn, sde
def get_score_model(config):
try:
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
except:
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
score_model = mutils.create_model(score_config)
score_ema = ExponentialMovingAverage(
score_model.parameters(), decay=score_config.model.ema_rate)
score_state = dict(
model=score_model, ema=score_ema, step=0, config=score_config)
score_state = restore_checkpoint(
ckpt_path, score_state,
device=config.device, resume=True)
score_ema.copy_to(score_model.parameters())
return score_model, score_ema, score_config
def get_surrogate(config):
surrogate_model = mutils.create_model(config)
return surrogate_model
def get_adj(except_inout=False):
_adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
if except_inout: _adj = _adj[1:-1, 1:-1]
return _adj
def generate_archs_meta(
config,
sampling_fn,
score_model,
score_ema,
meta_surrogate_model,
num_samples,
args=None,
task=None,
patient_factor=20,
batch_size=256,):
metrics = BasicArchMetricsMeta()
## Get the adj and mask
adj_s = get_adj()
mask_s = aug_mask(adj_s)[0]
adj_c = get_adj()
mask_c = aug_mask(adj_c)[0]
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
adj_s, mask_s, adj_c, mask_c = \
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
score_ema.copy_to(score_model.parameters())
score_model.eval()
meta_surrogate_model.eval()
c_scale = args.classifier_scale
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
round = 0
all_samples = []
while True and round < num_sampling_rounds:
round += 1
sample = sampling_fn(score_model,
mask_s,
meta_surrogate_model,
classifier_scale=c_scale,
task=task)
quantized_sample = quantize(sample)
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
if len(valid_arch_str) > 0: all_samples += valid_arch_str
# to sample various architectures
c_scale -= args.scale_step
seed = int(random.random() * 10000)
reset_seed(seed)
# stop sampling if we have enough samples
if (len(set(all_samples)) >= num_samples):
break
return list(set(all_samples))
def save_checkpoint(ckpt_dir, state, epoch, is_best):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def floyed(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for k in range(N):
for i in range(N):
for j in range(N):
if r[i, k] > 0 and r[k, j] > 0:
r[i, j] = 1
return r
def aug_mask(adj, algo='floyed', data='NASBench201'):
if len(adj.shape) == 2:
adj = adj.unsqueeze(0)
if data.lower() in ['nasbench201', 'ofa']:
assert len(adj.shape) == 3
r = adj[0].clone().detach()
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks = [mask_i] * adj.size(0)
return torch.stack(masks)
else:
masks = []
for r in adj:
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks.append(mask_i)
return torch.stack(masks)
def long_range(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for j in range(1, N):
col_j = r[:, j][:j]
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
if len(in_to_j) > 0:
for i in in_to_j:
col_i = r[:, i][:i]
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
if len(in_to_i) > 0:
for k in in_to_i:
r[k, j] = 1
return r
def quantize(x):
"""Covert the PyTorch tensor x, adj matrices to numpy array.
Args:
x: [Batch_size, Max_node, N_vocab]
"""
x_list = []
# discretization
x[x >= 0.5] = 1.
x[x < 0.5] = 0.
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
return x_list
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True