add somecomments
This commit is contained in:
		| @@ -123,4 +123,8 @@ class AbstractDatasetInfos: | ||||
|                            'y': example_batch['y'].size(1)} | ||||
|         self.output_dims = {'X': example_batch_x.size(1), | ||||
|                             'E': example_batch_edge_attr.size(1), | ||||
|                             'y': example_batch['y'].size(1)} | ||||
|                             'y': example_batch['y'].size(1)} | ||||
|         print('input dims') | ||||
|         print(self.input_dims) | ||||
|         print('output dims') | ||||
|         print(self.output_dims) | ||||
| @@ -28,19 +28,38 @@ class DataModule(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         self.datadir = cfg.dataset.datadir | ||||
|         self.task = cfg.dataset.task_name | ||||
|         print("DataModule") | ||||
|         print("task", self.task) | ||||
|         print("datadir`",self.datadir) | ||||
|         super().__init__(cfg) | ||||
|  | ||||
|     def prepare_data(self) -> None: | ||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||
|         print("target", target) | ||||
|         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         root_path = os.path.join(base_path, self.datadir) | ||||
|         self.root_path = root_path | ||||
|  | ||||
|         batch_size = self.cfg.train.batch_size | ||||
|          | ||||
|         num_workers = self.cfg.train.num_workers | ||||
|         pin_memory = self.cfg.dataset.pin_memory | ||||
|  | ||||
|         # Load the dataset to the memory | ||||
|         # Dataset has target property, root path, and transform | ||||
|         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) | ||||
|         print("len dataset", len(dataset)) | ||||
|         def print_data(dataset): | ||||
|             print("dataset", dataset) | ||||
|             print("dataset keys", dataset.keys) | ||||
|             print("dataset x", dataset.x) | ||||
|             print("dataset edge_index", dataset.edge_index) | ||||
|             print("dataset edge_attr", dataset.edge_attr) | ||||
|             print("dataset y", dataset.y) | ||||
|             print("") | ||||
|         print_data(dataset=dataset[0]) | ||||
|         print_data(dataset=dataset[1]) | ||||
|  | ||||
|  | ||||
|         if len(self.task.split('-')) == 2: | ||||
|             train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
| @@ -53,8 +72,12 @@ class DataModule(AbstractDataModule): | ||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||
|          | ||||
|         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] | ||||
|         self.train_dataset = train_dataset | ||||
|         self.train_dataset = train_dataset   | ||||
|         print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
|         print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
|         self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) | ||||
|  | ||||
|         self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|         self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|  | ||||
| @@ -253,6 +276,9 @@ class DataInfos(AbstractDatasetInfos): | ||||
|  | ||||
|  | ||||
| def compute_meta(root, source_name, train_index, test_index): | ||||
|     # initialize the periodic table | ||||
|     # 118 elements + 1 for * | ||||
|     # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. | ||||
|     pt = Chem.GetPeriodicTable() | ||||
|     atom_name_list = [] | ||||
|     atom_count_list = [] | ||||
| @@ -267,11 +293,13 @@ def compute_meta(root, source_name, train_index, test_index): | ||||
|     valencies = [0] * 500 | ||||
|     tansition_E = np.zeros((118, 118, 5)) | ||||
|      | ||||
|     # Load the data from the source file | ||||
|     filename = f'{source_name}.csv.gz' | ||||
|     df = pd.read_csv(f'{root}/{filename}') | ||||
|     all_index = list(range(len(df))) | ||||
|     non_test_index = list(set(all_index) - set(test_index)) | ||||
|     df = df.iloc[non_test_index] | ||||
|     # extract the smiles from the dataframe | ||||
|     tot_smiles = df['smiles'].tolist() | ||||
|  | ||||
|     n_atom_list = [] | ||||
| @@ -323,6 +351,11 @@ def compute_meta(root, source_name, train_index, test_index): | ||||
|             bond_index = bond_type_to_index[bond_type] | ||||
|             bond_count_list[bond_index] += 2 | ||||
|  | ||||
|             # Update the transition matrix | ||||
|             # The transition matrix is symmetric, so we update both directions | ||||
|             # We also update the temporary transition matrix to check for errors | ||||
|             # in the atom count | ||||
|              | ||||
|             tansition_E[start_index, end_index, bond_index] += 2 | ||||
|             tansition_E[end_index, start_index, bond_index] += 2 | ||||
|             tansition_E_temp[start_index, end_index, bond_index] += 2 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user