94 lines
4.9 KiB
Python
94 lines
4.9 KiB
Python
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
from metrics.abstract_metrics import CrossEntropyMetric
|
|
from torchmetrics import Metric, MeanSquaredError
|
|
|
|
# from 2:He to 119:*
|
|
valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
|
valencies_check = torch.tensor(valencies_check)
|
|
|
|
weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0]
|
|
weight_check = torch.tensor(weight_check)
|
|
|
|
class AtomWeightMetric(Metric):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
|
|
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
|
global weight_check
|
|
self.weight_check = weight_check
|
|
|
|
def update(self, X, Y):
|
|
atom_pred_num = X.argmax(dim=-1)
|
|
atom_real_num = Y.argmax(dim=-1)
|
|
self.weight_check = self.weight_check.type_as(X)
|
|
|
|
pred_weight = self.weight_check[atom_pred_num]
|
|
real_weight = self.weight_check[atom_real_num]
|
|
|
|
lss = 0
|
|
lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum()
|
|
self.total_loss += lss
|
|
self.total_samples += X.size(0)
|
|
|
|
def compute(self):
|
|
return self.total_loss / self.total_samples
|
|
|
|
|
|
class TrainLossDiscrete(nn.Module):
|
|
""" Train with Cross entropy"""
|
|
def __init__(self, lambda_train, weight_node=None, weight_edge=None):
|
|
super().__init__()
|
|
self.node_loss = CrossEntropyMetric()
|
|
self.edge_loss = CrossEntropyMetric()
|
|
self.weight_loss = AtomWeightMetric()
|
|
|
|
self.y_loss = MeanSquaredError()
|
|
self.lambda_train = lambda_train
|
|
|
|
def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool):
|
|
""" Compute train metrics
|
|
masked_pred_X : tensor -- (bs, n, dx)
|
|
masked_pred_E : tensor -- (bs, n, n, de)
|
|
pred_y : tensor -- (bs, )
|
|
true_X : tensor -- (bs, n, dx)
|
|
true_E : tensor -- (bs, n, n, de)
|
|
true_y : tensor -- (bs, )
|
|
log : boolean. """
|
|
|
|
loss_weight = self.weight_loss(masked_pred_X, true_X)
|
|
|
|
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
|
|
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
|
|
masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx)
|
|
masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de)
|
|
|
|
# Remove masked rows
|
|
mask_X = (true_X != 0.).any(dim=-1)
|
|
mask_E = (true_E != 0.).any(dim=-1)
|
|
|
|
flat_true_X = true_X[mask_X, :]
|
|
flat_pred_X = masked_pred_X[mask_X, :]
|
|
|
|
flat_true_E = true_E[mask_E, :]
|
|
flat_pred_E = masked_pred_E[mask_E, :]
|
|
|
|
loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
|
|
loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
|
|
|
|
return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight
|
|
|
|
def reset(self):
|
|
for metric in [self.node_loss, self.edge_loss, self.y_loss]:
|
|
metric.reset()
|
|
|
|
def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True):
|
|
epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
|
|
epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
|
|
epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1
|
|
|
|
if log:
|
|
print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} "
|
|
f"Weight: {epoch_weight_loss :.4f} "
|
|
f"-- Time taken {time.time() - start_epoch_time:.1f}s ") |