diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index ff53e46..6c7c5ee 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule): assert (E == torch.transpose(E, 1, 2)).all() - total_log_probs = torch.zeros(batch_size, device=self.device) + total_log_probs = torch.zeros([1000,10], device=self.device) # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. for s_int in reversed(range(0, self.T)): @@ -613,6 +613,8 @@ class Graph_DiT(pl.LightningModule): # Sample z_s sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) X, E, y = sampled_s.X, sampled_s.E, sampled_s.y + print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}') + print(f'log_probs shape: {log_probs.shape}') total_log_probs += log_probs # Sample @@ -688,8 +690,9 @@ class Graph_DiT(pl.LightningModule): log_prob_X = log_prob_X.sum(dim=-1) log_prob_E = log_prob_E.sum(dim=(1, 2)) print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') - log_probs = log_prob_E + log_prob_X - + # log_probs = log_prob_E + log_prob_X + log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1) # (batch_size, 2) + print(f'log_probs shape: {log_probs.shape}') ### Guidance if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) diff --git a/graph_dit/main.py b/graph_dit/main.py index 30dd2e3..1c35d92 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -152,6 +152,7 @@ from accelerate.utils import set_seed, ProjectConfiguration version_base="1.1", config_path="../configs", config_name="config" ) def test(cfg: DictConfig): + os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number accelerator_config = ProjectConfiguration( project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), automatic_checkpoint_naming=True, @@ -162,6 +163,11 @@ def test(cfg: DictConfig): project_config=accelerator_config, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, ) + + # Debug: 确认可用设备 + print(f"Available GPUs: {torch.cuda.device_count()}") + print(f"Using device: {accelerator.device}") + set_seed(cfg.train.seed, device_specific=True) datamodule = dataset.DataModule(cfg) @@ -185,13 +191,16 @@ def test(cfg: DictConfig): "visualization_tools": visulization_tools, } + # Debug: 确认可用设备 + print(f"Available GPUs: {torch.cuda.device_count()}") + print(f"Using device: {accelerator.device}") + if cfg.general.test_only: cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) - # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number model = Graph_DiT(cfg=cfg, **model_kwargs) graph_dit_model = model @@ -201,6 +210,7 @@ def test(cfg: DictConfig): # optional: freeze the model # graph_dit_model.model.requires_grad_(True) + import torch.nn.functional as F optimizer = graph_dit_model.configure_optimizers() train_dataloader = accelerator.prepare(datamodule.train_dataloader()) @@ -256,13 +266,19 @@ def test(cfg: DictConfig): chains_left_to_save = cfg.general.final_model_chains_to_save samples, all_ys, batch_id = [], [], 0 + samples_with_log_probs = [] test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) num_examples = test_y_collection.size(0) if cfg.general.final_model_samples_to_generate > num_examples: ratio = cfg.general.final_model_samples_to_generate // num_examples test_y_collection = test_y_collection.repeat(ratio+1, 1) num_examples = test_y_collection.size(0) - + def graph_reward_fn(graphs, true_graphs=None, device=None): + rewards = [] + for graph in graphs: + reward = 1.0 + rewards.append(reward) + return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device) while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) @@ -273,9 +289,12 @@ def test(cfg: DictConfig): # batch_y = test_y_collection[batch_id : batch_id + to_generate] batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, - keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0] + cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) samples = samples + cur_sample + reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + + samples_with_log_probs.append((cur_sample, log_probs, reward)) all_ys.append(batch_y) batch_id += to_generate @@ -294,6 +313,35 @@ def test(cfg: DictConfig): print("Samples:") print(samples) + perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) + samples, log_probs, rewards = samples_with_log_probs[perm] + samples = list(samples) + log_probs = list(log_probs) + print(f'log_probs: {log_probs[:5]}') + print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000]) + rewards = list(rewards) + + for inner_epoch in range(cfg.train.n_epochs): + # print(f'rewards: {rewards[0].shape}') # torch.Size([1000]) + rewards = torch.cat(rewards, dim=0) + print(f'rewards: {rewards.shape}') + advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) + old_log_probs = log_probs.copy() + with accelerator.accumulate(graph_dit_model): + ratio = torch.exp(log_probs - old_log_probs) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp(ratio, + 1.0 - cfg.ppo.clip_param, + 1.0 + cfg.ppo.clip_param) + loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) + print(f"loss: {loss.item()}, epoch: {inner_epoch}") + + # trainer = Trainer( # gradient_clip_val=cfg.train.clip_grad, # # accelerator="cpu",