diffusionNAG/NAS-Bench-201/main_exp/transfer_nag/nag_utils.py
2024-03-15 14:38:51 +00:00

302 lines
9.9 KiB
Python

###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import time
import igraph
import random
import numpy as np
import scipy.stats
import torch
import logging
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def load_graph_config(graph_data_name, nvt, data_path):
if graph_data_name is not 'nasbench201':
raise NotImplementedError(graph_data_name)
g_list = []
max_n = 0 # maximum number of nodes
ms = torch.load(data_path)['arch']['matrix']
for i in range(len(ms)):
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
max_n = max(max_n, n)
g_list.append((g, 0))
# number of different node types including in/out node
graph_config = {}
graph_config['num_vertex_type'] = nvt # original types + start/end types
graph_config['max_n'] = max_n # maximum number of nodes
graph_config['START_TYPE'] = 0 # predefined start vertex type
graph_config['END_TYPE'] = 1 # predefined end vertex type
return graph_config
def decode_NAS_BENCH_201_8_to_igraph(row):
if type(row) == str:
row = eval(row) # convert string to list of lists
n = len(row)
g = igraph.Graph(directed=True)
g.add_vertices(n)
for i, node in enumerate(row):
g.vs[i]['type'] = node[0]
if i < (n - 2) and i > 0:
g.add_edge(i, i + 1) # always connect from last node
for j, edge in enumerate(node[1:]):
if edge == 1:
g.add_edge(j, i)
return g, n
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 8
res = res and not (0 in g.vs['type'][1:-1])
res = res and not (1 in g.vs['type'][1:-1])
return res
def decode_igraph_to_NAS201_matrix(g):
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
for i, xy in enumerate(xys):
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
import numpy
return numpy.array(m)
def decode_igraph_to_NAS_BENCH_201_string(g):
if not is_valid_NAS201(g):
return None
m = decode_igraph_to_NAS201_matrix(g)
types = ['none', 'skip_connect', 'nor_conv_1x1',
'nor_conv_3x3', 'avg_pool_3x3']
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
class Accumulator():
def __init__(self, *args):
self.args = args
self.argdict = {}
for i, arg in enumerate(args):
self.argdict[arg] = i
self.sums = [0] * len(args)
self.cnt = 0
def accum(self, val):
val = [val] if type(val) is not list else val
val = [v for v in val if v is not None]
assert (len(val) == len(self.args))
for i in range(len(val)):
if torch.is_tensor(val[i]):
val[i] = val[i].item()
self.sums[i] += val[i]
self.cnt += 1
def clear(self):
self.sums = [0] * len(self.args)
self.cnt = 0
def get(self, arg, avg=True):
i = self.argdict.get(arg, -1)
assert (i is not -1)
if avg:
return self.sums[i] / (self.cnt + 1e-8)
else:
return self.sums[i]
def print_(self, header=None, time=None,
logfile=None, do_not_print=[], as_int=[],
avg=True):
msg = '' if header is None else header + ': '
if time is not None:
msg += ('(%.3f secs), ' % time)
args = [arg for arg in self.args if arg not in do_not_print]
arg = []
for arg in args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
if arg in as_int:
msg += ('%s %d, ' % (arg, int(val)))
else:
msg += ('%s %.4f, ' % (arg, val))
print(msg)
if logfile is not None:
logfile.write(msg + '\n')
logfile.flush()
def add_scalars(self, summary, header=None, tag_scalar=None,
step=None, avg=True, args=None):
for arg in self.args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
else:
val = val
tag = f'{header}/{arg}' if header is not None else arg
if tag_scalar is not None:
summary.add_scalars(main_tag=tag,
tag_scalar_dict={tag_scalar: val},
global_step=step)
else:
summary.add_scalar(tag=tag,
scalar_value=val,
global_step=step)
class Log:
def __init__(self, args, logf, summary=None):
self.args = args
self.logf = logf
self.summary = summary
self.stime = time.time()
self.ep_sttime = None
def print(self, logger, epoch, tag=None, avg=True):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
print(msg)
self.logf.write(msg+'\n')
logger.print_(header=tag, logfile=self.logf, avg=avg)
if self.summary is not None:
logger.add_scalars(
self.summary, header=tag, step=epoch, avg=avg)
logger.clear()
def print_args(self):
argdict = vars(self.args)
print(argdict)
for k, v in argdict.items():
self.logf.write(k + ': ' + str(v) + '\n')
self.logf.write('\n')
def set_time(self):
self.stime = time.time()
def save_time_log(self):
ct = time.time() - self.stime
msg = f'({ct:6.2f}s) meta-training phase done'
print(msg)
self.logf.write(msg+'\n')
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
# msg += f'time {time.time() - sttime:6.2f} '
if max_corr_dict is not None:
max_corr = max_corr_dict['corr']
max_loss = max_corr_dict['loss']
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
msg += f'corr {corr:.4f} ({max_corr:.4f})'
else:
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def max_corr_log(self, max_corr_dict):
corr = max_corr_dict['corr']
loss = max_corr_dict['loss']
epoch = max_corr_dict['epoch']
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
if type(y_pred) == list:
msg += f'pacc {y_pred[0]:0.4f}'
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
else:
msg += f'pacc {y_pred:0.4f}'
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
return msg
def load_model(model, ckpt_path):
model.cpu()
model.load_state_dict(torch.load(ckpt_path))
def save_model(epoch, model, model_path, max_corr=None):
print("==> save current model...")
if max_corr is not None:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, 'ckpt_max_corr.pt'))
else:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, f'ckpt_{epoch}.pt'))
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
return m, h