try update the api in DataInfo
This commit is contained in:
		| @@ -50,12 +50,12 @@ class DataModule(AbstractDataModule): | ||||
|  | ||||
|     def prepare_data(self) -> None: | ||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||
|         print("target", target) | ||||
|         print("target", target) # nasbench-201 | ||||
|         # try: | ||||
|         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         # except NameError: | ||||
|         # base_path = pathlib.Path(os.getcwd()).parent[2] | ||||
|         base_path = '/home/stud/hanzhang/Graph-Dit' | ||||
|         base_path = '/home/stud/hanzhang/nasbenchDiT' | ||||
|         root_path = os.path.join(base_path, self.datadir) | ||||
|         self.root_path = root_path | ||||
|  | ||||
| @@ -68,13 +68,15 @@ class DataModule(AbstractDataModule): | ||||
|         # Dataset has target property, root path, and transform | ||||
|         source = './NAS-Bench-201-v1_1-096897.pth' | ||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) | ||||
|         self.dataset = dataset | ||||
|  | ||||
|         # if len(self.task.split('-')) == 2: | ||||
|         #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
|         # else: | ||||
|         train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) | ||||
|  | ||||
|         self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index | ||||
|         self.train_index, self.val_index, self.test_index, self.unlabeled_index = ( | ||||
|             train_index, val_index, test_index, unlabeled_index) | ||||
|         train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) | ||||
|         if len(unlabeled_index) > 0: | ||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||
| @@ -477,14 +479,17 @@ def graphs_to_json(graphs, filename): | ||||
| class Dataset(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): | ||||
|         self.target_prop = target_prop | ||||
|         source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         self.source = source | ||||
|         super().__init__(root, transform, pre_transform, pre_filter) | ||||
|         print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt | ||||
|         self.api = API(source)  # Initialize NAS-Bench-201 API | ||||
|         print('API loaded') | ||||
|         super().__init__(root, transform, pre_transform, pre_filter) | ||||
|         print('Dataset initialized') | ||||
|         print(self.processed_paths[0]) | ||||
|         self.data, self.slices = torch.load(self.processed_paths[0]) | ||||
|         self.data.edge_attr = self.data.edge_attr.squeeze() | ||||
|         self.data.idx = torch.arange(len(self.data.y)) | ||||
|         print(f"self.data={self.data}, self.slices={self.slices}") | ||||
|  | ||||
|     @property | ||||
|     def raw_file_names(self): | ||||
| @@ -676,7 +681,7 @@ def create_adj_matrix_and_ops(nodes, edges): | ||||
|         adj_matrix[src][dst] = 1 | ||||
|     return adj_matrix, nodes | ||||
| class DataInfos(AbstractDatasetInfos): | ||||
|     def __init__(self, datamodule, cfg): | ||||
|     def __init__(self, datamodule, cfg, dataset): | ||||
|         tasktype_dict = { | ||||
|             'hiv_b': 'classification', | ||||
|             'bace_b': 'classification', | ||||
| @@ -689,6 +694,7 @@ class DataInfos(AbstractDatasetInfos): | ||||
|         self.task = task_name | ||||
|         self.task_type = tasktype_dict.get(task_name, "regression") | ||||
|         self.ensure_connected = cfg.model.ensure_connected | ||||
|         self.api = dataset.api | ||||
|  | ||||
|         datadir = cfg.dataset.datadir | ||||
|  | ||||
| @@ -699,9 +705,9 @@ class DataInfos(AbstractDatasetInfos): | ||||
|         length = 15625 | ||||
|         ops_type = {} | ||||
|         len_ops = set() | ||||
|         api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|         # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|         for i in range(length): | ||||
|             arch_info = api.query_meta_info_by_index(i) | ||||
|             arch_info = self.api.query_meta_info_by_index(i) | ||||
|             nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     | ||||
|             if i < 5: | ||||
| @@ -716,7 +722,6 @@ class DataInfos(AbstractDatasetInfos): | ||||
|             graphs.append((adj_matrix, ops)) | ||||
|  | ||||
|         meta_dict = graphs_to_json(graphs, 'nasbench-201') | ||||
|  | ||||
|         self.base_path = base_path | ||||
|         self.active_atoms = meta_dict['active_atoms'] | ||||
|         self.max_n_nodes = meta_dict['max_node'] | ||||
| @@ -930,4 +935,4 @@ def compute_meta(root, source_name, train_index, test_index): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     pass | ||||
|     dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) | ||||
|   | ||||
							
								
								
									
										0
									
								
								graph_dit/workingdoc.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/workingdoc.md
									
									
									
									
									
										Normal file
									
								
							
		Reference in New Issue
	
	Block a user