try to deploy PPO policy
This commit is contained in:
parent
297261d666
commit
97fbdf91c7
@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
assert (E == torch.transpose(E, 1, 2)).all()
|
assert (E == torch.transpose(E, 1, 2)).all()
|
||||||
|
|
||||||
total_log_probs = torch.zeros(batch_size, device=self.device)
|
total_log_probs = torch.zeros([1000,10], device=self.device)
|
||||||
|
|
||||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||||
for s_int in reversed(range(0, self.T)):
|
for s_int in reversed(range(0, self.T)):
|
||||||
@ -613,6 +613,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# Sample z_s
|
# Sample z_s
|
||||||
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||||
|
print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}')
|
||||||
|
print(f'log_probs shape: {log_probs.shape}')
|
||||||
total_log_probs += log_probs
|
total_log_probs += log_probs
|
||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
@ -688,8 +690,9 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
log_prob_X = log_prob_X.sum(dim=-1)
|
log_prob_X = log_prob_X.sum(dim=-1)
|
||||||
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
||||||
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
|
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
|
||||||
log_probs = log_prob_E + log_prob_X
|
# log_probs = log_prob_E + log_prob_X
|
||||||
|
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1) # (batch_size, 2)
|
||||||
|
print(f'log_probs shape: {log_probs.shape}')
|
||||||
### Guidance
|
### Guidance
|
||||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||||
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
|
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
|
||||||
|
@ -152,6 +152,7 @@ from accelerate.utils import set_seed, ProjectConfiguration
|
|||||||
version_base="1.1", config_path="../configs", config_name="config"
|
version_base="1.1", config_path="../configs", config_name="config"
|
||||||
)
|
)
|
||||||
def test(cfg: DictConfig):
|
def test(cfg: DictConfig):
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
||||||
accelerator_config = ProjectConfiguration(
|
accelerator_config = ProjectConfiguration(
|
||||||
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
|
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
|
||||||
automatic_checkpoint_naming=True,
|
automatic_checkpoint_naming=True,
|
||||||
@ -162,6 +163,11 @@ def test(cfg: DictConfig):
|
|||||||
project_config=accelerator_config,
|
project_config=accelerator_config,
|
||||||
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
|
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Debug: 确认可用设备
|
||||||
|
print(f"Available GPUs: {torch.cuda.device_count()}")
|
||||||
|
print(f"Using device: {accelerator.device}")
|
||||||
|
|
||||||
set_seed(cfg.train.seed, device_specific=True)
|
set_seed(cfg.train.seed, device_specific=True)
|
||||||
|
|
||||||
datamodule = dataset.DataModule(cfg)
|
datamodule = dataset.DataModule(cfg)
|
||||||
@ -185,13 +191,16 @@ def test(cfg: DictConfig):
|
|||||||
"visualization_tools": visulization_tools,
|
"visualization_tools": visulization_tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Debug: 确认可用设备
|
||||||
|
print(f"Available GPUs: {torch.cuda.device_count()}")
|
||||||
|
print(f"Using device: {accelerator.device}")
|
||||||
|
|
||||||
if cfg.general.test_only:
|
if cfg.general.test_only:
|
||||||
cfg, _ = get_resume(cfg, model_kwargs)
|
cfg, _ = get_resume(cfg, model_kwargs)
|
||||||
os.chdir(cfg.general.test_only.split("checkpoints")[0])
|
os.chdir(cfg.general.test_only.split("checkpoints")[0])
|
||||||
elif cfg.general.resume is not None:
|
elif cfg.general.resume is not None:
|
||||||
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
||||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
|
||||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||||
graph_dit_model = model
|
graph_dit_model = model
|
||||||
|
|
||||||
@ -201,6 +210,7 @@ def test(cfg: DictConfig):
|
|||||||
|
|
||||||
# optional: freeze the model
|
# optional: freeze the model
|
||||||
# graph_dit_model.model.requires_grad_(True)
|
# graph_dit_model.model.requires_grad_(True)
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
optimizer = graph_dit_model.configure_optimizers()
|
optimizer = graph_dit_model.configure_optimizers()
|
||||||
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
|
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
|
||||||
@ -256,13 +266,19 @@ def test(cfg: DictConfig):
|
|||||||
chains_left_to_save = cfg.general.final_model_chains_to_save
|
chains_left_to_save = cfg.general.final_model_chains_to_save
|
||||||
|
|
||||||
samples, all_ys, batch_id = [], [], 0
|
samples, all_ys, batch_id = [], [], 0
|
||||||
|
samples_with_log_probs = []
|
||||||
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
|
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
|
||||||
num_examples = test_y_collection.size(0)
|
num_examples = test_y_collection.size(0)
|
||||||
if cfg.general.final_model_samples_to_generate > num_examples:
|
if cfg.general.final_model_samples_to_generate > num_examples:
|
||||||
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
||||||
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
||||||
num_examples = test_y_collection.size(0)
|
num_examples = test_y_collection.size(0)
|
||||||
|
def graph_reward_fn(graphs, true_graphs=None, device=None):
|
||||||
|
rewards = []
|
||||||
|
for graph in graphs:
|
||||||
|
reward = 1.0
|
||||||
|
rewards.append(reward)
|
||||||
|
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
|
||||||
while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
@ -273,9 +289,12 @@ def test(cfg: DictConfig):
|
|||||||
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
|
|
||||||
cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0]
|
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
samples = samples + cur_sample
|
samples = samples + cur_sample
|
||||||
|
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
|
|
||||||
|
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
batch_id += to_generate
|
batch_id += to_generate
|
||||||
@ -294,6 +313,35 @@ def test(cfg: DictConfig):
|
|||||||
print("Samples:")
|
print("Samples:")
|
||||||
print(samples)
|
print(samples)
|
||||||
|
|
||||||
|
perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
||||||
|
samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
|
samples = list(samples)
|
||||||
|
log_probs = list(log_probs)
|
||||||
|
print(f'log_probs: {log_probs[:5]}')
|
||||||
|
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
|
||||||
|
rewards = list(rewards)
|
||||||
|
|
||||||
|
for inner_epoch in range(cfg.train.n_epochs):
|
||||||
|
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
|
||||||
|
rewards = torch.cat(rewards, dim=0)
|
||||||
|
print(f'rewards: {rewards.shape}')
|
||||||
|
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
||||||
|
old_log_probs = log_probs.copy()
|
||||||
|
with accelerator.accumulate(graph_dit_model):
|
||||||
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
unclipped_loss = -advantages * ratio
|
||||||
|
clipped_loss = -advantages * torch.clamp(ratio,
|
||||||
|
1.0 - cfg.ppo.clip_param,
|
||||||
|
1.0 + cfg.ppo.clip_param)
|
||||||
|
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
accelerator.log({"loss": loss.item(), "epoch": inner_epoch})
|
||||||
|
print(f"loss: {loss.item()}, epoch: {inner_epoch}")
|
||||||
|
|
||||||
|
|
||||||
# trainer = Trainer(
|
# trainer = Trainer(
|
||||||
# gradient_clip_val=cfg.train.clip_grad,
|
# gradient_clip_val=cfg.train.clip_grad,
|
||||||
# # accelerator="cpu",
|
# # accelerator="cpu",
|
||||||
|
Loading…
Reference in New Issue
Block a user