302 lines
9.9 KiB
Python
302 lines
9.9 KiB
Python
###########################################################################################
|
|
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
|
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
|
###########################################################################################
|
|
from __future__ import print_function
|
|
import os
|
|
import time
|
|
import igraph
|
|
import random
|
|
import numpy as np
|
|
import scipy.stats
|
|
import torch
|
|
import logging
|
|
|
|
|
|
def reset_seed(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
|
if not resume:
|
|
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
|
return state
|
|
elif not os.path.exists(ckpt_dir):
|
|
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
|
os.makedirs(os.path.dirname(ckpt_dir))
|
|
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
|
f"Returned the same state as input")
|
|
return state
|
|
else:
|
|
loaded_state = torch.load(ckpt_dir, map_location=device)
|
|
for k in state:
|
|
if k in ['optimizer', 'model', 'ema']:
|
|
state[k].load_state_dict(loaded_state[k])
|
|
else:
|
|
state[k] = loaded_state[k]
|
|
return state
|
|
|
|
|
|
def load_graph_config(graph_data_name, nvt, data_path):
|
|
if graph_data_name is not 'nasbench201':
|
|
raise NotImplementedError(graph_data_name)
|
|
g_list = []
|
|
max_n = 0 # maximum number of nodes
|
|
ms = torch.load(data_path)['arch']['matrix']
|
|
for i in range(len(ms)):
|
|
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
|
|
max_n = max(max_n, n)
|
|
g_list.append((g, 0))
|
|
# number of different node types including in/out node
|
|
graph_config = {}
|
|
graph_config['num_vertex_type'] = nvt # original types + start/end types
|
|
graph_config['max_n'] = max_n # maximum number of nodes
|
|
graph_config['START_TYPE'] = 0 # predefined start vertex type
|
|
graph_config['END_TYPE'] = 1 # predefined end vertex type
|
|
|
|
return graph_config
|
|
|
|
|
|
def decode_NAS_BENCH_201_8_to_igraph(row):
|
|
if type(row) == str:
|
|
row = eval(row) # convert string to list of lists
|
|
n = len(row)
|
|
g = igraph.Graph(directed=True)
|
|
g.add_vertices(n)
|
|
for i, node in enumerate(row):
|
|
g.vs[i]['type'] = node[0]
|
|
if i < (n - 2) and i > 0:
|
|
g.add_edge(i, i + 1) # always connect from last node
|
|
for j, edge in enumerate(node[1:]):
|
|
if edge == 1:
|
|
g.add_edge(j, i)
|
|
return g, n
|
|
|
|
|
|
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
|
|
# first need to be a valid DAG computation graph
|
|
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
|
# in addition, node i must connect to node i+1
|
|
res = res and len(g.vs['type']) == 8
|
|
res = res and not (0 in g.vs['type'][1:-1])
|
|
res = res and not (1 in g.vs['type'][1:-1])
|
|
return res
|
|
|
|
|
|
def decode_igraph_to_NAS201_matrix(g):
|
|
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
|
|
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
|
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
|
for i, xy in enumerate(xys):
|
|
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
|
|
import numpy
|
|
return numpy.array(m)
|
|
|
|
|
|
def decode_igraph_to_NAS_BENCH_201_string(g):
|
|
if not is_valid_NAS201(g):
|
|
return None
|
|
m = decode_igraph_to_NAS201_matrix(g)
|
|
types = ['none', 'skip_connect', 'nor_conv_1x1',
|
|
'nor_conv_3x3', 'avg_pool_3x3']
|
|
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
|
format(types[int(m[1][0])],
|
|
types[int(m[2][0])], types[int(m[2][1])],
|
|
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
|
|
|
|
|
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
|
res = g.is_dag()
|
|
n_start, n_end = 0, 0
|
|
for v in g.vs:
|
|
if v['type'] == START_TYPE:
|
|
n_start += 1
|
|
elif v['type'] == END_TYPE:
|
|
n_end += 1
|
|
if v.indegree() == 0 and v['type'] != START_TYPE:
|
|
return False
|
|
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
|
return False
|
|
return res and n_start == 1 and n_end == 1
|
|
|
|
|
|
class Accumulator():
|
|
def __init__(self, *args):
|
|
self.args = args
|
|
self.argdict = {}
|
|
for i, arg in enumerate(args):
|
|
self.argdict[arg] = i
|
|
self.sums = [0] * len(args)
|
|
self.cnt = 0
|
|
|
|
def accum(self, val):
|
|
val = [val] if type(val) is not list else val
|
|
val = [v for v in val if v is not None]
|
|
assert (len(val) == len(self.args))
|
|
for i in range(len(val)):
|
|
if torch.is_tensor(val[i]):
|
|
val[i] = val[i].item()
|
|
self.sums[i] += val[i]
|
|
self.cnt += 1
|
|
|
|
def clear(self):
|
|
self.sums = [0] * len(self.args)
|
|
self.cnt = 0
|
|
|
|
def get(self, arg, avg=True):
|
|
i = self.argdict.get(arg, -1)
|
|
assert (i is not -1)
|
|
if avg:
|
|
return self.sums[i] / (self.cnt + 1e-8)
|
|
else:
|
|
return self.sums[i]
|
|
|
|
def print_(self, header=None, time=None,
|
|
logfile=None, do_not_print=[], as_int=[],
|
|
avg=True):
|
|
msg = '' if header is None else header + ': '
|
|
if time is not None:
|
|
msg += ('(%.3f secs), ' % time)
|
|
|
|
args = [arg for arg in self.args if arg not in do_not_print]
|
|
arg = []
|
|
for arg in args:
|
|
val = self.sums[self.argdict[arg]]
|
|
if avg:
|
|
val /= (self.cnt + 1e-8)
|
|
if arg in as_int:
|
|
msg += ('%s %d, ' % (arg, int(val)))
|
|
else:
|
|
msg += ('%s %.4f, ' % (arg, val))
|
|
print(msg)
|
|
|
|
if logfile is not None:
|
|
logfile.write(msg + '\n')
|
|
logfile.flush()
|
|
|
|
def add_scalars(self, summary, header=None, tag_scalar=None,
|
|
step=None, avg=True, args=None):
|
|
for arg in self.args:
|
|
val = self.sums[self.argdict[arg]]
|
|
if avg:
|
|
val /= (self.cnt + 1e-8)
|
|
else:
|
|
val = val
|
|
tag = f'{header}/{arg}' if header is not None else arg
|
|
if tag_scalar is not None:
|
|
summary.add_scalars(main_tag=tag,
|
|
tag_scalar_dict={tag_scalar: val},
|
|
global_step=step)
|
|
else:
|
|
summary.add_scalar(tag=tag,
|
|
scalar_value=val,
|
|
global_step=step)
|
|
|
|
|
|
class Log:
|
|
def __init__(self, args, logf, summary=None):
|
|
self.args = args
|
|
self.logf = logf
|
|
self.summary = summary
|
|
self.stime = time.time()
|
|
self.ep_sttime = None
|
|
|
|
def print(self, logger, epoch, tag=None, avg=True):
|
|
if tag == 'train':
|
|
ct = time.time() - self.ep_sttime
|
|
tt = time.time() - self.stime
|
|
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
|
print(msg)
|
|
self.logf.write(msg+'\n')
|
|
logger.print_(header=tag, logfile=self.logf, avg=avg)
|
|
|
|
if self.summary is not None:
|
|
logger.add_scalars(
|
|
self.summary, header=tag, step=epoch, avg=avg)
|
|
logger.clear()
|
|
|
|
def print_args(self):
|
|
argdict = vars(self.args)
|
|
print(argdict)
|
|
for k, v in argdict.items():
|
|
self.logf.write(k + ': ' + str(v) + '\n')
|
|
self.logf.write('\n')
|
|
|
|
def set_time(self):
|
|
self.stime = time.time()
|
|
|
|
def save_time_log(self):
|
|
ct = time.time() - self.stime
|
|
msg = f'({ct:6.2f}s) meta-training phase done'
|
|
print(msg)
|
|
self.logf.write(msg+'\n')
|
|
|
|
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
|
|
if tag == 'train':
|
|
ct = time.time() - self.ep_sttime
|
|
tt = time.time() - self.stime
|
|
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
|
self.logf.write(msg+'\n')
|
|
print(msg)
|
|
self.logf.flush()
|
|
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
|
|
# msg += f'time {time.time() - sttime:6.2f} '
|
|
if max_corr_dict is not None:
|
|
max_corr = max_corr_dict['corr']
|
|
max_loss = max_corr_dict['loss']
|
|
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
|
|
msg += f'corr {corr:.4f} ({max_corr:.4f})'
|
|
else:
|
|
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
|
|
self.logf.write(msg+'\n')
|
|
print(msg)
|
|
self.logf.flush()
|
|
|
|
def max_corr_log(self, max_corr_dict):
|
|
corr = max_corr_dict['corr']
|
|
loss = max_corr_dict['loss']
|
|
epoch = max_corr_dict['epoch']
|
|
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
|
|
self.logf.write(msg+'\n')
|
|
print(msg)
|
|
self.logf.flush()
|
|
|
|
|
|
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
|
|
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
|
|
if type(y_pred) == list:
|
|
msg += f'pacc {y_pred[0]:0.4f}'
|
|
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
|
|
else:
|
|
msg += f'pacc {y_pred:0.4f}'
|
|
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
|
|
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
|
|
return msg
|
|
|
|
|
|
def load_model(model, ckpt_path):
|
|
model.cpu()
|
|
model.load_state_dict(torch.load(ckpt_path))
|
|
|
|
|
|
def save_model(epoch, model, model_path, max_corr=None):
|
|
print("==> save current model...")
|
|
if max_corr is not None:
|
|
torch.save(model.cpu().state_dict(),
|
|
os.path.join(model_path, 'ckpt_max_corr.pt'))
|
|
else:
|
|
torch.save(model.cpu().state_dict(),
|
|
os.path.join(model_path, f'ckpt_{epoch}.pt'))
|
|
|
|
|
|
def mean_confidence_interval(data, confidence=0.95):
|
|
a = 1.0 * np.array(data)
|
|
n = len(a)
|
|
m, se = np.mean(a), scipy.stats.sem(a)
|
|
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
|
return m, h
|