305 lines
13 KiB
Python
305 lines
13 KiB
Python
from __future__ import print_function
|
|
import torch
|
|
import os
|
|
import gc
|
|
import sys
|
|
import numpy as np
|
|
import os
|
|
import subprocess
|
|
|
|
from nag_utils import mean_confidence_interval
|
|
from nag_utils import restore_checkpoint
|
|
from nag_utils import load_graph_config
|
|
from nag_utils import load_model
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
|
from nas_bench_201 import train_single_model
|
|
from unnoised_model import MetaSurrogateUnnoisedModel
|
|
from diffusion.run_lib import generate_archs_meta
|
|
from diffusion.run_lib import get_sampling_fn_meta
|
|
from diffusion.run_lib import get_score_model
|
|
from diffusion.run_lib import get_surrogate
|
|
from loader import MetaTestDataset
|
|
from logger import Logger
|
|
from all_path import *
|
|
|
|
|
|
class NAG:
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
## Target dataset information
|
|
self.raw_data_path = RAW_DATA_PATH
|
|
self.data_path = DATA_PATH
|
|
self.data_name = args.data_name
|
|
self.num_class = args.num_class
|
|
self.num_sample = args.num_sample
|
|
|
|
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
|
|
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
|
|
load_model(model=self.meta_surrogate_unnoised_model,
|
|
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
|
|
self.meta_surrogate_unnoised_model.to(self.device)
|
|
|
|
## Load pre-trained meta-surrogate model
|
|
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
|
|
|
|
## Load score network model (base diffusion model)
|
|
self.load_diffusion_model(args=args)
|
|
|
|
## Check config
|
|
self.check_config()
|
|
|
|
## Set logger
|
|
self.logger = Logger(
|
|
log_dir=args.exp_name,
|
|
write_textfile=True
|
|
)
|
|
self.logger.update_config(args, is_args=True)
|
|
self.logger.write_str(str(vars(args)))
|
|
self.logger.write_str('-' * 100)
|
|
|
|
|
|
def check_config(self):
|
|
"""
|
|
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
|
|
"""
|
|
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
|
|
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
|
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
|
|
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
|
|
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
|
|
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
|
|
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
|
|
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
|
|
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
|
|
|
|
|
|
def forward(self, x, arch):
|
|
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
|
|
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
|
|
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
|
|
return y_pred
|
|
|
|
|
|
def meta_test(self):
|
|
if self.data_name == 'all':
|
|
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
|
self.meta_test_per_dataset(data_name)
|
|
else:
|
|
self.meta_test_per_dataset(self.data_name)
|
|
|
|
|
|
def meta_test_per_dataset(self, data_name):
|
|
## Load NASBench201
|
|
self.nasbench201 = torch.load(NASBENCH201)
|
|
all_arch_str = np.array(self.nasbench201['arch']['str'])
|
|
|
|
## Load meta-test dataset
|
|
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
|
|
|
|
## Set save path
|
|
meta_test_path = os.path.join(META_TEST_PATH, data_name)
|
|
os.makedirs(meta_test_path, exist_ok=True)
|
|
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
|
|
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
|
|
|
|
## Generate architectures
|
|
gen_arch_str = self.get_gen_arch_str()
|
|
gen_arch_igraph = self.get_items(
|
|
full_target=self.nasbench201['arch']['igraph'],
|
|
full_source=self.nasbench201['arch']['str'],
|
|
source=gen_arch_str)
|
|
|
|
## Sort with unnoised meta-surrogate model
|
|
y_pred_all = []
|
|
self.meta_surrogate_unnoised_model.eval()
|
|
self.meta_surrogate_unnoised_model.to(self.device)
|
|
with torch.no_grad():
|
|
for arch_igraph in gen_arch_igraph:
|
|
x, g = self.collect_data(arch_igraph)
|
|
y_pred = self.forward(x, g)
|
|
y_pred = torch.mean(y_pred)
|
|
y_pred_all.append(y_pred.cpu().detach().item())
|
|
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
|
|
|
|
## Record the information of the architecture generated in sorted order
|
|
for _, arch_str in enumerate(sorted_arch_lst):
|
|
f_arch_str.write(f'{arch_str}\n')
|
|
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
|
|
arch_str_lst = []
|
|
arch_acc_lst = []
|
|
|
|
## Get the accuracy of the architecture
|
|
if 'cifar' in data_name:
|
|
sorted_acc_lst = self.get_items(
|
|
full_target=self.nasbench201['test-acc'][data_name],
|
|
full_source=self.nasbench201['arch']['str'],
|
|
source=sorted_arch_lst)
|
|
arch_str_lst += sorted_arch_lst
|
|
arch_acc_lst += sorted_acc_lst
|
|
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
|
|
msg = f'Avg {acc:4f} (%)'
|
|
f_arch_acc.write(msg + '\n')
|
|
else:
|
|
if self.args.multi_proc:
|
|
## Run multiple processes in parallel
|
|
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
|
|
MAX_CAP = 5 # hard-coded for available GPUs
|
|
if not len(arch_idx_lst) > MAX_CAP:
|
|
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
|
|
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
|
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
|
|
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
|
subprocess.run([cmd], shell=True)
|
|
else:
|
|
arch_idx_lst_ = []
|
|
for j, arch_idx in enumerate(arch_idx_lst):
|
|
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
|
|
arch_idx_lst_.append(arch_idx)
|
|
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
|
|
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
|
num_split = int(3 * len(arch_idx_lst_))
|
|
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
|
subprocess.run([cmd], shell=True)
|
|
arch_idx_lst_ = []
|
|
|
|
while True:
|
|
try:
|
|
acc_runs_lst = []
|
|
epoch = 199
|
|
seeds = (777, 888, 999)
|
|
for arch_idx in arch_idx_lst:
|
|
acc_runs = []
|
|
save_path_ = os.path.join(meta_test_path, str(arch_idx))
|
|
for seed in seeds:
|
|
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
|
|
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
|
acc_runs_lst.append(acc_runs)
|
|
break
|
|
except:
|
|
pass
|
|
for i in acc_runs_lst:print(np.mean(i))
|
|
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
|
|
for r, acc in enumerate(acc_runs):
|
|
msg = f'run {r+1} {acc:.2f} (%)'
|
|
f_arch_acc.write(msg + '\n')
|
|
m, h = mean_confidence_interval(acc_runs)
|
|
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
|
f_arch_acc.write(msg + '\n')
|
|
arch_acc_lst.append(np.mean(acc_runs))
|
|
arch_str_lst.append(all_arch_str[arch_idx])
|
|
|
|
else:
|
|
for arch_idx in arch_idx_lst:
|
|
acc_runs = self.train_single_arch(
|
|
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
|
|
for r, acc in enumerate(acc_runs):
|
|
msg = f'run {r+1} {acc:.2f} (%)'
|
|
f_arch_acc.write(msg + '\n')
|
|
m, h = mean_confidence_interval(acc_runs)
|
|
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
|
f_arch_acc.write(msg + '\n')
|
|
arch_acc_lst.append(np.mean(acc_runs))
|
|
arch_str_lst.append(all_arch_str[arch_idx])
|
|
|
|
# Save results
|
|
results_path = os.path.join(self.args.exp_name, 'results.pt')
|
|
torch.save({
|
|
'arch_idx_lst': arch_idx_lst,
|
|
'arch_str_lst': arch_str_lst,
|
|
'arch_acc_lst': arch_acc_lst
|
|
}, results_path)
|
|
print(f">>> Save the results at {results_path}...")
|
|
|
|
|
|
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
|
save_path = os.path.join(meta_test_path, arch_str)
|
|
seeds = (777, 888, 999)
|
|
train_single_model(save_dir=save_path,
|
|
workers=24,
|
|
datasets=[data_name],
|
|
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
|
splits=[0],
|
|
use_less=False,
|
|
seeds=seeds,
|
|
model_str=arch_str,
|
|
arch_config={'channel': 16, 'num_cells': 5})
|
|
epoch = 199
|
|
test_acc_lst = []
|
|
for seed in seeds:
|
|
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
|
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
|
return test_acc_lst
|
|
|
|
|
|
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
|
|
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
|
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
|
return sotred_gen_arch_str
|
|
|
|
|
|
def collect_data_only(self):
|
|
x_batch = []
|
|
x_batch.append(self.test_dataset[0])
|
|
return torch.stack(x_batch).to(self.device)
|
|
|
|
|
|
def collect_data(self, arch_igraph):
|
|
x_batch, g_batch = [], []
|
|
for _ in range(10):
|
|
x_batch.append(self.test_dataset[0])
|
|
g_batch.append(arch_igraph)
|
|
return torch.stack(x_batch).to(self.device), g_batch
|
|
|
|
|
|
def get_items(self, full_target, full_source, source):
|
|
return [full_target[full_source.index(_)] for _ in source]
|
|
|
|
|
|
def load_diffusion_model(self, args):
|
|
self.config = torch.load('./configs/transfer_nag_config.pt')
|
|
self.config.device = torch.device('cuda')
|
|
self.config.data.label_list = ['meta-acc']
|
|
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
|
self.config.sampling.classifier_scale = args.classifier_scale
|
|
self.config.eval.batch_size = args.eval_batch_size
|
|
self.config.sampling.predictor = args.predictor
|
|
self.config.sampling.corrector = args.corrector
|
|
self.config.sampling.check_dataname = self.data_name
|
|
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
|
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
|
|
|
|
|
|
def get_gen_arch_str(self):
|
|
## Load meta-surrogate model
|
|
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
|
meta_surrogate_model = get_surrogate(meta_surrogate_config)
|
|
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
|
|
meta_surrogate_state = restore_checkpoint(
|
|
self.meta_surrogate_ckpt_path,
|
|
meta_surrogate_state,
|
|
device=self.config.device,
|
|
resume=True)
|
|
|
|
## Get dataset embedding, x
|
|
with torch.no_grad():
|
|
x = self.collect_data_only()
|
|
|
|
## Generate architectures
|
|
generated_arch_str = generate_archs_meta(
|
|
config=self.config,
|
|
sampling_fn=self.sampling_fn,
|
|
score_model=self.score_model,
|
|
score_ema=self.score_ema,
|
|
meta_surrogate_model=meta_surrogate_model,
|
|
num_samples=self.args.n_gen_samples,
|
|
args=self.args,
|
|
task=x)
|
|
|
|
## Clean up
|
|
meta_surrogate_model = None
|
|
gc.collect()
|
|
|
|
return generated_arch_str |