add config path
This commit is contained in:
parent
f75657ac3b
commit
9360839a35
@ -16,9 +16,12 @@ general:
|
|||||||
final_model_chains_to_save: 1
|
final_model_chains_to_save: 1
|
||||||
enable_progress_bar: False
|
enable_progress_bar: False
|
||||||
save_model: True
|
save_model: True
|
||||||
log_dir: '/nfs/data3/hanzhang/nasbenchDiT'
|
log_dir: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT'
|
||||||
number_checkpoint_limit: 3
|
number_checkpoint_limit: 3
|
||||||
type: 'Trainer'
|
type: 'Trainer'
|
||||||
|
nas_201: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
swap_result: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/swap_results.csv'
|
||||||
|
root: '/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/'
|
||||||
model:
|
model:
|
||||||
type: 'discrete'
|
type: 'discrete'
|
||||||
transition: 'marginal'
|
transition: 'marginal'
|
||||||
|
@ -25,7 +25,6 @@ from sklearn.model_selection import train_test_split
|
|||||||
import utils as utils
|
import utils as utils
|
||||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||||
from diffusion.distributions import DistributionNodes
|
from diffusion.distributions import DistributionNodes
|
||||||
from naswot.score_networks import get_nasbench201_idx_score
|
|
||||||
from naswot import nasspace
|
from naswot import nasspace
|
||||||
from naswot import datasets as dt
|
from naswot import datasets as dt
|
||||||
|
|
||||||
@ -72,7 +71,9 @@ class DataModule(AbstractDataModule):
|
|||||||
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||||
# except NameError:
|
# except NameError:
|
||||||
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
||||||
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
# base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
||||||
|
base_path = os.path.join(self.cfg.general.root, "..")
|
||||||
|
|
||||||
root_path = os.path.join(base_path, self.datadir)
|
root_path = os.path.join(base_path, self.datadir)
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
|
|
||||||
@ -84,7 +85,7 @@ class DataModule(AbstractDataModule):
|
|||||||
# Load the dataset to the memory
|
# Load the dataset to the memory
|
||||||
# Dataset has target property, root path, and transform
|
# Dataset has target property, root path, and transform
|
||||||
source = './NAS-Bench-201-v1_1-096897.pth'
|
source = './NAS-Bench-201-v1_1-096897.pth'
|
||||||
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None, cfg=self.cfg)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
# self.api = dataset.api
|
# self.api = dataset.api
|
||||||
|
|
||||||
@ -384,7 +385,7 @@ class DataModule_original(AbstractDataModule):
|
|||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.test_loader
|
return self.test_loader
|
||||||
|
|
||||||
def new_graphs_to_json(graphs, filename):
|
def new_graphs_to_json(graphs, filename, cfg):
|
||||||
source_name = "nasbench-201"
|
source_name = "nasbench-201"
|
||||||
num_graph = len(graphs)
|
num_graph = len(graphs)
|
||||||
|
|
||||||
@ -491,8 +492,9 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
'num_active_nodes': len(active_nodes),
|
'num_active_nodes': len(active_nodes),
|
||||||
'transition_E': transition_E.tolist(),
|
'transition_E': transition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
import os
|
||||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||||
|
with open(os.path.join(cfg.general.root,'nasbench-201-meta.json'), 'w') as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
@ -656,9 +658,11 @@ def graphs_to_json(graphs, filename):
|
|||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
return meta_dict
|
return meta_dict
|
||||||
class Dataset(InMemoryDataset):
|
class Dataset(InMemoryDataset):
|
||||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None, cfg=None):
|
||||||
self.target_prop = target_prop
|
self.target_prop = target_prop
|
||||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
self.cfg = cfg
|
||||||
|
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
source = os.path.join(self.cfg.general.root, 'NAS-Bench-201-v1_1-096897.pth')
|
||||||
self.source = source
|
self.source = source
|
||||||
# self.api = API(source) # Initialize NAS-Bench-201 API
|
# self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
# print('API loaded')
|
# print('API loaded')
|
||||||
@ -679,7 +683,8 @@ class Dataset(InMemoryDataset):
|
|||||||
return [f'{self.source}.pt']
|
return [f'{self.source}.pt']
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
source = self.cfg.general.nas_201
|
||||||
# self.api = API(source)
|
# self.api = API(source)
|
||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
@ -748,7 +753,8 @@ class Dataset(InMemoryDataset):
|
|||||||
return edges,nodes
|
return edges,nodes
|
||||||
|
|
||||||
|
|
||||||
def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
|
# def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
|
||||||
|
def graph_to_graph_data(graph, idx, args, device):
|
||||||
# def graph_to_graph_data(graph):
|
# def graph_to_graph_data(graph):
|
||||||
ops = graph[1]
|
ops = graph[1]
|
||||||
adj = graph[0]
|
adj = graph[0]
|
||||||
@ -797,7 +803,7 @@ class Dataset(InMemoryDataset):
|
|||||||
args.batch_size = 128
|
args.batch_size = 128
|
||||||
args.GPU = '0'
|
args.GPU = '0'
|
||||||
args.dataset = 'cifar10'
|
args.dataset = 'cifar10'
|
||||||
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
args.api_loc = self.cfg.general.nas_201
|
||||||
args.data_loc = '../cifardata/'
|
args.data_loc = '../cifardata/'
|
||||||
args.seed = 777
|
args.seed = 777
|
||||||
args.init = ''
|
args.init = ''
|
||||||
@ -812,10 +818,11 @@ class Dataset(InMemoryDataset):
|
|||||||
args.num_modules_per_stack = 3
|
args.num_modules_per_stack = 3
|
||||||
args.num_labels = 1
|
args.num_labels = 1
|
||||||
searchspace = nasspace.get_search_space(args)
|
searchspace = nasspace.get_search_space(args)
|
||||||
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||||
self.swap_scores = []
|
self.swap_scores = []
|
||||||
import csv
|
import csv
|
||||||
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
||||||
|
with open(self.cfg.general.swap_result, 'r') as f:
|
||||||
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
|
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
header = next(reader)
|
header = next(reader)
|
||||||
@ -824,12 +831,15 @@ class Dataset(InMemoryDataset):
|
|||||||
device = torch.device('cuda:2')
|
device = torch.device('cuda:2')
|
||||||
with tqdm(total = len_data) as pbar:
|
with tqdm(total = len_data) as pbar:
|
||||||
active_nodes = set()
|
active_nodes = set()
|
||||||
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
import os
|
||||||
|
# file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
||||||
|
file_path = os.path.join(self.cfg.general.root, 'nasbench-201-graph.json')
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
graph_list = json.load(f)
|
graph_list = json.load(f)
|
||||||
i = 0
|
i = 0
|
||||||
flex_graph_list = []
|
flex_graph_list = []
|
||||||
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
# flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
||||||
|
flex_graph_path = os.path.join(self.cfg.general.root,'flex-nasbench201-graph.json')
|
||||||
for graph in graph_list:
|
for graph in graph_list:
|
||||||
print(f'iterate every graph in graph_list, here is {i}')
|
print(f'iterate every graph in graph_list, here is {i}')
|
||||||
arch_info = graph['arch_str']
|
arch_info = graph['arch_str']
|
||||||
@ -837,7 +847,8 @@ class Dataset(InMemoryDataset):
|
|||||||
for op in ops:
|
for op in ops:
|
||||||
if op not in active_nodes:
|
if op not in active_nodes:
|
||||||
active_nodes.add(op)
|
active_nodes.add(op)
|
||||||
data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
|
# data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
|
||||||
|
data = graph_to_graph_data((adj_matrix, ops),idx=i, args=args, device=device)
|
||||||
i += 1
|
i += 1
|
||||||
if data is None:
|
if data is None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
@ -1140,6 +1151,7 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
self.task = task_name
|
self.task = task_name
|
||||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||||
self.ensure_connected = cfg.model.ensure_connected
|
self.ensure_connected = cfg.model.ensure_connected
|
||||||
|
self.cfg = cfg
|
||||||
# self.api = dataset.api
|
# self.api = dataset.api
|
||||||
|
|
||||||
datadir = cfg.dataset.datadir
|
datadir = cfg.dataset.datadir
|
||||||
@ -1182,14 +1194,15 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
# len_ops.add(len(ops))
|
# len_ops.add(len(ops))
|
||||||
# graphs.append((adj_matrix, ops))
|
# graphs.append((adj_matrix, ops))
|
||||||
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
||||||
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
||||||
|
graphs = read_adj_ops_from_json(os.path.join(self.cfg.general.root, 'nasbench-201-graph.json'))
|
||||||
|
|
||||||
# check first five graphs
|
# check first five graphs
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
print(f'graph {i} : {graphs[i]}')
|
print(f'graph {i} : {graphs[i]}')
|
||||||
# print(f'ops_type: {ops_type}')
|
# print(f'ops_type: {ops_type}')
|
||||||
|
|
||||||
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
|
meta_dict = new_graphs_to_json(graphs, 'nasbench-201', self.cfg)
|
||||||
self.base_path = base_path
|
self.base_path = base_path
|
||||||
self.active_nodes = meta_dict['active_nodes']
|
self.active_nodes = meta_dict['active_nodes']
|
||||||
self.max_n_nodes = meta_dict['max_n_nodes']
|
self.max_n_nodes = meta_dict['max_n_nodes']
|
||||||
@ -1396,11 +1409,12 @@ def compute_meta(root, source_name, train_index, test_index):
|
|||||||
'transition_E': tansition_E.tolist(),
|
'transition_E': tansition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
# with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||||
|
with open(os.path.join(self.cfg.general.root, 'nasbench201.meta.json'), "w") as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|
dataset = Dataset(source='nasbench', root='/zhome/academic/HLRS/xmu/xmuhanma/nasbenchDiT/graph_dit/', target_prop='Class', transform=None)
|
||||||
|
@ -24,7 +24,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||||
|
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
self.api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
self.api = API(cfg.general.nas_201)
|
||||||
|
|
||||||
input_dims = dataset_infos.input_dims
|
input_dims = dataset_infos.input_dims
|
||||||
output_dims = dataset_infos.output_dims
|
output_dims = dataset_infos.output_dims
|
||||||
@ -44,7 +44,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.args.batch_size = 128
|
self.args.batch_size = 128
|
||||||
self.args.GPU = '0'
|
self.args.GPU = '0'
|
||||||
self.args.dataset = 'cifar10-valid'
|
self.args.dataset = 'cifar10-valid'
|
||||||
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
self.args.api_loc = cfg.general.nas_201
|
||||||
self.args.data_loc = '../cifardata/'
|
self.args.data_loc = '../cifardata/'
|
||||||
self.args.seed = 777
|
self.args.seed = 777
|
||||||
self.args.init = ''
|
self.args.init = ''
|
||||||
@ -177,7 +177,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
rewards = []
|
rewards = []
|
||||||
if reward_model == 'swap':
|
if reward_model == 'swap':
|
||||||
import csv
|
import csv
|
||||||
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
with open(self.cfg.general.swap_result, 'r') as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
header = next(reader)
|
header = next(reader)
|
||||||
data = [row for row in reader]
|
data = [row for row in reader]
|
||||||
@ -345,10 +345,15 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
num_examples = self.val_y_collection.size(0)
|
num_examples = self.val_y_collection.size(0)
|
||||||
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||||
save_final=to_save,
|
save_final=to_save,
|
||||||
keep_chain=chains_save,
|
keep_chain=chains_save,
|
||||||
number_chain_steps=self.number_chain_steps))
|
number_chain_steps=self.number_chain_steps)
|
||||||
|
samples.extend(cur_sample)
|
||||||
|
# samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||||
|
# save_final=to_save,
|
||||||
|
# keep_chain=chains_save,
|
||||||
|
# number_chain_steps=self.number_chain_steps))
|
||||||
ident += to_generate
|
ident += to_generate
|
||||||
start_index += to_generate
|
start_index += to_generate
|
||||||
|
|
||||||
@ -423,7 +428,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||||
samples.append(cur_sample)
|
samples.extend(cur_sample)
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
batch_id += to_generate
|
batch_id += to_generate
|
||||||
|
Loading…
Reference in New Issue
Block a user