import os import torch import numpy as np import random import logging from absl import flags from scipy.stats import pearsonr, spearmanr import torch from models import cate from models import digcn from models import digcn_meta import losses import sampling from models import utils as mutils from models.ema import ExponentialMovingAverage import datasets_nas import sde_lib from utils import * from logger import Logger from analysis.arch_metrics import SamplingArchMetrics, SamplingArchMetricsMeta FLAGS = flags.FLAGS def set_exp_name(config): if config.task == 'tr_scorenet': exp_name = f'./results/{config.task}/{config.folder_name}' data = config.data elif config.task == 'tr_meta_surrogate': exp_name = f'./results/{config.task}/{config.folder_name}' os.makedirs(exp_name, exist_ok=True) config.exp_name = exp_name set_random_seed(config) return exp_name def set_random_seed(config): seed = config.seed os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def scorenet_train(config): """Runs the score network training pipeline. Args: config: Configuration to use. """ ## Set logger exp_name = set_exp_name(config) logger = Logger( log_dir=exp_name, write_textfile=True) logger.update_config(config, is_args=True) logger.write_str(str(vars(config))) logger.write_str('-' * 100) ## Create directories for experimental logs sample_dir = os.path.join(exp_name, "samples") os.makedirs(sample_dir, exist_ok=True) ## Initialize model and optimizer score_model = mutils.create_model(config) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, score_model.parameters()) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, config=config) ## Create checkpoints directory checkpoint_dir = os.path.join(exp_name, "checkpoints") ## Intermediate checkpoints to resume training checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) ## Resume training when intermediate checkpoints are detected if config.resume: state = restore_checkpoint(config.resume_ckpt_path, state, config.device, resume=config.resume) initial_step = int(state['step']) ## Build dataloader and iterators train_ds, eval_ds, test_ds = datasets_nas.get_dataset(config) train_loader, eval_loader, test_loader = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds) train_iter = iter(train_loader) # Create data normalizer and its inverse scaler = datasets_nas.get_data_scaler(config) inverse_scaler = datasets_nas.get_data_inverse_scaler(config) ## Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous reduce_mean = config.training.reduce_mean likelihood_weighting = config.training.likelihood_weighting train_step_fn = losses.get_step_fn(sde=sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, data=config.data.name) eval_step_fn = losses.get_step_fn(sde=sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, data=config.data.name) ## Build sampling functions if config.training.snapshot_sampling: sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab) sampling_fn = sampling.get_sampling_fn(config=config, sde=sde, shape=sampling_shape, inverse_scaler=inverse_scaler, eps=sampling_eps) ## Build analysis tools sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name) ## Start training the score network logging.info("Starting training loop at step %d." % (initial_step,)) element = {'train': ['training_loss'], 'eval': ['eval_loss'], 'test': ['test_loss'], 'sample': ['r_valid', 'r_unique', 'r_novel']} num_train_steps = config.training.n_iters is_best = False min_test_loss = 1e05 for step in range(initial_step, num_train_steps+1): try: x, adj, extra = next(train_iter) except StopIteration: train_iter = train_loader.__iter__() x, adj, extra = next(train_iter) mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name) x, adj, mask = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device) batch = (x, adj, mask) ## Execute one training step loss = train_step_fn(state, batch) logger.update(key="training_loss", v=loss.item()) if step % config.training.log_freq == 0: logging.info("step: %d, training_loss: %.5e" % (step, loss.item())) ## Report the loss on evaluation dataset periodically if step % config.training.eval_freq == 0: for eval_x, eval_adj, eval_extra in eval_loader: eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name) eval_x, eval_adj, eval_mask = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device) eval_batch = (eval_x, eval_adj, eval_mask) eval_loss = eval_step_fn(state, eval_batch) logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item())) logger.update(key="eval_loss", v=eval_loss.item()) for test_x, test_adj, test_extra in test_loader: test_mask = aug_mask(test_adj, algo=config.data.aug_mask_algo, data=config.data.name) test_x, test_adj, test_mask = scaler(test_x.to(config.device)), test_adj.to(config.device), test_mask.to(config.device) test_batch = (test_x, test_adj, test_mask) test_loss = eval_step_fn(state, test_batch) logging.info("step: %d, test_loss: %.5e" % (step, test_loss.item())) logger.update(key="test_loss", v=test_loss.item()) if logger.logs['test_loss'].avg < min_test_loss: is_best = True ## Save the checkpoint if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: save_step = step // config.training.snapshot_freq save_checkpoint(checkpoint_dir, state, step, save_step, is_best) ## Generate samples if config.training.snapshot_sampling: ema.store(score_model.parameters()) ema.copy_to(score_model.parameters()) sample, sample_steps, _ = sampling_fn(score_model, mask) quantized_sample = quantize(sample) this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) os.makedirs(this_sample_dir, exist_ok=True) ## Evaluate samples arch_metric = sampling_metrics(arch_list=quantized_sample, this_sample_dir=this_sample_dir) r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2] logger.update(key="r_valid", v=r_valid) logger.update(key="r_unique", v=r_unique) logger.update(key="r_novel", v=r_novel) logging.info("r_valid: %.5e" % (r_valid)) logging.info("r_unique: %.5e" % (r_unique)) logging.info("r_novel: %.5e" % (r_novel)) if step % config.training.eval_freq == 0: logger.write_log(element=element, step=step) else: logger.write_log(element={'train': ['training_loss']}, step=step) logger.reset() logger.save_log() def scorenet_evaluate(config): """Evaluate trained score network. Args: config: Configuration to use. """ ## Set logger exp_name = set_exp_name(config) logger = Logger( log_dir=exp_name, write_textfile=True) logger.update_config(config, is_args=True) logger.write_str(str(vars(config))) logger.write_str('-' * 100) ## Load the config of pre-trained score network score_config = torch.load(config.scorenet_ckpt_path)['config'] ## Setup SDEs if score_config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=score_config.model.beta_min, beta_max=score_config.model.beta_max, N=score_config.model.num_scales) sampling_eps = 1e-3 elif score_config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=score_config.model.sigma_min, sigma_max=score_config.model.sigma_max, N=score_config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") ## Creat data normalizer and its inverse inverse_scaler = datasets_nas.get_data_inverse_scaler(config) # Build the sampling function sampling_shape = (config.eval.batch_size, score_config.data.max_node, score_config.data.n_vocab) sampling_fn = sampling.get_sampling_fn(config=config, sde=sde, shape=sampling_shape, inverse_scaler=inverse_scaler, eps=sampling_eps) ## Load pre-trained score network score_model = mutils.create_model(score_config) ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate) state = dict(model=score_model, ema=ema, step=0, config=score_config) state = restore_checkpoint(config.scorenet_ckpt_path, state, device=config.device, resume=True) ema.store(score_model.parameters()) ema.copy_to(score_model.parameters()) ## Build dataset train_ds, eval_ds, test_ds = datasets_nas.get_dataset(score_config) ## Build analysis tools sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name) ## Create directories for experimental logs sample_dir = os.path.join(exp_name, "samples") os.makedirs(sample_dir, exist_ok=True) ## Start sampling logging.info("Starting sampling") element = {'sample': ['r_valid', 'r_unique', 'r_novel']} num_sampling_rounds = int(np.ceil(config.eval.num_samples / config.eval.batch_size)) print(f'>>> Sampling for {num_sampling_rounds} rounds...') all_samples = [] adj = train_ds.adj.to(config.device) mask = train_ds.mask(algo=score_config.data.aug_mask_algo).to(config.device) if len(adj.shape) == 2: adj = adj.unsqueeze(0) if len(mask.shape) == 2: mask = mask.unsqueeze(0) for _ in range(num_sampling_rounds): sample, sample_steps, _ = sampling_fn(score_model, mask) quantized_sample = quantize(sample) all_samples += quantized_sample ## Evaluate samples all_samples = all_samples[:config.eval.num_samples] arch_metric = sampling_metrics(arch_list=all_samples, this_sample_dir=sample_dir) r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2] logger.update(key="r_valid", v=r_valid) logger.update(key="r_unique", v=r_unique) logger.update(key="r_novel", v=r_novel) logger.write_log(element=element, step=1) logger.save_log() def meta_surrogate_train(config): """Runs the meta-predictor model training pipeline. Args: config: Configuration to use. """ ## Set logger exp_name = set_exp_name(config) logger = Logger( log_dir=exp_name, write_textfile=True) logger.update_config(config, is_args=True) logger.write_str(str(vars(config))) logger.write_str('-' * 100) ## Create directories for experimental logs sample_dir = os.path.join(exp_name, "samples") os.makedirs(sample_dir, exist_ok=True) ## Initialize model and optimizer surrogate_model = mutils.create_model(config) optimizer = losses.get_optimizer(config, surrogate_model.parameters()) state = dict(optimizer=optimizer, model=surrogate_model, step=0, config=config) ## Create checkpoints directory checkpoint_dir = os.path.join(exp_name, "checkpoints") ## Intermediate checkpoints to resume training checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) ## Resume training when intermediate checkpoints are detected and resume=True state = restore_checkpoint(checkpoint_meta_dir, state, config.device, resume=config.resume) initial_step = int(state['step']) ## Build dataloader and iterators train_ds, eval_ds, test_ds = datasets_nas.get_meta_dataset(config) train_loader, eval_loader, _ = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds) train_iter = iter(train_loader) ## Create data normalizer and its inverse scaler = datasets_nas.get_data_scaler(config) inverse_scaler = datasets_nas.get_data_inverse_scaler(config) ## Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") ## Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous reduce_mean = config.training.reduce_mean likelihood_weighting = config.training.likelihood_weighting train_step_fn = losses.get_step_fn_predictor(sde=sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, data=config.data.name, label_list=config.data.label_list, noised=config.training.noised) eval_step_fn = losses.get_step_fn_predictor(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, data=config.data.name, label_list=config.data.label_list, noised=config.training.noised) ## Build sampling functions if config.training.snapshot_sampling: sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab) sampling_fn = sampling.get_sampling_fn(config=config, sde=sde, shape=sampling_shape, inverse_scaler=inverse_scaler, eps=sampling_eps, conditional=True, data_name=config.sampling.check_dataname, # for sanity check num_sample=config.model.num_sample) ## Load pre-trained score network score_config = torch.load(config.scorenet_ckpt_path)['config'] check_config(score_config, config) score_model = mutils.create_model(score_config) score_ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate) score_state = dict(model=score_model, ema=score_ema, step=0, config=score_config) score_state = restore_checkpoint(config.scorenet_ckpt_path, score_state, device=config.device, resume=True) score_ema.copy_to(score_model.parameters()) ## Build analysis tools sampling_metrics = SamplingArchMetricsMeta(config, train_ds, exp_name) ## Start training logging.info("Starting training loop at step %d." % (initial_step,)) element = {'train': ['training_loss'], 'eval': ['eval_loss', 'eval_p_corr', 'eval_s_corr'], 'sample': ['r_valid', 'r_unique', 'r_novel']} num_train_steps = config.training.n_iters is_best = False max_eval_p_corr = -1 for step in range(initial_step, num_train_steps + 1): try: x, adj, extra, task = next(train_iter) except StopIteration: train_iter = train_loader.__iter__() x, adj, extra, task = next(train_iter) mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name) x, adj, mask, task = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device), task.to(config.device) batch = (x, adj, mask, extra, task) ## Execute one training step loss, pred, labels = train_step_fn(state, batch) logger.update(key="training_loss", v=loss.item()) if step % config.training.log_freq == 0: logging.info("step: %d, training_loss: %.5e" % (step, loss.item())) ## Report the loss on evaluation dataset periodically if step % config.training.eval_freq == 0: eval_pred_list, eval_labels_list = list(), list() for eval_x, eval_adj, eval_extra, eval_task in eval_loader: eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name) eval_x, eval_adj, eval_mask, eval_task = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device), eval_task.to(config.device) eval_batch = (eval_x, eval_adj, eval_mask, eval_extra, eval_task) eval_loss, eval_pred, eval_labels = eval_step_fn(state, eval_batch) eval_pred_list += [v.detach().item() for v in eval_pred.squeeze()] eval_labels_list += [v.detach().item() for v in eval_labels.squeeze()] logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item())) logger.update(key="eval_loss", v=eval_loss.item()) eval_p_corr = pearsonr(np.array(eval_pred_list), np.array(eval_labels_list))[0] eval_s_corr = spearmanr(np.array(eval_pred_list), np.array(eval_labels_list))[0] logging.info("step: %d, eval_p_corr: %.5e" % (step, eval_p_corr)) logging.info("step: %d, eval_s_corr: %.5e" % (step, eval_s_corr)) logger.update(key="eval_p_corr", v=eval_p_corr) logger.update(key="eval_s_corr", v=eval_s_corr) if eval_p_corr > max_eval_p_corr: is_best = True max_eval_p_corr = eval_p_corr ## Save a checkpoint periodically and generate samples if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: ## Save the checkpoint. save_step = step // config.training.snapshot_freq save_checkpoint(checkpoint_dir, state, step, save_step, is_best) ## Generate and save samples if config.training.snapshot_sampling: score_ema.store(score_model.parameters()) score_ema.copy_to(score_model.parameters()) sample = sampling_fn(score_model=score_model, mask=mask, classifier=surrogate_model, classifier_scale=config.sampling.classifier_scale) quantized_sample = quantize(sample) # quantization this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) os.makedirs(this_sample_dir, exist_ok=True) ## Evaluate samples arch_metric = sampling_metrics(arch_list=quantized_sample, this_sample_dir=this_sample_dir, check_dataname=config.sampling.check_dataname) r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2] logging.info("step: %d, r_valid: %.5e" % (step, r_valid)) logging.info("step: %d, r_unique: %.5e" % (step, r_unique)) logging.info("step: %d, r_novel: %.5e" % (step, r_novel)) logger.update(key="r_valid", v=r_valid) logger.update(key="r_unique", v=r_unique) logger.update(key="r_novel", v=r_novel) if step % config.training.eval_freq == 0: logger.write_log(element=element, step=step) else: logger.write_log(element={'train': ['training_loss']}, step=step) logger.reset() def check_config(config1, config2): assert config1.model.sigma_min == config2.model.sigma_min assert config1.model.sigma_max == config2.model.sigma_max assert config1.training.sde == config2.training.sde assert config1.training.continuous == config2.training.continuous assert config1.data.centered == config2.data.centered assert config1.data.max_node == config2.data.max_node assert config1.data.n_vocab == config2.data.n_vocab run_train_dict = { 'scorenet': scorenet_train, 'meta_surrogate': meta_surrogate_train } run_eval_dict = { 'scorenet': scorenet_evaluate, } def train(config): run_train_dict[config.model_type](config) def evaluate(config): run_eval_dict[config.model_type](config)