try to deploy PPO policy

This commit is contained in:
mhz 2024-09-09 23:50:10 +02:00
parent 297261d666
commit 97fbdf91c7
2 changed files with 58 additions and 7 deletions

View File

@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all() assert (E == torch.transpose(E, 1, 2)).all()
total_log_probs = torch.zeros(batch_size, device=self.device) total_log_probs = torch.zeros([1000,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)): for s_int in reversed(range(0, self.T)):
@ -613,6 +613,8 @@ class Graph_DiT(pl.LightningModule):
# Sample z_s # Sample z_s
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}')
print(f'log_probs shape: {log_probs.shape}')
total_log_probs += log_probs total_log_probs += log_probs
# Sample # Sample
@ -688,8 +690,9 @@ class Graph_DiT(pl.LightningModule):
log_prob_X = log_prob_X.sum(dim=-1) log_prob_X = log_prob_X.sum(dim=-1)
log_prob_E = log_prob_E.sum(dim=(1, 2)) log_prob_E = log_prob_E.sum(dim=(1, 2))
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
log_probs = log_prob_E + log_prob_X # log_probs = log_prob_E + log_prob_X
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1) # (batch_size, 2)
print(f'log_probs shape: {log_probs.shape}')
### Guidance ### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)

View File

@ -152,6 +152,7 @@ from accelerate.utils import set_seed, ProjectConfiguration
version_base="1.1", config_path="../configs", config_name="config" version_base="1.1", config_path="../configs", config_name="config"
) )
def test(cfg: DictConfig): def test(cfg: DictConfig):
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
accelerator_config = ProjectConfiguration( accelerator_config = ProjectConfiguration(
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
automatic_checkpoint_naming=True, automatic_checkpoint_naming=True,
@ -162,6 +163,11 @@ def test(cfg: DictConfig):
project_config=accelerator_config, project_config=accelerator_config,
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
) )
# Debug: 确认可用设备
print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"Using device: {accelerator.device}")
set_seed(cfg.train.seed, device_specific=True) set_seed(cfg.train.seed, device_specific=True)
datamodule = dataset.DataModule(cfg) datamodule = dataset.DataModule(cfg)
@ -185,13 +191,16 @@ def test(cfg: DictConfig):
"visualization_tools": visulization_tools, "visualization_tools": visulization_tools,
} }
# Debug: 确认可用设备
print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"Using device: {accelerator.device}")
if cfg.general.test_only: if cfg.general.test_only:
cfg, _ = get_resume(cfg, model_kwargs) cfg, _ = get_resume(cfg, model_kwargs)
os.chdir(cfg.general.test_only.split("checkpoints")[0]) os.chdir(cfg.general.test_only.split("checkpoints")[0])
elif cfg.general.resume is not None: elif cfg.general.resume is not None:
cfg, _ = get_resume_adaptive(cfg, model_kwargs) cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split("checkpoints")[0]) os.chdir(cfg.general.resume.split("checkpoints")[0])
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
model = Graph_DiT(cfg=cfg, **model_kwargs) model = Graph_DiT(cfg=cfg, **model_kwargs)
graph_dit_model = model graph_dit_model = model
@ -201,6 +210,7 @@ def test(cfg: DictConfig):
# optional: freeze the model # optional: freeze the model
# graph_dit_model.model.requires_grad_(True) # graph_dit_model.model.requires_grad_(True)
import torch.nn.functional as F import torch.nn.functional as F
optimizer = graph_dit_model.configure_optimizers() optimizer = graph_dit_model.configure_optimizers()
train_dataloader = accelerator.prepare(datamodule.train_dataloader()) train_dataloader = accelerator.prepare(datamodule.train_dataloader())
@ -256,13 +266,19 @@ def test(cfg: DictConfig):
chains_left_to_save = cfg.general.final_model_chains_to_save chains_left_to_save = cfg.general.final_model_chains_to_save
samples, all_ys, batch_id = [], [], 0 samples, all_ys, batch_id = [], [], 0
samples_with_log_probs = []
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
num_examples = test_y_collection.size(0) num_examples = test_y_collection.size(0)
if cfg.general.final_model_samples_to_generate > num_examples: if cfg.general.final_model_samples_to_generate > num_examples:
ratio = cfg.general.final_model_samples_to_generate // num_examples ratio = cfg.general.final_model_samples_to_generate // num_examples
test_y_collection = test_y_collection.repeat(ratio+1, 1) test_y_collection = test_y_collection.repeat(ratio+1, 1)
num_examples = test_y_collection.size(0) num_examples = test_y_collection.size(0)
def graph_reward_fn(graphs, true_graphs=None, device=None):
rewards = []
for graph in graphs:
reward = 1.0
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
while samples_left_to_generate > 0: while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/' print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
@ -273,9 +289,12 @@ def test(cfg: DictConfig):
# batch_y = test_y_collection[batch_id : batch_id + to_generate] # batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0] keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
samples = samples + cur_sample samples = samples + cur_sample
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
samples_with_log_probs.append((cur_sample, log_probs, reward))
all_ys.append(batch_y) all_ys.append(batch_y)
batch_id += to_generate batch_id += to_generate
@ -294,6 +313,35 @@ def test(cfg: DictConfig):
print("Samples:") print("Samples:")
print(samples) print(samples)
perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
samples, log_probs, rewards = samples_with_log_probs[perm]
samples = list(samples)
log_probs = list(log_probs)
print(f'log_probs: {log_probs[:5]}')
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
rewards = list(rewards)
for inner_epoch in range(cfg.train.n_epochs):
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
rewards = torch.cat(rewards, dim=0)
print(f'rewards: {rewards.shape}')
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
old_log_probs = log_probs.copy()
with accelerator.accumulate(graph_dit_model):
ratio = torch.exp(log_probs - old_log_probs)
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(ratio,
1.0 - cfg.ppo.clip_param,
1.0 + cfg.ppo.clip_param)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.log({"loss": loss.item(), "epoch": inner_epoch})
print(f"loss: {loss.item()}, epoch: {inner_epoch}")
# trainer = Trainer( # trainer = Trainer(
# gradient_clip_val=cfg.train.clip_grad, # gradient_clip_val=cfg.train.clip_grad,
# # accelerator="cpu", # # accelerator="cpu",