add sample phase and try to get log prob

This commit is contained in:
mhz
2024-09-08 23:26:49 +02:00
parent 0c4b597dd2
commit 5dccf590e7
2 changed files with 76 additions and 26 deletions

View File

@@ -286,7 +286,7 @@ class Graph_DiT(pl.LightningModule):
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
save_final=to_save,
keep_chain=chains_save,
number_chain_steps=self.number_chain_steps))
number_chain_steps=self.number_chain_steps)[0])
ident += to_generate
start_index += to_generate
@@ -360,7 +360,7 @@ class Graph_DiT(pl.LightningModule):
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0]
samples = samples + cur_sample
all_ys.append(batch_y)
@@ -601,6 +601,8 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
total_log_probs = torch.zeros(batch_size, device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
@@ -609,21 +611,22 @@ class Graph_DiT(pl.LightningModule):
t_norm = t_array / self.T
# Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
total_log_probs += log_probs
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
graph_list = []
for i in range(batch_size):
n = n_nodes[i]
atom_types = X[i, :n].cpu()
node_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
graph_list.append([node_types, edge_types])
return molecule_list
return graph_list, total_log_probs
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
@@ -635,6 +638,7 @@ class Graph_DiT(pl.LightningModule):
# Neural net predictions
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}")
def get_prob(noisy_data, unconditioned=False):
pred = self.forward(noisy_data, unconditioned=unconditioned)
@@ -674,6 +678,17 @@ class Graph_DiT(pl.LightningModule):
# with condition = P_t(G_{t-1} |G_t, C)
# with condition = P_t(A_{t-1} |A_t, y)
prob_X, prob_E, pred = get_prob(noisy_data)
print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}')
print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}')
print(f'X_t: {X_t}')
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
# Sum the log_prob across dimensions for total log_prob
log_prob_X = log_prob_X.sum(dim=-1)
log_prob_E = log_prob_E.sum(dim=(1, 2))
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
log_probs = log_prob_E + log_prob_X
### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
@@ -810,4 +825,4 @@ class Graph_DiT(pl.LightningModule):
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs