From 0c7c525680409f7fcf11678b5897997323e86942 Mon Sep 17 00:00:00 2001 From: mhz Date: Wed, 26 Jun 2024 22:09:46 +0200 Subject: [PATCH] try update the api in DataInfo --- graph_dit/datasets/dataset.py | 27 ++++++++++++++++----------- graph_dit/workingdoc.md | 0 2 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 graph_dit/workingdoc.md diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 225969d..defe757 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -50,12 +50,12 @@ class DataModule(AbstractDataModule): def prepare_data(self) -> None: target = getattr(self.cfg.dataset, 'guidance_target', None) - print("target", target) + print("target", target) # nasbench-201 # try: # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # except NameError: # base_path = pathlib.Path(os.getcwd()).parent[2] - base_path = '/home/stud/hanzhang/Graph-Dit' + base_path = '/home/stud/hanzhang/nasbenchDiT' root_path = os.path.join(base_path, self.datadir) self.root_path = root_path @@ -68,13 +68,15 @@ class DataModule(AbstractDataModule): # Dataset has target property, root path, and transform source = './NAS-Bench-201-v1_1-096897.pth' dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) + self.dataset = dataset # if len(self.task.split('-')) == 2: # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) # else: train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) - self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index + self.train_index, self.val_index, self.test_index, self.unlabeled_index = ( + train_index, val_index, test_index, unlabeled_index) train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) if len(unlabeled_index) > 0: train_index = torch.cat([train_index, unlabeled_index], dim=0) @@ -477,14 +479,17 @@ def graphs_to_json(graphs, filename): class Dataset(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop - source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.source = source + super().__init__(root, transform, pre_transform, pre_filter) + print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt self.api = API(source) # Initialize NAS-Bench-201 API print('API loaded') - super().__init__(root, transform, pre_transform, pre_filter) print('Dataset initialized') - print(self.processed_paths[0]) self.data, self.slices = torch.load(self.processed_paths[0]) + self.data.edge_attr = self.data.edge_attr.squeeze() + self.data.idx = torch.arange(len(self.data.y)) + print(f"self.data={self.data}, self.slices={self.slices}") @property def raw_file_names(self): @@ -676,7 +681,7 @@ def create_adj_matrix_and_ops(nodes, edges): adj_matrix[src][dst] = 1 return adj_matrix, nodes class DataInfos(AbstractDatasetInfos): - def __init__(self, datamodule, cfg): + def __init__(self, datamodule, cfg, dataset): tasktype_dict = { 'hiv_b': 'classification', 'bace_b': 'classification', @@ -689,6 +694,7 @@ class DataInfos(AbstractDatasetInfos): self.task = task_name self.task_type = tasktype_dict.get(task_name, "regression") self.ensure_connected = cfg.model.ensure_connected + self.api = dataset.api datadir = cfg.dataset.datadir @@ -699,9 +705,9 @@ class DataInfos(AbstractDatasetInfos): length = 15625 ops_type = {} len_ops = set() - api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') + # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') for i in range(length): - arch_info = api.query_meta_info_by_index(i) + arch_info = self.api.query_meta_info_by_index(i) nodes, edges = parse_architecture_string(arch_info.arch_str) adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) if i < 5: @@ -716,7 +722,6 @@ class DataInfos(AbstractDatasetInfos): graphs.append((adj_matrix, ops)) meta_dict = graphs_to_json(graphs, 'nasbench-201') - self.base_path = base_path self.active_atoms = meta_dict['active_atoms'] self.max_n_nodes = meta_dict['max_node'] @@ -930,4 +935,4 @@ def compute_meta(root, source_name, train_index, test_index): if __name__ == "__main__": - pass \ No newline at end of file + dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) diff --git a/graph_dit/workingdoc.md b/graph_dit/workingdoc.md new file mode 100644 index 0000000..e69de29