diffusionNAG/NAS-Bench-201/main_exp/transfer_nag/nag.py
2024-03-15 14:38:51 +00:00

305 lines
13 KiB
Python

from __future__ import print_function
import torch
import os
import gc
import sys
import numpy as np
import os
import subprocess
from nag_utils import mean_confidence_interval
from nag_utils import restore_checkpoint
from nag_utils import load_graph_config
from nag_utils import load_model
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from nas_bench_201 import train_single_model
from unnoised_model import MetaSurrogateUnnoisedModel
from diffusion.run_lib import generate_archs_meta
from diffusion.run_lib import get_sampling_fn_meta
from diffusion.run_lib import get_score_model
from diffusion.run_lib import get_surrogate
from loader import MetaTestDataset
from logger import Logger
from all_path import *
class NAG:
def __init__(self, args):
self.args = args
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
## Target dataset information
self.raw_data_path = RAW_DATA_PATH
self.data_path = DATA_PATH
self.data_name = args.data_name
self.num_class = args.num_class
self.num_sample = args.num_sample
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
load_model(model=self.meta_surrogate_unnoised_model,
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
self.meta_surrogate_unnoised_model.to(self.device)
## Load pre-trained meta-surrogate model
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
## Load score network model (base diffusion model)
self.load_diffusion_model(args=args)
## Check config
self.check_config()
## Set logger
self.logger = Logger(
log_dir=args.exp_name,
write_textfile=True
)
self.logger.update_config(args, is_args=True)
self.logger.write_str(str(vars(args)))
self.logger.write_str('-' * 100)
def check_config(self):
"""
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
"""
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
def forward(self, x, arch):
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
return y_pred
def meta_test(self):
if self.data_name == 'all':
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
self.meta_test_per_dataset(data_name)
else:
self.meta_test_per_dataset(self.data_name)
def meta_test_per_dataset(self, data_name):
## Load NASBench201
self.nasbench201 = torch.load(NASBENCH201)
all_arch_str = np.array(self.nasbench201['arch']['str'])
## Load meta-test dataset
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
## Set save path
meta_test_path = os.path.join(META_TEST_PATH, data_name)
os.makedirs(meta_test_path, exist_ok=True)
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
## Generate architectures
gen_arch_str = self.get_gen_arch_str()
gen_arch_igraph = self.get_items(
full_target=self.nasbench201['arch']['igraph'],
full_source=self.nasbench201['arch']['str'],
source=gen_arch_str)
## Sort with unnoised meta-surrogate model
y_pred_all = []
self.meta_surrogate_unnoised_model.eval()
self.meta_surrogate_unnoised_model.to(self.device)
with torch.no_grad():
for arch_igraph in gen_arch_igraph:
x, g = self.collect_data(arch_igraph)
y_pred = self.forward(x, g)
y_pred = torch.mean(y_pred)
y_pred_all.append(y_pred.cpu().detach().item())
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
## Record the information of the architecture generated in sorted order
for _, arch_str in enumerate(sorted_arch_lst):
f_arch_str.write(f'{arch_str}\n')
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
arch_str_lst = []
arch_acc_lst = []
## Get the accuracy of the architecture
if 'cifar' in data_name:
sorted_acc_lst = self.get_items(
full_target=self.nasbench201['test-acc'][data_name],
full_source=self.nasbench201['arch']['str'],
source=sorted_arch_lst)
arch_str_lst += sorted_arch_lst
arch_acc_lst += sorted_acc_lst
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
msg = f'Avg {acc:4f} (%)'
f_arch_acc.write(msg + '\n')
else:
if self.args.multi_proc:
## Run multiple processes in parallel
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
MAX_CAP = 5 # hard-coded for available GPUs
if not len(arch_idx_lst) > MAX_CAP:
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
else:
arch_idx_lst_ = []
for j, arch_idx in enumerate(arch_idx_lst):
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
arch_idx_lst_.append(arch_idx)
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_))
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
arch_idx_lst_ = []
while True:
try:
acc_runs_lst = []
epoch = 199
seeds = (777, 888, 999)
for arch_idx in arch_idx_lst:
acc_runs = []
save_path_ = os.path.join(meta_test_path, str(arch_idx))
for seed in seeds:
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
acc_runs_lst.append(acc_runs)
break
except:
pass
for i in acc_runs_lst:print(np.mean(i))
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
else:
for arch_idx in arch_idx_lst:
acc_runs = self.train_single_arch(
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
# Save results
results_path = os.path.join(self.args.exp_name, 'results.pt')
torch.save({
'arch_idx_lst': arch_idx_lst,
'arch_str_lst': arch_str_lst,
'arch_acc_lst': arch_acc_lst
}, results_path)
print(f">>> Save the results at {results_path}...")
def train_single_arch(self, data_name, arch_str, meta_test_path):
save_path = os.path.join(meta_test_path, arch_str)
seeds = (777, 888, 999)
train_single_model(save_dir=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{self.raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
epoch = 199
test_acc_lst = []
for seed in seeds:
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
return test_acc_lst
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
_, sorted_idx = torch.sort(y_pred_all, descending=True)
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
return sotred_gen_arch_str
def collect_data_only(self):
x_batch = []
x_batch.append(self.test_dataset[0])
return torch.stack(x_batch).to(self.device)
def collect_data(self, arch_igraph):
x_batch, g_batch = [], []
for _ in range(10):
x_batch.append(self.test_dataset[0])
g_batch.append(arch_igraph)
return torch.stack(x_batch).to(self.device), g_batch
def get_items(self, full_target, full_source, source):
return [full_target[full_source.index(_)] for _ in source]
def load_diffusion_model(self, args):
self.config = torch.load('./configs/transfer_nag_config.pt')
self.config.device = torch.device('cuda')
self.config.data.label_list = ['meta-acc']
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
self.config.sampling.classifier_scale = args.classifier_scale
self.config.eval.batch_size = args.eval_batch_size
self.config.sampling.predictor = args.predictor
self.config.sampling.corrector = args.corrector
self.config.sampling.check_dataname = self.data_name
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
def get_gen_arch_str(self):
## Load meta-surrogate model
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
meta_surrogate_model = get_surrogate(meta_surrogate_config)
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
meta_surrogate_state = restore_checkpoint(
self.meta_surrogate_ckpt_path,
meta_surrogate_state,
device=self.config.device,
resume=True)
## Get dataset embedding, x
with torch.no_grad():
x = self.collect_data_only()
## Generate architectures
generated_arch_str = generate_archs_meta(
config=self.config,
sampling_fn=self.sampling_fn,
score_model=self.score_model,
score_ema=self.score_ema,
meta_surrogate_model=meta_surrogate_model,
num_samples=self.args.n_gen_samples,
args=self.args,
task=x)
## Clean up
meta_surrogate_model = None
gc.collect()
return generated_arch_str