diff --git a/configs/config.yaml b/configs/config.yaml index ce858e4..19ec8df 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -16,9 +16,12 @@ general: final_model_chains_to_save: 1 enable_progress_bar: False save_model: True - log_dir: '/nfs/data3/hanzhang/nasbenchDiT' + log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT' number_checkpoint_limit: 3 type: 'Trainer' + nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv' + root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/' model: type: 'discrete' transition: 'marginal' diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 5e50bc2..2bd19a2 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split import utils as utils from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from diffusion.distributions import DistributionNodes -from naswot.score_networks import get_nasbench201_idx_score from naswot import nasspace from naswot import datasets as dt @@ -72,7 +71,9 @@ class DataModule(AbstractDataModule): # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # except NameError: # base_path = pathlib.Path(os.getcwd()).parent[2] - base_path = '/nfs/data3/hanzhang/nasbenchDiT' + # base_path = '/nfs/data3/hanzhang/nasbenchDiT' + base_path = os.path.join(self.cfg.general.root, "..") + root_path = os.path.join(base_path, self.datadir) self.root_path = root_path @@ -84,7 +85,7 @@ class DataModule(AbstractDataModule): # Load the dataset to the memory # Dataset has target property, root path, and transform source = './NAS-Bench-201-v1_1-096897.pth' - dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) + dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg) self.dataset = dataset # self.api = dataset.api @@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule): def test_dataloader(self): return self.test_loader -def new_graphs_to_json(graphs, filename): +def new_graphs_to_json(graphs, filename, cfg): source_name = "nasbench-201" num_graph = len(graphs) @@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename): 'num_active_nodes': len(active_nodes), 'transition_E': transition_E.tolist(), } - - with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: + import os + # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: + with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f: json.dump(meta_dict, f) return meta_dict @@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename): json.dump(meta_dict, f) return meta_dict class Dataset(InMemoryDataset): - def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): + def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None): self.target_prop = target_prop - source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + self.cfg = cfg + # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth') self.source = source # self.api = API(source) # Initialize NAS-Bench-201 API # print('API loaded') @@ -679,7 +683,8 @@ class Dataset(InMemoryDataset): return [f'{self.source}.pt'] def process(self): - source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + source = self.cfg.general.nas_201 # self.api = API(source) data_list = [] @@ -748,7 +753,8 @@ class Dataset(InMemoryDataset): return edges,nodes - def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): + # def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): + def graph_to_graph_data(graph, idx, args, device): # def graph_to_graph_data(graph): ops = graph[1] adj = graph[0] @@ -797,7 +803,7 @@ class Dataset(InMemoryDataset): args.batch_size = 128 args.GPU = '0' args.dataset = 'cifar10' - args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + args.api_loc = self.cfg.general.nas_201 args.data_loc = '../cifardata/' args.seed = 777 args.init = '' @@ -812,10 +818,11 @@ class Dataset(InMemoryDataset): args.num_modules_per_stack = 3 args.num_labels = 1 searchspace = nasspace.get_search_space(args) - train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) + # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) self.swap_scores = [] import csv - with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + with open(self.cfg.general.swap_result, 'r') as f: # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: reader = csv.reader(f) header = next(reader) @@ -824,12 +831,15 @@ class Dataset(InMemoryDataset): device = torch.device('cuda:2') with tqdm(total = len_data) as pbar: active_nodes = set() - file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' + import os + # file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' + file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json') with open(file_path, 'r') as f: graph_list = json.load(f) i = 0 flex_graph_list = [] - flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' + # flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' + flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json') for graph in graph_list: print(f'iterate every graph in graph_list, here is {i}') arch_info = graph['arch_str'] @@ -837,7 +847,8 @@ class Dataset(InMemoryDataset): for op in ops: if op not in active_nodes: active_nodes.add(op) - data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) + # data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) + data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device) i += 1 if data is None: pbar.update(1) @@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos): self.task = task_name self.task_type = tasktype_dict.get(task_name, "regression") self.ensure_connected = cfg.model.ensure_connected + self.cfg = cfg # self.api = dataset.api datadir = cfg.dataset.datadir @@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos): # len_ops.add(len(ops)) # graphs.append((adj_matrix, ops)) # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') - graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') + # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') + graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')) # check first five graphs for i in range(5): print(f'graph {i} : {graphs[i]}') # print(f'ops_type: {ops_type}') - meta_dict = new_graphs_to_json(graphs, 'nasbench-201') + meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg) self.base_path = base_path self.active_nodes = meta_dict['active_nodes'] self.max_n_nodes = meta_dict['max_n_nodes'] @@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index): 'transition_E': tansition_E.tolist(), } - with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: + # with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: + with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f: json.dump(meta_dict, f) return meta_dict if __name__ == "__main__": - dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) + dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index b8dfc33..6c38a9c 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -24,7 +24,7 @@ class Graph_DiT(pl.LightningModule): self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) from nas_201_api import NASBench201API as API - self.api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') + self.api = API(cfg.general.nas_201) input_dims = dataset_infos.input_dims output_dims = dataset_infos.output_dims @@ -44,7 +44,7 @@ class Graph_DiT(pl.LightningModule): self.args.batch_size = 128 self.args.GPU = '0' self.args.dataset = 'cifar10-valid' - self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + self.args.api_loc = cfg.general.nas_201 self.args.data_loc = '../cifardata/' self.args.seed = 777 self.args.init = '' @@ -177,7 +177,7 @@ class Graph_DiT(pl.LightningModule): rewards = [] if reward_model == 'swap': import csv - with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + with open(self.cfg.general.swap_result, 'r') as f: reader = csv.reader(f) header = next(reader) data = [row for row in reader] @@ -345,10 +345,15 @@ class Graph_DiT(pl.LightningModule): num_examples = self.val_y_collection.size(0) batch_y = self.val_y_collection[start_index:start_index + to_generate] all_ys.append(batch_y) - samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, + cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, save_final=to_save, keep_chain=chains_save, - number_chain_steps=self.number_chain_steps)) + number_chain_steps=self.number_chain_steps) + samples.extend(cur_sample) + # samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, + # save_final=to_save, + # keep_chain=chains_save, + # number_chain_steps=self.number_chain_steps)) ident += to_generate start_index += to_generate @@ -423,7 +428,7 @@ class Graph_DiT(pl.LightningModule): cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, keep_chain=chains_save, number_chain_steps=self.number_chain_steps) - samples.append(cur_sample) + samples.extend(cur_sample) all_ys.append(batch_y) batch_id += to_generate