From a222c514d9eb696958fc9233d20ce4fcb00c49c5 Mon Sep 17 00:00:00 2001 From: mhz Date: Wed, 26 Jun 2024 22:42:06 +0200 Subject: [PATCH] add get_train_graphs --- graph_dit/datasets/dataset.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index defe757..143f65c 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -69,6 +69,7 @@ class DataModule(AbstractDataModule): source = './NAS-Bench-201-v1_1-096897.pth' dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) self.dataset = dataset + self.api = dataset.api # if len(self.task.split('-')) == 2: # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) @@ -177,6 +178,27 @@ class DataModule(AbstractDataModule): smiles = Chem.MolToSmiles(mol) return smiles + def get_train_graphs(self): + train_graphs = [] + test_graphs = [] + for graph in self.train_dataset: + train_graphs.append(graph) + for graph in self.test_dataset: + test_graphs.append(graph) + return train_graphs, test_graphs + + + # def get_train_smiles(self): + # filename = f'{self.task}.csv.gz' + # df = pd.read_csv(f'{self.root_path}/raw/{filename}') + # df_test = df.iloc[self.test_index] + # df = df.iloc[self.train_index] + # smiles_list = df['smiles'].tolist() + # smiles_list_test = df_test['smiles'].tolist() + # smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] + # smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] + # return smiles_list, smiles_list_test + def get_train_smiles(self): train_smiles = [] test_smiles = []