add sample phase and try to get log prob
This commit is contained in:
parent
0c4b597dd2
commit
5dccf590e7
@ -286,7 +286,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||||
save_final=to_save,
|
save_final=to_save,
|
||||||
keep_chain=chains_save,
|
keep_chain=chains_save,
|
||||||
number_chain_steps=self.number_chain_steps))
|
number_chain_steps=self.number_chain_steps)[0])
|
||||||
ident += to_generate
|
ident += to_generate
|
||||||
start_index += to_generate
|
start_index += to_generate
|
||||||
|
|
||||||
@ -360,7 +360,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||||
|
|
||||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0]
|
||||||
samples = samples + cur_sample
|
samples = samples + cur_sample
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
@ -601,6 +601,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
assert (E == torch.transpose(E, 1, 2)).all()
|
assert (E == torch.transpose(E, 1, 2)).all()
|
||||||
|
|
||||||
|
total_log_probs = torch.zeros(batch_size, device=self.device)
|
||||||
|
|
||||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||||
for s_int in reversed(range(0, self.T)):
|
for s_int in reversed(range(0, self.T)):
|
||||||
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
||||||
@ -609,21 +611,22 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
t_norm = t_array / self.T
|
t_norm = t_array / self.T
|
||||||
|
|
||||||
# Sample z_s
|
# Sample z_s
|
||||||
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||||
|
total_log_probs += log_probs
|
||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
||||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||||
|
|
||||||
molecule_list = []
|
graph_list = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
n = n_nodes[i]
|
n = n_nodes[i]
|
||||||
atom_types = X[i, :n].cpu()
|
node_types = X[i, :n].cpu()
|
||||||
edge_types = E[i, :n, :n].cpu()
|
edge_types = E[i, :n, :n].cpu()
|
||||||
molecule_list.append([atom_types, edge_types])
|
graph_list.append([node_types, edge_types])
|
||||||
|
|
||||||
return molecule_list
|
return graph_list, total_log_probs
|
||||||
|
|
||||||
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
||||||
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
||||||
@ -635,6 +638,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
# Neural net predictions
|
# Neural net predictions
|
||||||
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
|
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
|
||||||
|
print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}")
|
||||||
|
|
||||||
def get_prob(noisy_data, unconditioned=False):
|
def get_prob(noisy_data, unconditioned=False):
|
||||||
pred = self.forward(noisy_data, unconditioned=unconditioned)
|
pred = self.forward(noisy_data, unconditioned=unconditioned)
|
||||||
@ -674,6 +678,17 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# with condition = P_t(G_{t-1} |G_t, C)
|
# with condition = P_t(G_{t-1} |G_t, C)
|
||||||
# with condition = P_t(A_{t-1} |A_t, y)
|
# with condition = P_t(A_{t-1} |A_t, y)
|
||||||
prob_X, prob_E, pred = get_prob(noisy_data)
|
prob_X, prob_E, pred = get_prob(noisy_data)
|
||||||
|
print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}')
|
||||||
|
print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}')
|
||||||
|
print(f'X_t: {X_t}')
|
||||||
|
log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1)) # bs, n
|
||||||
|
log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1)) # bs, n, n
|
||||||
|
|
||||||
|
# Sum the log_prob across dimensions for total log_prob
|
||||||
|
log_prob_X = log_prob_X.sum(dim=-1)
|
||||||
|
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
||||||
|
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
|
||||||
|
log_probs = log_prob_E + log_prob_X
|
||||||
|
|
||||||
### Guidance
|
### Guidance
|
||||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||||
@ -810,4 +825,4 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||||
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||||
|
|
||||||
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
|
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# These imports are tricky because they use c++, do not move them
|
# These imports are tricky because they use c++, do not move them
|
||||||
import tqdm
|
from tqdm import tqdm
|
||||||
import os, shutil
|
import os, shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -232,29 +232,64 @@ def test(cfg: DictConfig):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# return {'loss': loss}
|
# return {'loss': loss}
|
||||||
|
|
||||||
|
# start testing
|
||||||
|
print("start testing")
|
||||||
|
graph_dit_model.eval()
|
||||||
|
test_dataloader = accelerator.prepare(datamodule.test_dataloader())
|
||||||
|
for data in test_dataloader:
|
||||||
|
data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index]
|
||||||
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||||
|
|
||||||
|
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes)
|
||||||
|
dense_data = dense_data.mask(node_mask)
|
||||||
|
noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||||
|
pred = graph_dit_model.forward(noisy_data)
|
||||||
|
nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True)
|
||||||
|
graph_dit_model.test_y_collection.append(data.y)
|
||||||
|
print(f'test loss: {nll}')
|
||||||
|
|
||||||
# start sampling
|
# start sampling
|
||||||
|
|
||||||
samples = []
|
samples_left_to_generate = cfg.general.final_model_samples_to_generate
|
||||||
|
samples_left_to_save = cfg.general.final_model_samples_to_save
|
||||||
|
chains_left_to_save = cfg.general.final_model_chains_to_save
|
||||||
|
|
||||||
for i in tqdm(
|
samples, all_ys, batch_id = [], [], 0
|
||||||
range(cfg.general.n_samples), desc="Sampling", disable=not cfg.general.enable_progress_bar
|
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
|
||||||
):
|
num_examples = test_y_collection.size(0)
|
||||||
batch_size = cfg.train.batch_size
|
if cfg.general.final_model_samples_to_generate > num_examples:
|
||||||
num_steps = cfg.model.diffusion_steps
|
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
||||||
y = torch.ones(batch_size, num_steps, 1, 1, device=accelerator.device, dtype=inference_dtype)
|
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
||||||
|
num_examples = test_y_collection.size(0)
|
||||||
|
|
||||||
|
while samples_left_to_generate > 0:
|
||||||
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
|
bs = 1 * cfg.train.batch_size
|
||||||
|
to_generate = min(samples_left_to_generate, bs)
|
||||||
|
to_save = min(samples_left_to_save, bs)
|
||||||
|
chains_save = min(chains_left_to_save, bs)
|
||||||
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
|
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
|
|
||||||
# sample from the model
|
cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
samples_batch = graph_dit_model.sample_batch(
|
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0]
|
||||||
batch_id=i,
|
samples = samples + cur_sample
|
||||||
batch_size=batch_size,
|
|
||||||
y=y,
|
|
||||||
keep_chain=1,
|
|
||||||
number_chain_steps=num_steps,
|
|
||||||
save_final=batch_size
|
|
||||||
)
|
|
||||||
samples.append(samples_batch)
|
|
||||||
|
|
||||||
|
all_ys.append(batch_y)
|
||||||
|
batch_id += to_generate
|
||||||
|
|
||||||
|
samples_left_to_save -= to_save
|
||||||
|
samples_left_to_generate -= to_generate
|
||||||
|
chains_left_to_save -= chains_save
|
||||||
|
|
||||||
|
print(f"final Computing sampling metrics...")
|
||||||
|
graph_dit_model.sampling_metrics.reset()
|
||||||
|
graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
||||||
|
graph_dit_model.sampling_metrics.reset()
|
||||||
|
print(f"Done.")
|
||||||
|
|
||||||
# save samples
|
# save samples
|
||||||
print("Samples:")
|
print("Samples:")
|
||||||
print(samples)
|
print(samples)
|
||||||
|
Loading…
Reference in New Issue
Block a user