set batch_y to 1 and want to test 15625
This commit is contained in:
parent
1fa2d49c11
commit
3950a8438d
@ -781,7 +781,7 @@ class Dataset(InMemoryDataset):
|
|||||||
print(f'idx={idx}, y={y}')
|
print(f'idx={idx}, y={y}')
|
||||||
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||||
return None
|
# return None
|
||||||
return data
|
return data
|
||||||
graph_list = []
|
graph_list = []
|
||||||
class Args:
|
class Args:
|
||||||
|
@ -356,7 +356,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
to_generate = min(samples_left_to_generate, bs)
|
to_generate = min(samples_left_to_generate, bs)
|
||||||
to_save = min(samples_left_to_save, bs)
|
to_save = min(samples_left_to_save, bs)
|
||||||
chains_save = min(chains_left_to_save, bs)
|
chains_save = min(chains_left_to_save, bs)
|
||||||
batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
|
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||||
|
|
||||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||||
|
Loading…
Reference in New Issue
Block a user