diff --git a/graph_dit/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py index 0d0d9f9..c8e82c5 100644 --- a/graph_dit/datasets/abstract_dataset.py +++ b/graph_dit/datasets/abstract_dataset.py @@ -123,4 +123,8 @@ class AbstractDatasetInfos: 'y': example_batch['y'].size(1)} self.output_dims = {'X': example_batch_x.size(1), 'E': example_batch_edge_attr.size(1), - 'y': example_batch['y'].size(1)} \ No newline at end of file + 'y': example_batch['y'].size(1)} + print('input dims') + print(self.input_dims) + print('output dims') + print(self.output_dims) \ No newline at end of file diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index a30d139..0c12a1f 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -28,19 +28,38 @@ class DataModule(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir self.task = cfg.dataset.task_name + print("DataModule") + print("task", self.task) + print("datadir`",self.datadir) super().__init__(cfg) def prepare_data(self) -> None: target = getattr(self.cfg.dataset, 'guidance_target', None) + print("target", target) base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] root_path = os.path.join(base_path, self.datadir) self.root_path = root_path batch_size = self.cfg.train.batch_size + num_workers = self.cfg.train.num_workers pin_memory = self.cfg.dataset.pin_memory + # Load the dataset to the memory + # Dataset has target property, root path, and transform dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) + print("len dataset", len(dataset)) + def print_data(dataset): + print("dataset", dataset) + print("dataset keys", dataset.keys) + print("dataset x", dataset.x) + print("dataset edge_index", dataset.edge_index) + print("dataset edge_attr", dataset.edge_attr) + print("dataset y", dataset.y) + print("") + print_data(dataset=dataset[0]) + print_data(dataset=dataset[1]) + if len(self.task.split('-')) == 2: train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) @@ -53,8 +72,12 @@ class DataModule(AbstractDataModule): train_index = torch.cat([train_index, unlabeled_index], dim=0) train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] - self.train_dataset = train_dataset + self.train_dataset = train_dataset + print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) + print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) + print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) + self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) @@ -253,6 +276,9 @@ class DataInfos(AbstractDatasetInfos): def compute_meta(root, source_name, train_index, test_index): + # initialize the periodic table + # 118 elements + 1 for * + # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. pt = Chem.GetPeriodicTable() atom_name_list = [] atom_count_list = [] @@ -267,11 +293,13 @@ def compute_meta(root, source_name, train_index, test_index): valencies = [0] * 500 tansition_E = np.zeros((118, 118, 5)) + # Load the data from the source file filename = f'{source_name}.csv.gz' df = pd.read_csv(f'{root}/{filename}') all_index = list(range(len(df))) non_test_index = list(set(all_index) - set(test_index)) df = df.iloc[non_test_index] + # extract the smiles from the dataframe tot_smiles = df['smiles'].tolist() n_atom_list = [] @@ -323,6 +351,11 @@ def compute_meta(root, source_name, train_index, test_index): bond_index = bond_type_to_index[bond_type] bond_count_list[bond_index] += 2 + # Update the transition matrix + # The transition matrix is symmetric, so we update both directions + # We also update the temporary transition matrix to check for errors + # in the atom count + tansition_E[start_index, end_index, bond_index] += 2 tansition_E[end_index, start_index, bond_index] += 2 tansition_E_temp[start_index, end_index, bond_index] += 2 diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 4a0c9a6..5595c1c 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -76,12 +76,16 @@ class Graph_DiT(pl.LightningModule): timesteps=cfg.model.diffusion_steps) + print("__init__") + print("dataset_info.node_types", self.dataset_info.node_types) + # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float()) x_marginals = x_marginals / (x_marginals ).sum() e_marginals = e_marginals / (e_marginals ).sum() + # transition e is the probability of transitioning from x1 to x2 with e xe_conditions = self.dataset_info.transition_E.float() xe_conditions = xe_conditions[self.active_index][:, self.active_index] diff --git a/graph_dit/main.py b/graph_dit/main.py index fb4f4ea..2dcd97a 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -82,6 +82,7 @@ def main(cfg: DictConfig): dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) train_smiles, reference_smiles = datamodule.get_train_smiles() + # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) diff --git a/graph_dit/metrics/molecular_metrics_train.py b/graph_dit/metrics/molecular_metrics_train.py index f9f8779..7b4fd0f 100644 --- a/graph_dit/metrics/molecular_metrics_train.py +++ b/graph_dit/metrics/molecular_metrics_train.py @@ -84,7 +84,7 @@ class BondMetricsCE(MetricCollection): ce_TR = TripleCE(3) super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) - +# class TrainMolecularMetricsDiscrete(nn.Module): def __init__(self, dataset_infos): super().__init__() diff --git a/graph_dit/models/transformer.py b/graph_dit/models/transformer.py index a568191..4fcb2f2 100644 --- a/graph_dit/models/transformer.py +++ b/graph_dit/models/transformer.py @@ -75,28 +75,55 @@ class Denoiser(nn.Module): _constant_init(block.adaLN_modulation[0], 0) _constant_init(self.out_layer.adaLN_modulation[0], 0) + """ + Input Parameters: + x: Node features. + e: Edge features. + node_mask: Mask indicating valid nodes. + y: Condition features. + t: Current timestep in the diffusion process. + unconditioned: Boolean flag indicating whether to ignore conditions. + """ def forward(self, x, e, node_mask, y, t, unconditioned): + print("Denoiser Forward") + print(x.shape, e.shape, y.shape, t.shape, unconditioned) force_drop_id = torch.zeros_like(y.sum(-1)) + # drop the nan values force_drop_id[torch.isnan(y.sum(-1))] = 1 if unconditioned: force_drop_id = torch.ones_like(y[:, 0]) x_in, e_in, y_in = x, e, y + # bs = batch size, n = number of nodes bs, n, _ = x.size() x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) + print("X after concat with E") + print(x.shape) + # self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) x = self.x_embedder(x) + print("X after x_embedder") + print(x.shape) + # self.t_embedder = TimestepEmbedder(hidden_size) c1 = self.t_embedder(t) + print("C1 after t_embedder") + print(c1.shape) for i in range(1, self.ydim): if i == 1: c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) else: c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) + print("C2 after y_embedding_list") + print(c2.shape) + print("C1 + C2") c = c1 + c2 + print(c.shape) for i, block in enumerate(self.encoders): x = block(x, c, node_mask) + print("X after block") + print(x.shape) # X: B * N * dx, E: B * N * N * de X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)