add sample phase and try to get log prob
This commit is contained in:
@@ -286,7 +286,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||
save_final=to_save,
|
||||
keep_chain=chains_save,
|
||||
number_chain_steps=self.number_chain_steps))
|
||||
number_chain_steps=self.number_chain_steps)[0])
|
||||
ident += to_generate
|
||||
start_index += to_generate
|
||||
|
||||
@@ -360,7 +360,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||
|
||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0]
|
||||
samples = samples + cur_sample
|
||||
|
||||
all_ys.append(batch_y)
|
||||
@@ -601,6 +601,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
total_log_probs = torch.zeros(batch_size, device=self.device)
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
for s_int in reversed(range(0, self.T)):
|
||||
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
||||
@@ -609,21 +611,22 @@ class Graph_DiT(pl.LightningModule):
|
||||
t_norm = t_array / self.T
|
||||
|
||||
# Sample z_s
|
||||
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||
total_log_probs += log_probs
|
||||
|
||||
# Sample
|
||||
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||
|
||||
molecule_list = []
|
||||
graph_list = []
|
||||
for i in range(batch_size):
|
||||
n = n_nodes[i]
|
||||
atom_types = X[i, :n].cpu()
|
||||
node_types = X[i, :n].cpu()
|
||||
edge_types = E[i, :n, :n].cpu()
|
||||
molecule_list.append([atom_types, edge_types])
|
||||
graph_list.append([node_types, edge_types])
|
||||
|
||||
return molecule_list
|
||||
return graph_list, total_log_probs
|
||||
|
||||
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
||||
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
||||
@@ -635,6 +638,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
# Neural net predictions
|
||||
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
|
||||
print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}")
|
||||
|
||||
def get_prob(noisy_data, unconditioned=False):
|
||||
pred = self.forward(noisy_data, unconditioned=unconditioned)
|
||||
@@ -674,6 +678,17 @@ class Graph_DiT(pl.LightningModule):
|
||||
# with condition = P_t(G_{t-1} |G_t, C)
|
||||
# with condition = P_t(A_{t-1} |A_t, y)
|
||||
prob_X, prob_E, pred = get_prob(noisy_data)
|
||||
print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}')
|
||||
print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}')
|
||||
print(f'X_t: {X_t}')
|
||||
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
|
||||
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
|
||||
|
||||
# Sum the log_prob across dimensions for total log_prob
|
||||
log_prob_X = log_prob_X.sum(dim=-1)
|
||||
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
||||
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
|
||||
log_probs = log_prob_E + log_prob_X
|
||||
|
||||
### Guidance
|
||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||
@@ -810,4 +825,4 @@ class Graph_DiT(pl.LightningModule):
|
||||
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||
|
||||
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
|
||||
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs
|
||||
|
||||
Reference in New Issue
Block a user