286 lines
8.8 KiB
Python
286 lines
8.8 KiB
Python
import torch
|
|
import sys
|
|
import numpy as np
|
|
import random
|
|
|
|
sys.path.append('.')
|
|
import sampling
|
|
import datasets_nas
|
|
from models import cate
|
|
from models import digcn
|
|
from models import digcn_meta
|
|
from models import utils as mutils
|
|
from models.ema import ExponentialMovingAverage
|
|
import sde_lib
|
|
from utils import *
|
|
from analysis.arch_functions import BasicArchMetricsMeta
|
|
from all_path import *
|
|
|
|
|
|
def get_sampling_fn_meta(config):
|
|
## Set SDE
|
|
if config.training.sde.lower() == 'vpsde':
|
|
sde = sde_lib.VPSDE(
|
|
beta_min=config.model.beta_min,
|
|
beta_max=config.model.beta_max,
|
|
N=config.model.num_scales)
|
|
sampling_eps = 1e-3
|
|
elif config.training.sde.lower() == 'subvpsde':
|
|
sde = sde_lib.subVPSDE(
|
|
beta_min=config.model.beta_min,
|
|
beta_max=config.model.beta_max,
|
|
N=config.model.num_scales)
|
|
sampling_eps = 1e-3
|
|
elif config.training.sde.lower() == 'vesde':
|
|
sde = sde_lib.VESDE(
|
|
sigma_min=config.model.sigma_min,
|
|
sigma_max=config.model.sigma_max,
|
|
N=config.model.num_scales)
|
|
sampling_eps = 1e-5
|
|
else:
|
|
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
|
|
|
## Get data normalizer inverse
|
|
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
|
|
|
## Get sampling function
|
|
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
|
|
sampling_fn = sampling.get_sampling_fn(
|
|
config=config,
|
|
sde=sde,
|
|
shape=sampling_shape,
|
|
inverse_scaler=inverse_scaler,
|
|
eps=sampling_eps,
|
|
conditional=True,
|
|
data_name=config.sampling.check_dataname,
|
|
num_sample=config.model.num_sample)
|
|
|
|
return sampling_fn, sde
|
|
|
|
|
|
def get_score_model(config):
|
|
try:
|
|
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
|
ckpt_path = config.scorenet_ckpt_path
|
|
except:
|
|
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
|
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
|
ckpt_path = config.scorenet_ckpt_path
|
|
|
|
score_model = mutils.create_model(score_config)
|
|
score_ema = ExponentialMovingAverage(
|
|
score_model.parameters(), decay=score_config.model.ema_rate)
|
|
score_state = dict(
|
|
model=score_model, ema=score_ema, step=0, config=score_config)
|
|
score_state = restore_checkpoint(
|
|
ckpt_path, score_state,
|
|
device=config.device, resume=True)
|
|
score_ema.copy_to(score_model.parameters())
|
|
return score_model, score_ema, score_config
|
|
|
|
|
|
def get_surrogate(config):
|
|
surrogate_model = mutils.create_model(config)
|
|
return surrogate_model
|
|
|
|
|
|
def get_adj(except_inout=False):
|
|
_adj = np.asarray(
|
|
[[0, 1, 1, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
|
)
|
|
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
|
if except_inout: _adj = _adj[1:-1, 1:-1]
|
|
return _adj
|
|
|
|
|
|
def generate_archs_meta(
|
|
config,
|
|
sampling_fn,
|
|
score_model,
|
|
score_ema,
|
|
meta_surrogate_model,
|
|
num_samples,
|
|
args=None,
|
|
task=None,
|
|
patient_factor=20,
|
|
batch_size=256,):
|
|
|
|
metrics = BasicArchMetricsMeta()
|
|
|
|
## Get the adj and mask
|
|
adj_s = get_adj()
|
|
mask_s = aug_mask(adj_s)[0]
|
|
adj_c = get_adj()
|
|
mask_c = aug_mask(adj_c)[0]
|
|
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
|
|
adj_s, mask_s, adj_c, mask_c = \
|
|
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
|
|
|
|
score_ema.copy_to(score_model.parameters())
|
|
score_model.eval()
|
|
meta_surrogate_model.eval()
|
|
c_scale = args.classifier_scale
|
|
|
|
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
|
|
round = 0
|
|
all_samples = []
|
|
while True and round < num_sampling_rounds:
|
|
round += 1
|
|
sample = sampling_fn(score_model,
|
|
mask_s,
|
|
meta_surrogate_model,
|
|
classifier_scale=c_scale,
|
|
task=task)
|
|
quantized_sample = quantize(sample)
|
|
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
|
|
if len(valid_arch_str) > 0: all_samples += valid_arch_str
|
|
# to sample various architectures
|
|
c_scale -= args.scale_step
|
|
seed = int(random.random() * 10000)
|
|
reset_seed(seed)
|
|
# stop sampling if we have enough samples
|
|
if (len(set(all_samples)) >= num_samples):
|
|
break
|
|
|
|
return list(set(all_samples))
|
|
|
|
|
|
def save_checkpoint(ckpt_dir, state, epoch, is_best):
|
|
saved_state = {}
|
|
for k in state:
|
|
if k in ['optimizer', 'model', 'ema']:
|
|
saved_state.update({k: state[k].state_dict()})
|
|
else:
|
|
saved_state.update({k: state[k]})
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
|
|
if is_best:
|
|
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
|
|
|
# remove the ckpt except is_best state
|
|
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
|
if not ckpt_file.startswith('checkpoint'):
|
|
continue
|
|
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
|
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
|
|
|
|
|
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
|
if not resume:
|
|
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
|
return state
|
|
elif not os.path.exists(ckpt_dir):
|
|
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
|
os.makedirs(os.path.dirname(ckpt_dir))
|
|
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
|
f"Returned the same state as input")
|
|
return state
|
|
else:
|
|
loaded_state = torch.load(ckpt_dir, map_location=device)
|
|
for k in state:
|
|
if k in ['optimizer', 'model', 'ema']:
|
|
state[k].load_state_dict(loaded_state[k])
|
|
else:
|
|
state[k] = loaded_state[k]
|
|
return state
|
|
|
|
|
|
def floyed(r):
|
|
"""
|
|
:param r: a numpy NxN matrix with float 0,1
|
|
:return: a numpy NxN matrix with float 0,1
|
|
"""
|
|
if type(r) == torch.Tensor:
|
|
r = r.cpu().numpy()
|
|
N = r.shape[0]
|
|
for k in range(N):
|
|
for i in range(N):
|
|
for j in range(N):
|
|
if r[i, k] > 0 and r[k, j] > 0:
|
|
r[i, j] = 1
|
|
return r
|
|
|
|
|
|
def aug_mask(adj, algo='floyed', data='NASBench201'):
|
|
if len(adj.shape) == 2:
|
|
adj = adj.unsqueeze(0)
|
|
|
|
if data.lower() in ['nasbench201', 'ofa']:
|
|
assert len(adj.shape) == 3
|
|
r = adj[0].clone().detach()
|
|
if algo == 'long_range':
|
|
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
|
elif algo == 'floyed':
|
|
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
|
else:
|
|
mask_i = r
|
|
masks = [mask_i] * adj.size(0)
|
|
return torch.stack(masks)
|
|
else:
|
|
masks = []
|
|
for r in adj:
|
|
if algo == 'long_range':
|
|
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
|
elif algo == 'floyed':
|
|
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
|
else:
|
|
mask_i = r
|
|
masks.append(mask_i)
|
|
return torch.stack(masks)
|
|
|
|
|
|
def long_range(r):
|
|
"""
|
|
:param r: a numpy NxN matrix with float 0,1
|
|
:return: a numpy NxN matrix with float 0,1
|
|
"""
|
|
# r = np.array(r)
|
|
if type(r) == torch.Tensor:
|
|
r = r.cpu().numpy()
|
|
N = r.shape[0]
|
|
for j in range(1, N):
|
|
col_j = r[:, j][:j]
|
|
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
|
|
if len(in_to_j) > 0:
|
|
for i in in_to_j:
|
|
col_i = r[:, i][:i]
|
|
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
|
|
if len(in_to_i) > 0:
|
|
for k in in_to_i:
|
|
r[k, j] = 1
|
|
return r
|
|
|
|
|
|
def quantize(x):
|
|
"""Covert the PyTorch tensor x, adj matrices to numpy array.
|
|
|
|
Args:
|
|
x: [Batch_size, Max_node, N_vocab]
|
|
"""
|
|
x_list = []
|
|
|
|
# discretization
|
|
x[x >= 0.5] = 1.
|
|
x[x < 0.5] = 0.
|
|
|
|
for i in range(x.shape[0]):
|
|
x_tmp = x[i]
|
|
x_tmp = x_tmp.cpu().numpy()
|
|
x_list.append(x_tmp)
|
|
|
|
return x_list
|
|
|
|
|
|
def reset_seed(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.backends.cudnn.deterministic = True |