need to update the model
This commit is contained in:
		@@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assert (E == torch.transpose(E, 1, 2)).all()
 | 
					        assert (E == torch.transpose(E, 1, 2)).all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        total_log_probs = torch.zeros([1000,10], device=self.device)
 | 
					        total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
 | 
					        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
 | 
				
			||||||
        for s_int in reversed(range(0, self.T)):
 | 
					        for s_int in reversed(range(0, self.T)):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -273,12 +273,14 @@ def test(cfg: DictConfig):
 | 
				
			|||||||
        ratio = cfg.general.final_model_samples_to_generate // num_examples
 | 
					        ratio = cfg.general.final_model_samples_to_generate // num_examples
 | 
				
			||||||
        test_y_collection = test_y_collection.repeat(ratio+1, 1)
 | 
					        test_y_collection = test_y_collection.repeat(ratio+1, 1)
 | 
				
			||||||
        num_examples = test_y_collection.size(0)
 | 
					        num_examples = test_y_collection.size(0)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Normal reward function
 | 
				
			||||||
    def graph_reward_fn(graphs, true_graphs=None, device=None):
 | 
					    def graph_reward_fn(graphs, true_graphs=None, device=None):
 | 
				
			||||||
        rewards = []
 | 
					        rewards = []
 | 
				
			||||||
        for graph in graphs:
 | 
					        for graph in graphs:
 | 
				
			||||||
            reward = 1.0
 | 
					            reward = 1.0
 | 
				
			||||||
            rewards.append(reward)
 | 
					            rewards.append(reward)
 | 
				
			||||||
        return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
 | 
					        return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
 | 
				
			||||||
    while samples_left_to_generate > 0:
 | 
					    while samples_left_to_generate > 0:
 | 
				
			||||||
        print(f'samples left to generate: {samples_left_to_generate}/'
 | 
					        print(f'samples left to generate: {samples_left_to_generate}/'
 | 
				
			||||||
            f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
 | 
					            f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
 | 
				
			||||||
@@ -317,19 +319,33 @@ def test(cfg: DictConfig):
 | 
				
			|||||||
    samples, log_probs, rewards = samples_with_log_probs[perm]
 | 
					    samples, log_probs, rewards = samples_with_log_probs[perm]
 | 
				
			||||||
    samples = list(samples)
 | 
					    samples = list(samples)
 | 
				
			||||||
    log_probs = list(log_probs)
 | 
					    log_probs = list(log_probs)
 | 
				
			||||||
 | 
					    for i in range(len(log_probs)):
 | 
				
			||||||
 | 
					        log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
 | 
				
			||||||
    print(f'log_probs: {log_probs[:5]}')
 | 
					    print(f'log_probs: {log_probs[:5]}')
 | 
				
			||||||
    print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
 | 
					    print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
 | 
				
			||||||
    rewards = list(rewards)
 | 
					    rewards = list(rewards)
 | 
				
			||||||
 | 
					    log_probs = torch.cat(log_probs, dim=0)
 | 
				
			||||||
 | 
					    print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
 | 
				
			||||||
 | 
					    old_log_probs = log_probs.clone()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # multi metrics range
 | 
				
			||||||
 | 
					    # reward hacking hiking
 | 
				
			||||||
    for inner_epoch in range(cfg.train.n_epochs):
 | 
					    for inner_epoch in range(cfg.train.n_epochs):
 | 
				
			||||||
        # print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
 | 
					        # print(f'rewards: {rewards.shape}') # torch.Size([1000])
 | 
				
			||||||
        rewards = torch.cat(rewards, dim=0)
 | 
					        print(f'rewards: {rewards[:5]}')
 | 
				
			||||||
        print(f'rewards: {rewards.shape}')
 | 
					        print(f'len rewards: {len(rewards)}')
 | 
				
			||||||
 | 
					        print(f'type rewards: {type(rewards)}')
 | 
				
			||||||
 | 
					        if len(rewards) > 1 and isinstance(rewards, list):
 | 
				
			||||||
 | 
					            rewards = torch.cat(rewards, dim=0)
 | 
				
			||||||
 | 
					        elif len(rewards) == 1 and isinstance(rewards, list):
 | 
				
			||||||
 | 
					            rewards = rewards[0]
 | 
				
			||||||
 | 
					        # print(f'rewards: {rewards.shape}')
 | 
				
			||||||
        advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
 | 
					        advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
 | 
				
			||||||
        old_log_probs = log_probs.copy()
 | 
					        print(f'advantages: {advantages.shape}')
 | 
				
			||||||
        with accelerator.accumulate(graph_dit_model):
 | 
					        with accelerator.accumulate(graph_dit_model):
 | 
				
			||||||
            ratio = torch.exp(log_probs - old_log_probs)
 | 
					            ratio = torch.exp(log_probs - old_log_probs)
 | 
				
			||||||
            unclipped_loss = -advantages * ratio
 | 
					            unclipped_loss = -advantages * ratio
 | 
				
			||||||
 | 
					            # z-score normalization
 | 
				
			||||||
            clipped_loss = -advantages * torch.clamp(ratio,
 | 
					            clipped_loss = -advantages * torch.clamp(ratio,
 | 
				
			||||||
                            1.0 - cfg.ppo.clip_param,
 | 
					                            1.0 - cfg.ppo.clip_param,
 | 
				
			||||||
                            1.0 + cfg.ppo.clip_param)
 | 
					                            1.0 + cfg.ppo.clip_param)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user