add get_train_graphs
This commit is contained in:
		| @@ -69,6 +69,7 @@ class DataModule(AbstractDataModule): | ||||
|         source = './NAS-Bench-201-v1_1-096897.pth' | ||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) | ||||
|         self.dataset = dataset | ||||
|         self.api = dataset.api | ||||
|  | ||||
|         # if len(self.task.split('-')) == 2: | ||||
|         #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
| @@ -177,6 +178,27 @@ class DataModule(AbstractDataModule): | ||||
|         smiles = Chem.MolToSmiles(mol) | ||||
|         return smiles | ||||
|  | ||||
|     def get_train_graphs(self): | ||||
|         train_graphs = [] | ||||
|         test_graphs = [] | ||||
|         for graph in self.train_dataset: | ||||
|             train_graphs.append(graph) | ||||
|         for graph in self.test_dataset: | ||||
|             test_graphs.append(graph) | ||||
|         return train_graphs, test_graphs | ||||
|  | ||||
|  | ||||
|     # def get_train_smiles(self): | ||||
|     #     filename = f'{self.task}.csv.gz' | ||||
|     #     df = pd.read_csv(f'{self.root_path}/raw/{filename}') | ||||
|     #     df_test = df.iloc[self.test_index] | ||||
|     #     df = df.iloc[self.train_index] | ||||
|     #     smiles_list = df['smiles'].tolist() | ||||
|     #     smiles_list_test = df_test['smiles'].tolist() | ||||
|     #     smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] | ||||
|     #     smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] | ||||
|     #     return smiles_list, smiles_list_test | ||||
|  | ||||
|     def get_train_smiles(self): | ||||
|         train_smiles = []    | ||||
|         test_smiles = [] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user