330 lines
13 KiB
Python
330 lines
13 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
import sys
|
||
|
from scipy.stats import pearsonr, spearmanr
|
||
|
from torch.utils.data import DataLoader
|
||
|
sys.path.append('.')
|
||
|
import sampling
|
||
|
|
||
|
import datasets_nas
|
||
|
from models import pgsn
|
||
|
from models import digcn
|
||
|
from models import cate
|
||
|
from models import dagformer
|
||
|
from models import digcn
|
||
|
from models import digcn_meta
|
||
|
from models import regressor
|
||
|
from models.GDSS import scorenetx
|
||
|
from models import utils as mutils
|
||
|
from models.ema import ExponentialMovingAverage
|
||
|
import sde_lib
|
||
|
from utils import *
|
||
|
import losses
|
||
|
|
||
|
from analysis.arch_functions import BasicArchMetricsOFA
|
||
|
import losses
|
||
|
from analysis.arch_functions import NUM_STAGE, MAX_LAYER_PER_STAGE
|
||
|
from all_path import *
|
||
|
|
||
|
|
||
|
def get_sampling_fn(config, p=1, prod_w=False, weight_ratio_abs=False):
|
||
|
# Setup SDEs
|
||
|
if config.training.sde.lower() == 'vpsde':
|
||
|
sde = sde_lib.VPSDE(
|
||
|
beta_min=config.model.beta_min,
|
||
|
beta_max=config.model.beta_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-3
|
||
|
elif config.training.sde.lower() == 'subvpsde':
|
||
|
sde = sde_lib.subVPSDE(
|
||
|
beta_min=config.model.beta_min,
|
||
|
beta_max=config.model.beta_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-3
|
||
|
elif config.training.sde.lower() == 'vesde':
|
||
|
sde = sde_lib.VESDE(
|
||
|
sigma_min=config.model.sigma_min,
|
||
|
sigma_max=config.model.sigma_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-5
|
||
|
else:
|
||
|
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||
|
|
||
|
# create data normalizer and its inverse
|
||
|
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||
|
|
||
|
sampling_shape = (
|
||
|
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
|
||
|
sampling_fn = sampling.get_sampling_fn(
|
||
|
config, sde, sampling_shape, inverse_scaler,
|
||
|
sampling_eps, config.data.name, conditional=True,
|
||
|
p=p, prod_w=prod_w, weight_ratio_abs=weight_ratio_abs)
|
||
|
|
||
|
return sampling_fn, sde
|
||
|
|
||
|
|
||
|
def get_sampling_fn_meta(config, p=1, prod_w=False, weight_ratio_abs=False, init=False, n_init=5):
|
||
|
# Setup SDEs
|
||
|
if config.training.sde.lower() == 'vpsde':
|
||
|
sde = sde_lib.VPSDE(
|
||
|
beta_min=config.model.beta_min,
|
||
|
beta_max=config.model.beta_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-3
|
||
|
elif config.training.sde.lower() == 'subvpsde':
|
||
|
sde = sde_lib.subVPSDE(
|
||
|
beta_min=config.model.beta_min,
|
||
|
beta_max=config.model.beta_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-3
|
||
|
elif config.training.sde.lower() == 'vesde':
|
||
|
sde = sde_lib.VESDE(
|
||
|
sigma_min=config.model.sigma_min,
|
||
|
sigma_max=config.model.sigma_max,
|
||
|
N=config.model.num_scales)
|
||
|
sampling_eps = 1e-5
|
||
|
else:
|
||
|
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||
|
|
||
|
# create data normalizer and its inverse
|
||
|
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||
|
|
||
|
if init:
|
||
|
sampling_shape = (
|
||
|
n_init, config.data.max_node, config.data.n_vocab)
|
||
|
else:
|
||
|
sampling_shape = (
|
||
|
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
|
||
|
sampling_fn = sampling.get_sampling_fn(
|
||
|
config, sde, sampling_shape, inverse_scaler,
|
||
|
sampling_eps, config.data.name, conditional=True,
|
||
|
is_meta=True, data_name=config.sampling.check_dataname,
|
||
|
num_sample=config.model.num_sample)
|
||
|
|
||
|
return sampling_fn, sde
|
||
|
|
||
|
|
||
|
def get_score_model(config, pos_enc_type=2):
|
||
|
# Build sampling functions and Load pre-trained score network & predictor network
|
||
|
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||
|
ckpt_path = config.scorenet_ckpt_path
|
||
|
score_config.sampling.corrector = 'langevin'
|
||
|
score_config.model.pos_enc_type = pos_enc_type
|
||
|
|
||
|
score_model = mutils.create_model(score_config)
|
||
|
score_ema = ExponentialMovingAverage(
|
||
|
score_model.parameters(), decay=score_config.model.ema_rate)
|
||
|
score_state = dict(
|
||
|
model=score_model, ema=score_ema, step=0, config=score_config)
|
||
|
score_state = restore_checkpoint(
|
||
|
ckpt_path, score_state,
|
||
|
device=config.device, resume=True)
|
||
|
score_ema.copy_to(score_model.parameters())
|
||
|
return score_model, score_ema, score_config
|
||
|
|
||
|
|
||
|
def get_predictor(config):
|
||
|
classifier_model = mutils.create_model(config)
|
||
|
|
||
|
return classifier_model
|
||
|
|
||
|
|
||
|
def get_adj(data_name, except_inout):
|
||
|
if data_name == 'NASBench201':
|
||
|
_adj = np.asarray(
|
||
|
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||
|
[0, 0, 0, 0, 1, 1, 0, 0],
|
||
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
||
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
||
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
||
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
||
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
||
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||
|
)
|
||
|
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||
|
if except_inout:
|
||
|
_adj = _adj[1:-1, 1:-1]
|
||
|
elif data_name == 'ofa':
|
||
|
assert except_inout
|
||
|
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
|
||
|
_adj = torch.zeros(num_nodes, num_nodes)
|
||
|
for i in range(num_nodes-1):
|
||
|
_adj[i, i+1] = 1
|
||
|
return _adj
|
||
|
return _adj
|
||
|
|
||
|
def generate_archs(
|
||
|
config, sampling_fn, score_model, score_ema, classifier_model,
|
||
|
num_samples, patient_factor, batch_size=512, classifier_scale=None,
|
||
|
task=None):
|
||
|
|
||
|
metrics = BasicArchMetricsOFA()
|
||
|
# algo = 'none'
|
||
|
adj_s = get_adj(config.data.name, config.data.except_inout)
|
||
|
mask_s = aug_mask(adj_s, algo=config.data.aug_mask_algo)[0]
|
||
|
adj_c = get_adj(config.data.name, config.data.except_inout)
|
||
|
mask_c = aug_mask(adj_c, algo=config.data.aug_mask_algo)[0]
|
||
|
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
|
||
|
adj_s, mask_s, adj_c, mask_c = \
|
||
|
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
|
||
|
|
||
|
# Generate and save samples
|
||
|
score_ema.copy_to(score_model.parameters())
|
||
|
if num_samples > batch_size:
|
||
|
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor)
|
||
|
else:
|
||
|
num_sampling_rounds = int(patient_factor)
|
||
|
print(f'==> Sampling for {num_sampling_rounds} rounds...')
|
||
|
|
||
|
r = 0
|
||
|
all_samples = []
|
||
|
classifier_scales = list(range(100000, 0, -int(classifier_scale)))
|
||
|
|
||
|
while True and r < num_sampling_rounds:
|
||
|
classifier_scale = classifier_scales[r]
|
||
|
print(f'==> round {r} classifier_scale {classifier_scale}')
|
||
|
sample, _, sample_chain, (score_grad_norm_p, classifier_grad_norm_p, score_grad_norm_c, classifier_grad_norm_c) \
|
||
|
= sampling_fn(score_model, mask_s, classifier_model,
|
||
|
eval_chain=True,
|
||
|
number_chain_steps=config.sampling.number_chain_steps,
|
||
|
classifier_scale=classifier_scale,
|
||
|
task=task, sample_bs=num_samples)
|
||
|
try:
|
||
|
sample_list = quantize(sample, adj_s) # quantization
|
||
|
_, validity, valid_arch_str, _, _ = metrics.compute_validity(sample_list, adj_s, mask_s)
|
||
|
except:
|
||
|
import pdb; pdb.set_trace()
|
||
|
validity = 0.
|
||
|
valid_arch_str = []
|
||
|
print(f' ==> [Validity]: {round(validity, 4)}')
|
||
|
|
||
|
if len(valid_arch_str) > 0:
|
||
|
all_samples += valid_arch_str
|
||
|
print(f' ==> [# Unique Arch]: {len(set(all_samples))}')
|
||
|
|
||
|
if (len(set(all_samples)) >= num_samples):
|
||
|
break
|
||
|
|
||
|
r += 1
|
||
|
|
||
|
return list(set(all_samples))[:num_samples]
|
||
|
|
||
|
|
||
|
def noise_aware_meta_predictor_fit(config,
|
||
|
predictor_model=None,
|
||
|
xtrain=None,
|
||
|
seed=None,
|
||
|
sde=None,
|
||
|
batch_size=5,
|
||
|
epochs=50,
|
||
|
save_best_p_corr=False,
|
||
|
save_path=None,):
|
||
|
assert save_best_p_corr
|
||
|
reset_seed(seed)
|
||
|
|
||
|
data_loader = DataLoader(xtrain,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=True,
|
||
|
drop_last=True)
|
||
|
|
||
|
# create data normalizer and its inverse
|
||
|
scaler = datasets_nas.get_data_scaler(config)
|
||
|
|
||
|
# Initialize model.
|
||
|
optimizer = losses.get_optimizer(config, predictor_model.parameters())
|
||
|
state = dict(optimizer=optimizer,
|
||
|
model=predictor_model,
|
||
|
step=0,
|
||
|
config=config)
|
||
|
|
||
|
# Build one-step training and evaluation functions
|
||
|
optimize_fn = losses.optimization_manager(config)
|
||
|
continuous = config.training.continuous
|
||
|
reduce_mean = config.training.reduce_mean
|
||
|
likelihood_weighting = config.training.likelihood_weighting
|
||
|
train_step_fn = losses.get_step_fn_predictor(sde, train=True, optimize_fn=optimize_fn,
|
||
|
reduce_mean=reduce_mean, continuous=continuous,
|
||
|
likelihood_weighting=likelihood_weighting,
|
||
|
data=config.data.name, label_list=config.data.label_list,
|
||
|
noised=config.training.noised,
|
||
|
t_spot=config.training.t_spot,
|
||
|
is_meta=True)
|
||
|
|
||
|
# temp
|
||
|
# epochs = len(xtrain) * 100
|
||
|
is_best = False
|
||
|
best_p_corr = -1
|
||
|
ckpt_dir = os.path.join(save_path, 'loop')
|
||
|
print(f'==> Training for {epochs} epochs')
|
||
|
for epoch in range(epochs):
|
||
|
pred_list, labels_list = list(), list()
|
||
|
for step, batch in enumerate(data_loader):
|
||
|
x = batch['x'].to(config.device) # (5, 5, 20, 9)???
|
||
|
adj = get_adj(config.data.name, config.data.except_inout)
|
||
|
task = batch['task']
|
||
|
extra = batch
|
||
|
mask = aug_mask(adj,
|
||
|
algo=config.data.aug_mask_algo,
|
||
|
data=config.data.name)
|
||
|
x = scaler(x.to(config.device))
|
||
|
adj = adj.to(config.device)
|
||
|
mask = mask.to(config.device)
|
||
|
task = task.to(config.device)
|
||
|
batch = (x, adj, mask, extra, task)
|
||
|
# Execute one training step
|
||
|
loss, pred, labels = train_step_fn(state, batch)
|
||
|
pred_list += [v.detach().item() for v in pred.squeeze()]
|
||
|
labels_list += [v.detach().item() for v in labels.squeeze()]
|
||
|
p_corr = pearsonr(np.array(pred_list), np.array(labels_list))[0]
|
||
|
s_corr = spearmanr(np.array(pred_list), np.array(labels_list))[0]
|
||
|
if epoch % 50 == 0: print(f'==> [Epoch-{epoch}] P corr: {round(p_corr, 4)} | S corr: {round(s_corr, 4)}')
|
||
|
|
||
|
if save_best_p_corr:
|
||
|
if p_corr > best_p_corr:
|
||
|
is_best = True
|
||
|
best_p_corr = p_corr
|
||
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||
|
save_checkpoint(ckpt_dir, state, epoch, is_best)
|
||
|
if save_best_p_corr:
|
||
|
loaded_state = torch.load(os.path.join(ckpt_dir, 'model_best.pth.tar'), map_location=config.device)
|
||
|
predictor_model.load_state_dict(loaded_state['model'])
|
||
|
|
||
|
|
||
|
def save_checkpoint(ckpt_dir, state, epoch, is_best):
|
||
|
saved_state = {}
|
||
|
for k in state:
|
||
|
if k in ['optimizer', 'model', 'ema']:
|
||
|
saved_state.update({k: state[k].state_dict()})
|
||
|
else:
|
||
|
saved_state.update({k: state[k]})
|
||
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||
|
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
|
||
|
if is_best:
|
||
|
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
||
|
# remove the ckpt except is_best state
|
||
|
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
||
|
if not ckpt_file.startswith('checkpoint'):
|
||
|
continue
|
||
|
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
||
|
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
||
|
|
||
|
|
||
|
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||
|
if not resume:
|
||
|
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||
|
return state
|
||
|
elif not os.path.exists(ckpt_dir):
|
||
|
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||
|
os.makedirs(os.path.dirname(ckpt_dir))
|
||
|
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||
|
f"Returned the same state as input")
|
||
|
return state
|
||
|
else:
|
||
|
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||
|
for k in state:
|
||
|
if k in ['optimizer', 'model', 'ema']:
|
||
|
state[k].load_state_dict(loaded_state[k])
|
||
|
else:
|
||
|
state[k] = loaded_state[k]
|
||
|
return state
|