391 lines
15 KiB
Python
391 lines
15 KiB
Python
from __future__ import print_function
|
|
import torch
|
|
import os
|
|
import gc
|
|
import sys
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import time
|
|
import os
|
|
|
|
from torch import optim
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from scipy.stats import pearsonr
|
|
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_graph_config, decode_ofa_mbv3_str_to_igraph
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import get_log
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import save_model, mean_confidence_interval
|
|
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.loader import get_meta_train_loader, MetaTestDataset
|
|
|
|
from transfer_nag_lib.encoder_FSBO_ofa import EncoderFSBO as PredictorModel
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.predictor import Predictor as MetaD2APredictor
|
|
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.train import train_single_model
|
|
|
|
from diffusion.run_lib import generate_archs
|
|
from diffusion.run_lib import get_sampling_fn_meta
|
|
from diffusion.run_lib import get_score_model
|
|
from diffusion.run_lib import get_predictor
|
|
|
|
sys.path.append(os.path.join(os.getcwd()))
|
|
from all_path import *
|
|
from utils import restore_checkpoint
|
|
|
|
|
|
class NAG:
|
|
def __init__(self, args, dgp_arch=[99, 50, 179, 194], bohb=False):
|
|
self.args = args
|
|
self.batch_size = args.batch_size
|
|
self.num_sample = args.num_sample
|
|
self.max_epoch = args.max_epoch
|
|
self.save_epoch = args.save_epoch
|
|
self.save_path = args.save_path
|
|
self.search_space = args.search_space
|
|
self.model_name = 'predictor'
|
|
self.test = args.test
|
|
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
self.max_corr_dict = {'corr': -1, 'epoch': -1}
|
|
self.train_arch = args.train_arch
|
|
self.use_metad2a_predictor_selec = args.use_metad2a_predictor_selec
|
|
|
|
self.raw_data_path = RAW_DATA_PATH
|
|
self.model_path = UNNOISE_META_PREDICTOR_CKPT_PATH
|
|
self.data_path = PROCESSED_DATA_PATH
|
|
self.classifier_ckpt_path = NOISE_META_PREDICTOR_CKPT_PATH
|
|
self.load_diffusion_model(self.args.n_training_samples, args.pos_enc_type)
|
|
|
|
graph_config = load_graph_config(
|
|
args.graph_data_name, args.nvt, self.data_path)
|
|
|
|
self.model = PredictorModel(args, graph_config, dgp_arch=dgp_arch)
|
|
self.metad2a_model = MetaD2APredictor(args).model
|
|
|
|
if self.test:
|
|
self.data_name = args.data_name
|
|
self.num_class = args.num_class
|
|
self.load_epoch = args.load_epoch
|
|
self.n_training_samples = self.args.n_training_samples
|
|
self.n_gen_samples = args.n_gen_samples
|
|
self.folder_name = args.folder_name
|
|
self.unique = args.unique
|
|
|
|
model_state_dict = self.model.state_dict()
|
|
load_max_pt = 'ckpt_max_corr.pt'
|
|
ckpt_path = os.path.join(self.model_path, load_max_pt)
|
|
ckpt = torch.load(ckpt_path)
|
|
for k, v in ckpt.items():
|
|
if k in model_state_dict.keys():
|
|
model_state_dict[k] = v
|
|
self.model.cpu()
|
|
self.model.load_state_dict(model_state_dict)
|
|
self.model.to(self.device)
|
|
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
|
|
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
|
|
factor=0.1, patience=1000, verbose=True)
|
|
self.mtrloader = get_meta_train_loader(
|
|
self.batch_size, self.data_path, self.num_sample, is_pred=True)
|
|
|
|
self.acc_mean = self.mtrloader.dataset.mean
|
|
self.acc_std = self.mtrloader.dataset.std
|
|
|
|
|
|
def forward(self, x, arch, labels=None, train=False, matrix=False, metad2a=False):
|
|
if metad2a:
|
|
D_mu = self.metad2a_model.set_encode(x.to(self.device))
|
|
G_mu = self.metad2a_model.graph_encode(arch)
|
|
y_pred = self.metad2a_model.predict(D_mu, G_mu)
|
|
return y_pred
|
|
else:
|
|
D_mu = self.model.set_encode(x.to(self.device))
|
|
G_mu = self.model.graph_encode(arch, matrix=matrix)
|
|
y_pred, y_dist = self.model.predict(D_mu, G_mu, labels=labels, train=train)
|
|
return y_pred, y_dist
|
|
|
|
def meta_train(self):
|
|
sttime = time.time()
|
|
for epoch in range(1, self.max_epoch + 1):
|
|
self.mtrlog.ep_sttime = time.time()
|
|
loss, corr = self.meta_train_epoch(epoch)
|
|
self.scheduler.step(loss)
|
|
self.mtrlog.print_pred_log(loss, corr, 'train', epoch)
|
|
valoss, vacorr = self.meta_validation(epoch)
|
|
if self.max_corr_dict['corr'] < vacorr or epoch==1:
|
|
self.max_corr_dict['corr'] = vacorr
|
|
self.max_corr_dict['epoch'] = epoch
|
|
self.max_corr_dict['loss'] = valoss
|
|
save_model(epoch, self.model, self.model_path, max_corr=True)
|
|
|
|
self.mtrlog.print_pred_log(
|
|
valoss, vacorr, 'valid', max_corr_dict=self.max_corr_dict)
|
|
|
|
if epoch % self.save_epoch == 0:
|
|
save_model(epoch, self.model, self.model_path)
|
|
|
|
self.mtrlog.save_time_log()
|
|
self.mtrlog.max_corr_log(self.max_corr_dict)
|
|
|
|
def meta_train_epoch(self, epoch):
|
|
self.model.to(self.device)
|
|
self.model.train()
|
|
|
|
self.mtrloader.dataset.set_mode('train')
|
|
|
|
dlen = len(self.mtrloader.dataset)
|
|
trloss = 0
|
|
y_all, y_pred_all = [], []
|
|
pbar = tqdm(self.mtrloader)
|
|
|
|
for x, g, acc in pbar:
|
|
self.optimizer.zero_grad()
|
|
y_pred, y_dist = self.forward(x, g, labels=acc, train=True, matrix=False)
|
|
y = acc.to(self.device).double()
|
|
print(y.double())
|
|
print(y_dist)
|
|
loss = -self.model.mll(y_dist, y)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
y = y.tolist()
|
|
y_pred = y_pred.squeeze().tolist()
|
|
y_all += y
|
|
y_pred_all += y_pred
|
|
pbar.set_description(get_log(
|
|
epoch, loss, y_pred, y, self.acc_std, self.acc_mean))
|
|
trloss += float(loss)
|
|
|
|
return trloss / dlen, pearsonr(np.array(y_all),
|
|
np.array(y_pred_all))[0]
|
|
|
|
def meta_validation(self, epoch):
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
valoss = 0
|
|
self.mtrloader.dataset.set_mode('valid')
|
|
dlen = len(self.mtrloader.dataset)
|
|
y_all, y_pred_all = [], []
|
|
pbar = tqdm(self.mtrloader)
|
|
|
|
with torch.no_grad():
|
|
for x, g, acc in pbar:
|
|
y_pred, y_dist = self.forward(x, g, labels=acc, train=False, matrix=False)
|
|
y = acc.to(self.device)
|
|
loss = -self.model.mll(y_dist, y)
|
|
|
|
y = y.tolist()
|
|
y_pred = y_pred.squeeze().tolist()
|
|
y_all += y
|
|
y_pred_all += y_pred
|
|
pbar.set_description(get_log(
|
|
epoch, loss, y_pred, y, self.acc_std, self.acc_mean, tag='val'))
|
|
valoss += float(loss)
|
|
try:
|
|
pearson_corr = pearsonr(np.array(y_all), np.array(y_pred_all))[0]
|
|
except Exception as e:
|
|
pearson_corr = 0
|
|
|
|
return valoss / dlen, pearson_corr
|
|
|
|
def meta_test(self):
|
|
if self.data_name == 'all':
|
|
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
|
acc = self.meta_test_per_dataset(data_name)
|
|
else:
|
|
acc = self.meta_test_per_dataset(self.data_name)
|
|
return acc
|
|
|
|
|
|
def meta_test_per_dataset(self, data_name):
|
|
self.test_dataset = MetaTestDataset(
|
|
self.data_path, data_name, self.num_sample, self.num_class)
|
|
|
|
meta_test_path = self.args.exp_name
|
|
os.makedirs(meta_test_path, exist_ok=True)
|
|
f_arch_str = open(os.path.join(meta_test_path, 'architecture.txt'), 'w')
|
|
f = open(os.path.join(meta_test_path, 'accuracy.txt'), 'w')
|
|
|
|
elasped_time = []
|
|
|
|
print(f'==> select top architectures for {data_name} by meta-predictor...')
|
|
|
|
gen_arch_str = self.get_gen_arch_str()
|
|
|
|
gen_arch_igraph = [decode_ofa_mbv3_str_to_igraph(_) for _ in gen_arch_str]
|
|
|
|
y_pred_all = []
|
|
self.metad2a_model.eval()
|
|
self.metad2a_model.to(self.device)
|
|
|
|
# MetaD2A ver. prediction
|
|
sttime = time.time()
|
|
with torch.no_grad():
|
|
for i, arch_igraph in enumerate(gen_arch_igraph):
|
|
x, g = self.collect_data(arch_igraph)
|
|
y_pred = self.forward(x, g, metad2a=True)
|
|
y_pred = torch.mean(y_pred)
|
|
y_pred_all.append(y_pred.cpu().detach().item())
|
|
|
|
if self.use_metad2a_predictor_selec:
|
|
top_arch_lst = self.select_top_arch(
|
|
data_name, torch.tensor(y_pred_all), gen_arch_str, self.n_training_samples)
|
|
else:
|
|
top_arch_lst = gen_arch_str[:self.n_training_samples]
|
|
|
|
elasped = time.time() - sttime
|
|
elasped_time.append(elasped)
|
|
|
|
for _, arch_str in enumerate(top_arch_lst):
|
|
f_arch_str.write(f'{arch_str}\n'); print(f'neural architecture config: {arch_str}')
|
|
|
|
support = top_arch_lst
|
|
x_support = []
|
|
y_support = []
|
|
seeds = [777, 888, 999]
|
|
y_support_per_seed = {
|
|
_: [] for _ in seeds
|
|
}
|
|
net_info = {
|
|
'params': [],
|
|
'flops': [],
|
|
}
|
|
best_acc = 0.0
|
|
best_sampe_num = 0
|
|
|
|
print("Data name: %s" % data_name)
|
|
for i, arch_str in enumerate(support):
|
|
save_path = os.path.join(meta_test_path, arch_str)
|
|
os.makedirs(save_path, exist_ok=True)
|
|
acc_runs = []
|
|
for seed in seeds:
|
|
print(f'==> train for {data_name} {arch_str} ({seed})')
|
|
valid_acc, max_valid_acc, params, flops = train_single_model(save_path=save_path,
|
|
workers=8,
|
|
datasets=data_name,
|
|
xpaths=f'{self.raw_data_path}/{data_name}',
|
|
splits=[0],
|
|
use_less=False,
|
|
seed=seed,
|
|
model_str=arch_str,
|
|
device='cuda',
|
|
lr=0.01,
|
|
momentum=0.9,
|
|
weight_decay=4e-5,
|
|
report_freq=50,
|
|
epochs=20,
|
|
grad_clip=5,
|
|
cutout=True,
|
|
cutout_length=16,
|
|
autoaugment=True,
|
|
drop=0.2,
|
|
drop_path=0.2,
|
|
img_size=224)
|
|
acc_runs.append(valid_acc)
|
|
y_support_per_seed[seed].append(valid_acc)
|
|
|
|
for r, acc in enumerate(acc_runs):
|
|
msg = f'run {r + 1} {acc:.2f} (%)'
|
|
f.write(msg + '\n')
|
|
f.flush()
|
|
print(msg)
|
|
m, h = mean_confidence_interval(acc_runs)
|
|
|
|
if m > best_acc:
|
|
best_acc = m
|
|
best_sampe_num = i
|
|
msg = f'Avg {m:.3f}+-{h.item():.2f} (%) (best acc {best_acc:.3f} - #{i})'
|
|
f.write(msg + '\n')
|
|
print(msg)
|
|
y_support.append(np.mean(acc_runs))
|
|
x_support.append(arch_str)
|
|
net_info['params'].append(params)
|
|
net_info['flops'].append(flops)
|
|
torch.save({'y_support': y_support, 'x_support': x_support,
|
|
'y_support_per_seed': y_support_per_seed,
|
|
'net_info': net_info,
|
|
'best_acc': best_acc,
|
|
'best_sample_num': best_sampe_num},
|
|
meta_test_path+'/result.pt')
|
|
|
|
|
|
return None
|
|
|
|
|
|
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
|
save_path = os.path.join(meta_test_path, arch_str)
|
|
seeds = (777, 888, 999)
|
|
train_single_model(save_path=save_path,
|
|
workers=24,
|
|
datasets=[data_name],
|
|
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
|
splits=[0],
|
|
use_less=False,
|
|
seeds=seeds,
|
|
model_str=arch_str,
|
|
arch_config={'channel': 16, 'num_cells': 5})
|
|
# Changed training time from 49/199
|
|
epoch = 49 if data_name == 'mnist' else 199
|
|
test_acc_lst = []
|
|
for seed in seeds:
|
|
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
|
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
|
return test_acc_lst
|
|
|
|
|
|
def select_top_arch(
|
|
self, data_name, y_pred_all, gen_arch_str, N):
|
|
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
|
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
|
final_str = sotred_gen_arch_str[:N]
|
|
return final_str
|
|
|
|
def collect_data_only(self):
|
|
x_batch = []
|
|
x_batch.append(self.test_dataset[0])
|
|
return torch.stack(x_batch).to(self.device)
|
|
|
|
def collect_data(self, arch_igraph):
|
|
x_batch, g_batch = [], []
|
|
for _ in range(10):
|
|
x_batch.append(self.test_dataset[0])
|
|
g_batch.append(arch_igraph)
|
|
return torch.stack(x_batch).to(self.device), g_batch
|
|
|
|
def load_diffusion_model(self, n_training_samples, pos_enc_type):
|
|
self.config = torch.load(CONFIG_PATH)
|
|
self.config.data.root = SCORE_MODEL_DATA_PATH
|
|
self.config.scorenet_ckpt_path = SCORE_MODEL_CKPT_PATH
|
|
torch.save(self.config, CONFIG_PATH)
|
|
|
|
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
|
self.sampling_fn_training_samples, _ = get_sampling_fn_meta(self.config, init=True, n_init=n_training_samples)
|
|
self.score_model, self.score_ema, self.score_config \
|
|
= get_score_model(self.config, pos_enc_type=pos_enc_type)
|
|
|
|
def get_gen_arch_str(self):
|
|
classifier_config = torch.load(self.classifier_ckpt_path)['config']
|
|
# Load meta-predictor
|
|
classifier_model = get_predictor(classifier_config)
|
|
classifier_state = dict(model=classifier_model, step=0, config=classifier_config)
|
|
classifier_state = restore_checkpoint(self.classifier_ckpt_path,
|
|
classifier_state, device=self.config.device, resume=True)
|
|
print(f'==> load checkpoint for our predictor: {self.classifier_ckpt_path}...')
|
|
|
|
with torch.no_grad():
|
|
x = self.collect_data_only()
|
|
|
|
generated_arch_str = generate_archs(
|
|
self.config,
|
|
self.sampling_fn,
|
|
self.score_model,
|
|
self.score_ema,
|
|
classifier_model,
|
|
num_samples=self.n_gen_samples,
|
|
patient_factor=self.args.patient_factor,
|
|
batch_size=self.args.eval_batch_size,
|
|
classifier_scale=self.args.classifier_scale,
|
|
task=x if self.args.fix_task else None)
|
|
|
|
gc.collect()
|
|
return generated_arch_str
|