diffusionNAG/NAS-Bench-201/analysis/arch_metrics.py
2024-03-15 14:38:51 +00:00

77 lines
3.2 KiB
Python

from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
import torch.nn as nn
class SamplingArchMetrics(nn.Module):
def __init__(self,
config,
train_ds,
exp_name,):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.train_arch_str_list = train_ds.arch_str_list_
def forward(self,
arch_list: list,
this_sample_dir,
check_dataname='cifar10'):
arch_metrics, all_arch_str = compute_arch_metrics(arch_list=arch_list,
train_arch_str_list=self.train_arch_str_list,
train_ds=self.train_ds,
check_dataname=check_dataname)
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics
class SamplingArchMetricsMeta(nn.Module):
def __init__(self,
config,
train_ds,
exp_name):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.search_space = config.data.name
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
def forward(self,
arch_list: list,
this_sample_dir,
check_dataname='cifar10'):
arch_metrics = compute_arch_metrics_meta(arch_list=arch_list,
train_arch_str_list=self.train_arch_str_list,
train_ds=self.train_ds,
check_dataname=check_dataname)
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics