add somecomments
This commit is contained in:
		| @@ -123,4 +123,8 @@ class AbstractDatasetInfos: | |||||||
|                            'y': example_batch['y'].size(1)} |                            'y': example_batch['y'].size(1)} | ||||||
|         self.output_dims = {'X': example_batch_x.size(1), |         self.output_dims = {'X': example_batch_x.size(1), | ||||||
|                             'E': example_batch_edge_attr.size(1), |                             'E': example_batch_edge_attr.size(1), | ||||||
|                             'y': example_batch['y'].size(1)} |                             'y': example_batch['y'].size(1)} | ||||||
|  |         print('input dims') | ||||||
|  |         print(self.input_dims) | ||||||
|  |         print('output dims') | ||||||
|  |         print(self.output_dims) | ||||||
| @@ -28,19 +28,38 @@ class DataModule(AbstractDataModule): | |||||||
|     def __init__(self, cfg): |     def __init__(self, cfg): | ||||||
|         self.datadir = cfg.dataset.datadir |         self.datadir = cfg.dataset.datadir | ||||||
|         self.task = cfg.dataset.task_name |         self.task = cfg.dataset.task_name | ||||||
|  |         print("DataModule") | ||||||
|  |         print("task", self.task) | ||||||
|  |         print("datadir`",self.datadir) | ||||||
|         super().__init__(cfg) |         super().__init__(cfg) | ||||||
|  |  | ||||||
|     def prepare_data(self) -> None: |     def prepare_data(self) -> None: | ||||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) |         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||||
|  |         print("target", target) | ||||||
|         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] |         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||||
|         root_path = os.path.join(base_path, self.datadir) |         root_path = os.path.join(base_path, self.datadir) | ||||||
|         self.root_path = root_path |         self.root_path = root_path | ||||||
|  |  | ||||||
|         batch_size = self.cfg.train.batch_size |         batch_size = self.cfg.train.batch_size | ||||||
|  |          | ||||||
|         num_workers = self.cfg.train.num_workers |         num_workers = self.cfg.train.num_workers | ||||||
|         pin_memory = self.cfg.dataset.pin_memory |         pin_memory = self.cfg.dataset.pin_memory | ||||||
|  |  | ||||||
|  |         # Load the dataset to the memory | ||||||
|  |         # Dataset has target property, root path, and transform | ||||||
|         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) |         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) | ||||||
|  |         print("len dataset", len(dataset)) | ||||||
|  |         def print_data(dataset): | ||||||
|  |             print("dataset", dataset) | ||||||
|  |             print("dataset keys", dataset.keys) | ||||||
|  |             print("dataset x", dataset.x) | ||||||
|  |             print("dataset edge_index", dataset.edge_index) | ||||||
|  |             print("dataset edge_attr", dataset.edge_attr) | ||||||
|  |             print("dataset y", dataset.y) | ||||||
|  |             print("") | ||||||
|  |         print_data(dataset=dataset[0]) | ||||||
|  |         print_data(dataset=dataset[1]) | ||||||
|  |  | ||||||
|  |  | ||||||
|         if len(self.task.split('-')) == 2: |         if len(self.task.split('-')) == 2: | ||||||
|             train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) |             train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||||
| @@ -53,8 +72,12 @@ class DataModule(AbstractDataModule): | |||||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) |             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||||
|          |          | ||||||
|         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] |         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] | ||||||
|         self.train_dataset = train_dataset |         self.train_dataset = train_dataset   | ||||||
|  |         print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||||
|  |         print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||||
|  |         print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||||
|         self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) |         self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) | ||||||
|  |  | ||||||
|         self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) |         self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||||
|         self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) |         self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||||
|  |  | ||||||
| @@ -253,6 +276,9 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|  |  | ||||||
|  |  | ||||||
| def compute_meta(root, source_name, train_index, test_index): | def compute_meta(root, source_name, train_index, test_index): | ||||||
|  |     # initialize the periodic table | ||||||
|  |     # 118 elements + 1 for * | ||||||
|  |     # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. | ||||||
|     pt = Chem.GetPeriodicTable() |     pt = Chem.GetPeriodicTable() | ||||||
|     atom_name_list = [] |     atom_name_list = [] | ||||||
|     atom_count_list = [] |     atom_count_list = [] | ||||||
| @@ -267,11 +293,13 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|     valencies = [0] * 500 |     valencies = [0] * 500 | ||||||
|     tansition_E = np.zeros((118, 118, 5)) |     tansition_E = np.zeros((118, 118, 5)) | ||||||
|      |      | ||||||
|  |     # Load the data from the source file | ||||||
|     filename = f'{source_name}.csv.gz' |     filename = f'{source_name}.csv.gz' | ||||||
|     df = pd.read_csv(f'{root}/{filename}') |     df = pd.read_csv(f'{root}/{filename}') | ||||||
|     all_index = list(range(len(df))) |     all_index = list(range(len(df))) | ||||||
|     non_test_index = list(set(all_index) - set(test_index)) |     non_test_index = list(set(all_index) - set(test_index)) | ||||||
|     df = df.iloc[non_test_index] |     df = df.iloc[non_test_index] | ||||||
|  |     # extract the smiles from the dataframe | ||||||
|     tot_smiles = df['smiles'].tolist() |     tot_smiles = df['smiles'].tolist() | ||||||
|  |  | ||||||
|     n_atom_list = [] |     n_atom_list = [] | ||||||
| @@ -323,6 +351,11 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|             bond_index = bond_type_to_index[bond_type] |             bond_index = bond_type_to_index[bond_type] | ||||||
|             bond_count_list[bond_index] += 2 |             bond_count_list[bond_index] += 2 | ||||||
|  |  | ||||||
|  |             # Update the transition matrix | ||||||
|  |             # The transition matrix is symmetric, so we update both directions | ||||||
|  |             # We also update the temporary transition matrix to check for errors | ||||||
|  |             # in the atom count | ||||||
|  |              | ||||||
|             tansition_E[start_index, end_index, bond_index] += 2 |             tansition_E[start_index, end_index, bond_index] += 2 | ||||||
|             tansition_E[end_index, start_index, bond_index] += 2 |             tansition_E[end_index, start_index, bond_index] += 2 | ||||||
|             tansition_E_temp[start_index, end_index, bond_index] += 2 |             tansition_E_temp[start_index, end_index, bond_index] += 2 | ||||||
|   | |||||||
| @@ -76,12 +76,16 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                                                               timesteps=cfg.model.diffusion_steps) |                                                               timesteps=cfg.model.diffusion_steps) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         print("__init__") | ||||||
|  |         print("dataset_info.node_types", self.dataset_info.node_types) | ||||||
|  |         # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) | ||||||
|         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) |         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) | ||||||
|          |          | ||||||
|         e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float()) |         e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float()) | ||||||
|         x_marginals = x_marginals / (x_marginals ).sum() |         x_marginals = x_marginals / (x_marginals ).sum() | ||||||
|         e_marginals = e_marginals / (e_marginals ).sum() |         e_marginals = e_marginals / (e_marginals ).sum() | ||||||
|  |  | ||||||
|  |         # transition e is the probability of transitioning from x1 to x2 with e | ||||||
|         xe_conditions = self.dataset_info.transition_E.float() |         xe_conditions = self.dataset_info.transition_E.float() | ||||||
|         xe_conditions = xe_conditions[self.active_index][:, self.active_index]  |         xe_conditions = xe_conditions[self.active_index][:, self.active_index]  | ||||||
|          |          | ||||||
|   | |||||||
| @@ -82,6 +82,7 @@ def main(cfg: DictConfig): | |||||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) |     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) | ||||||
|     train_smiles, reference_smiles = datamodule.get_train_smiles() |     train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||||
|  |  | ||||||
|  |     # get input output dimensions | ||||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) |     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||||
|     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) |     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -84,7 +84,7 @@ class BondMetricsCE(MetricCollection): | |||||||
|         ce_TR = TripleCE(3) |         ce_TR = TripleCE(3) | ||||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) |         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||||
|  |  | ||||||
|  | #  | ||||||
| class TrainMolecularMetricsDiscrete(nn.Module): | class TrainMolecularMetricsDiscrete(nn.Module): | ||||||
|     def __init__(self, dataset_infos): |     def __init__(self, dataset_infos): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|   | |||||||
| @@ -75,28 +75,55 @@ class Denoiser(nn.Module): | |||||||
|             _constant_init(block.adaLN_modulation[0], 0) |             _constant_init(block.adaLN_modulation[0], 0) | ||||||
|         _constant_init(self.out_layer.adaLN_modulation[0], 0) |         _constant_init(self.out_layer.adaLN_modulation[0], 0) | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     Input Parameters: | ||||||
|  |     x: Node features. | ||||||
|  |     e: Edge features. | ||||||
|  |     node_mask: Mask indicating valid nodes. | ||||||
|  |     y: Condition features. | ||||||
|  |     t: Current timestep in the diffusion process. | ||||||
|  |     unconditioned: Boolean flag indicating whether to ignore conditions. | ||||||
|  |     """ | ||||||
|     def forward(self, x, e, node_mask, y, t, unconditioned): |     def forward(self, x, e, node_mask, y, t, unconditioned): | ||||||
|          |          | ||||||
|  |         print("Denoiser Forward") | ||||||
|  |         print(x.shape, e.shape, y.shape, t.shape, unconditioned) | ||||||
|         force_drop_id = torch.zeros_like(y.sum(-1)) |         force_drop_id = torch.zeros_like(y.sum(-1)) | ||||||
|  |         # drop the nan values | ||||||
|         force_drop_id[torch.isnan(y.sum(-1))] = 1 |         force_drop_id[torch.isnan(y.sum(-1))] = 1 | ||||||
|         if unconditioned: |         if unconditioned: | ||||||
|             force_drop_id = torch.ones_like(y[:, 0]) |             force_drop_id = torch.ones_like(y[:, 0]) | ||||||
|          |          | ||||||
|         x_in, e_in, y_in = x, e, y |         x_in, e_in, y_in = x, e, y | ||||||
|  |         # bs = batch size, n = number of nodes | ||||||
|         bs, n, _ = x.size() |         bs, n, _ = x.size() | ||||||
|         x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) |         x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) | ||||||
|  |         print("X after concat with E") | ||||||
|  |         print(x.shape) | ||||||
|  |         # self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) | ||||||
|         x = self.x_embedder(x) |         x = self.x_embedder(x) | ||||||
|  |         print("X after x_embedder") | ||||||
|  |         print(x.shape) | ||||||
|  |  | ||||||
|  |         # self.t_embedder = TimestepEmbedder(hidden_size) | ||||||
|         c1 = self.t_embedder(t) |         c1 = self.t_embedder(t) | ||||||
|  |         print("C1 after t_embedder") | ||||||
|  |         print(c1.shape) | ||||||
|         for i in range(1, self.ydim): |         for i in range(1, self.ydim): | ||||||
|             if i == 1: |             if i == 1: | ||||||
|                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) |                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||||
|             else: |             else: | ||||||
|                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) |                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||||
|  |         print("C2 after y_embedding_list") | ||||||
|  |         print(c2.shape) | ||||||
|  |         print("C1 + C2") | ||||||
|         c = c1 + c2 |         c = c1 + c2 | ||||||
|  |         print(c.shape) | ||||||
|          |          | ||||||
|         for i, block in enumerate(self.encoders): |         for i, block in enumerate(self.encoders): | ||||||
|             x = block(x, c, node_mask) |             x = block(x, c, node_mask) | ||||||
|  |         print("X after block") | ||||||
|  |         print(x.shape) | ||||||
|  |  | ||||||
|         # X: B * N * dx, E: B * N * N * de |         # X: B * N * dx, E: B * N * N * de | ||||||
|         X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask) |         X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user