update some score function

This commit is contained in:
mhz 2024-08-05 21:45:15 +02:00
parent f5d00be56e
commit 205f43291b
4 changed files with 37 additions and 17 deletions

9
.gitignore vendored
View File

@ -159,3 +159,12 @@ archive.zip
logs/
generated/
data/processed/
*.pdf
*.zip
*.pth
*.bck
*.pt
cifardata/
*.meta.json
*.joblib
*.gz

View File

@ -2,7 +2,7 @@ general:
name: 'graph_dit'
wandb: 'disabled'
gpus: 1
gpu_number: 3
gpu_number: 2
resume: null
test_only: null
sample_every_val: 2500
@ -31,7 +31,8 @@ model:
lambda_train: [1, 10] # node and edge training weight
ensure_connected: True
train:
n_epochs: 5000
# n_epochs: 5000
n_epochs: 10
batch_size: 1200
lr: 0.0002
clip_grad: null

View File

@ -220,7 +220,7 @@ class Graph_DiT(pl.LightningModule):
# self.sampling_metrics.reset()
self.val_y_collection = []
@torch.no_grad()
# @torch.no_grad()
def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
@ -313,7 +313,7 @@ class Graph_DiT(pl.LightningModule):
self.test_E_logp.reset()
self.test_y_collection = []
@torch.no_grad()
# @torch.no_grad()
def test_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
@ -573,7 +573,7 @@ class Graph_DiT(pl.LightningModule):
return nll
@torch.no_grad()
# @torch.no_grad()
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
"""
:param batch_id: int
@ -742,19 +742,24 @@ class Graph_DiT(pl.LightningModule):
if valid_rlt[i]:
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
# edges = e_list[i].cpu().numpy()
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=sampled_s.X.device , args=self.args))
else:
score.append(-1)
return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
target_score = torch.ones(100, dtype=torch.float32, device=sampled_s.X.device, requires_grad=True) * 2000.0
# target_score_list = [2000 for i in range(100)]
# return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), torch.tensor(target_score_list, device=sampled_s.X.device, dtype=torch.float32, requires_grad=True)
return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), target_score
sample_num = 10
best_arch = None
best_score_int = -1e8
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
print(f'score.requires_grad: {score.requires_grad}')
for i in range(sample_num):
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
score = get_score(sampled_s)
score, target_score = get_score(sampled_s)
print(f'score: {score}')
print(f'score.shape: {score.shape}')
print(f'torch.sum(score): {torch.sum(score)}')
@ -779,14 +784,19 @@ class Graph_DiT(pl.LightningModule):
print(f'X_s: {X_s}, E_s: {E_s}')
# NASWOT score
target_score = torch.ones(100, requires_grad=True) * 2000.0
target_score = target_score.to(X_s.device)
# target_score = torch.ones(100, requires_grad=True, device=X_s.device) * 2000.0
# target_score = torch.ones(100, requires_grad=True) * 2000.0
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
print(f'best_score.device: {best_score.device}, target_score.device: {target_score.device}')
# target_score = target_score.to(X_s.device)
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
# compute loss mse(cur_score - target_score)
mse_loss = torch.nn.MSELoss()
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
loss = mse_loss(best_score, target_score)
loss = mse_loss(target_score, best_score)
print(f'loss: {loss.requires_grad}')
loss.backward(retain_graph=True)
# loss backward = gradient
@ -798,8 +808,8 @@ class Graph_DiT(pl.LightningModule):
beta_ratio = 0.5
# x_current = pred.X - beta_ratio * x_grad
# e_current = pred.E - beta_ratio * e_grad
E_s = pred.X - beta_ratio * x_grad
X_s = pred.E - beta_ratio * e_grad
X_s = pred.X - beta_ratio * x_grad
E_s = pred.E - beta_ratio * e_grad
# update prob.X prob_E with using gradient

View File

@ -86,7 +86,7 @@ class Denoiser(nn.Module):
"""
def forward(self, x, e, node_mask, y, t, unconditioned):
print("Denoiser Forward")
# print("Denoiser Forward")
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
force_drop_id = torch.zeros_like(y.sum(-1))
# drop the nan values