import time import torch import torch.nn as nn from metrics.abstract_metrics import CrossEntropyMetric from torchmetrics import Metric, MeanSquaredError # from 2:He to 119:* valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] valencies_check = torch.tensor(valencies_check) weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0] weight_check = torch.tensor(weight_check) class AtomWeightMetric(Metric): def __init__(self): super().__init__() self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") global weight_check self.weight_check = weight_check def update(self, X, Y): atom_pred_num = X.argmax(dim=-1) atom_real_num = Y.argmax(dim=-1) self.weight_check = self.weight_check.type_as(X) pred_weight = self.weight_check[atom_pred_num] real_weight = self.weight_check[atom_real_num] lss = 0 lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum() self.total_loss += lss self.total_samples += X.size(0) def compute(self): return self.total_loss / self.total_samples class TrainLossDiscrete(nn.Module): """ Train with Cross entropy""" def __init__(self, lambda_train, weight_node=None, weight_edge=None): super().__init__() self.node_loss = CrossEntropyMetric() self.edge_loss = CrossEntropyMetric() self.weight_loss = AtomWeightMetric() self.y_loss = MeanSquaredError() self.lambda_train = lambda_train def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool): """ Compute train metrics masked_pred_X : tensor -- (bs, n, dx) masked_pred_E : tensor -- (bs, n, n, de) pred_y : tensor -- (bs, ) true_X : tensor -- (bs, n, dx) true_E : tensor -- (bs, n, n, de) true_y : tensor -- (bs, ) log : boolean. """ loss_weight = self.weight_loss(masked_pred_X, true_X) true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx) true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de) masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx) masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de) # Remove masked rows mask_X = (true_X != 0.).any(dim=-1) mask_E = (true_E != 0.).any(dim=-1) flat_true_X = true_X[mask_X, :] flat_pred_X = masked_pred_X[mask_X, :] flat_true_E = true_E[mask_E, :] flat_pred_E = masked_pred_E[mask_E, :] loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight def reset(self): for metric in [self.node_loss, self.edge_loss, self.y_loss]: metric.reset() def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True): epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1 if log: print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} " f"Weight: {epoch_weight_loss :.4f} " f"-- Time taken {time.time() - start_epoch_time:.1f}s ")