########################################################################################### # Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021 # Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021 ########################################################################################### from __future__ import print_function import os import time import igraph import random import numpy as np import scipy.stats import torch import logging def reset_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def restore_checkpoint(ckpt_dir, state, device, resume=False): if not resume: os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) return state elif not os.path.exists(ckpt_dir): if not os.path.exists(os.path.dirname(ckpt_dir)): os.makedirs(os.path.dirname(ckpt_dir)) logging.warning(f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input") return state else: loaded_state = torch.load(ckpt_dir, map_location=device) for k in state: if k in ['optimizer', 'model', 'ema']: state[k].load_state_dict(loaded_state[k]) else: state[k] = loaded_state[k] return state def load_graph_config(graph_data_name, nvt, data_path): if graph_data_name is not 'nasbench201': raise NotImplementedError(graph_data_name) g_list = [] max_n = 0 # maximum number of nodes ms = torch.load(data_path)['arch']['matrix'] for i in range(len(ms)): g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i]) max_n = max(max_n, n) g_list.append((g, 0)) # number of different node types including in/out node graph_config = {} graph_config['num_vertex_type'] = nvt # original types + start/end types graph_config['max_n'] = max_n # maximum number of nodes graph_config['START_TYPE'] = 0 # predefined start vertex type graph_config['END_TYPE'] = 1 # predefined end vertex type return graph_config def decode_NAS_BENCH_201_8_to_igraph(row): if type(row) == str: row = eval(row) # convert string to list of lists n = len(row) g = igraph.Graph(directed=True) g.add_vertices(n) for i, node in enumerate(row): g.vs[i]['type'] = node[0] if i < (n - 2) and i > 0: g.add_edge(i, i + 1) # always connect from last node for j, edge in enumerate(node[1:]): if edge == 1: g.add_edge(j, i) return g, n def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1): # first need to be a valid DAG computation graph res = is_valid_DAG(g, START_TYPE, END_TYPE) # in addition, node i must connect to node i+1 res = res and len(g.vs['type']) == 8 res = res and not (0 in g.vs['type'][1:-1]) res = res and not (1 in g.vs['type'][1:-1]) return res def decode_igraph_to_NAS201_matrix(g): m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]] xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)] for i, xy in enumerate(xys): m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2 import numpy return numpy.array(m) def decode_igraph_to_NAS_BENCH_201_string(g): if not is_valid_NAS201(g): return None m = decode_igraph_to_NAS201_matrix(g) types = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\ format(types[int(m[1][0])], types[int(m[2][0])], types[int(m[2][1])], types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])]) def is_valid_DAG(g, START_TYPE=0, END_TYPE=1): res = g.is_dag() n_start, n_end = 0, 0 for v in g.vs: if v['type'] == START_TYPE: n_start += 1 elif v['type'] == END_TYPE: n_end += 1 if v.indegree() == 0 and v['type'] != START_TYPE: return False if v.outdegree() == 0 and v['type'] != END_TYPE: return False return res and n_start == 1 and n_end == 1 class Accumulator(): def __init__(self, *args): self.args = args self.argdict = {} for i, arg in enumerate(args): self.argdict[arg] = i self.sums = [0] * len(args) self.cnt = 0 def accum(self, val): val = [val] if type(val) is not list else val val = [v for v in val if v is not None] assert (len(val) == len(self.args)) for i in range(len(val)): if torch.is_tensor(val[i]): val[i] = val[i].item() self.sums[i] += val[i] self.cnt += 1 def clear(self): self.sums = [0] * len(self.args) self.cnt = 0 def get(self, arg, avg=True): i = self.argdict.get(arg, -1) assert (i is not -1) if avg: return self.sums[i] / (self.cnt + 1e-8) else: return self.sums[i] def print_(self, header=None, time=None, logfile=None, do_not_print=[], as_int=[], avg=True): msg = '' if header is None else header + ': ' if time is not None: msg += ('(%.3f secs), ' % time) args = [arg for arg in self.args if arg not in do_not_print] arg = [] for arg in args: val = self.sums[self.argdict[arg]] if avg: val /= (self.cnt + 1e-8) if arg in as_int: msg += ('%s %d, ' % (arg, int(val))) else: msg += ('%s %.4f, ' % (arg, val)) print(msg) if logfile is not None: logfile.write(msg + '\n') logfile.flush() def add_scalars(self, summary, header=None, tag_scalar=None, step=None, avg=True, args=None): for arg in self.args: val = self.sums[self.argdict[arg]] if avg: val /= (self.cnt + 1e-8) else: val = val tag = f'{header}/{arg}' if header is not None else arg if tag_scalar is not None: summary.add_scalars(main_tag=tag, tag_scalar_dict={tag_scalar: val}, global_step=step) else: summary.add_scalar(tag=tag, scalar_value=val, global_step=step) class Log: def __init__(self, args, logf, summary=None): self.args = args self.logf = logf self.summary = summary self.stime = time.time() self.ep_sttime = None def print(self, logger, epoch, tag=None, avg=True): if tag == 'train': ct = time.time() - self.ep_sttime tt = time.time() - self.stime msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}' print(msg) self.logf.write(msg+'\n') logger.print_(header=tag, logfile=self.logf, avg=avg) if self.summary is not None: logger.add_scalars( self.summary, header=tag, step=epoch, avg=avg) logger.clear() def print_args(self): argdict = vars(self.args) print(argdict) for k, v in argdict.items(): self.logf.write(k + ': ' + str(v) + '\n') self.logf.write('\n') def set_time(self): self.stime = time.time() def save_time_log(self): ct = time.time() - self.stime msg = f'({ct:6.2f}s) meta-training phase done' print(msg) self.logf.write(msg+'\n') def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None): if tag == 'train': ct = time.time() - self.ep_sttime tt = time.time() - self.stime msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}' self.logf.write(msg+'\n') print(msg) self.logf.flush() # msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} ' # msg += f'time {time.time() - sttime:6.2f} ' if max_corr_dict is not None: max_corr = max_corr_dict['corr'] max_loss = max_corr_dict['loss'] msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) ' msg += f'corr {corr:.4f} ({max_corr:.4f})' else: msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}' self.logf.write(msg+'\n') print(msg) self.logf.flush() def max_corr_log(self, max_corr_dict): corr = max_corr_dict['corr'] loss = max_corr_dict['loss'] epoch = max_corr_dict['epoch'] msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}' self.logf.write(msg+'\n') print(msg) self.logf.flush() def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'): msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} ' if type(y_pred) == list: msg += f'pacc {y_pred[0]:0.4f}' msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) ' else: msg += f'pacc {y_pred:0.4f}' msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) ' msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})' return msg def load_model(model, ckpt_path): model.cpu() model.load_state_dict(torch.load(ckpt_path)) def save_model(epoch, model, model_path, max_corr=None): print("==> save current model...") if max_corr is not None: torch.save(model.cpu().state_dict(), os.path.join(model_path, 'ckpt_max_corr.pt')) else: torch.save(model.cpu().state_dict(), os.path.join(model_path, f'ckpt_{epoch}.pt')) def mean_confidence_interval(data, confidence=0.95): a = 1.0 * np.array(data) n = len(a) m, se = np.mean(a), scipy.stats.sem(a) h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) return m, h