update some score function
This commit is contained in:
parent
f5d00be56e
commit
205f43291b
11
.gitignore
vendored
11
.gitignore
vendored
@ -158,4 +158,13 @@ src/analysis/orca/tmp_XMYAR426.txt
|
||||
archive.zip
|
||||
logs/
|
||||
generated/
|
||||
data/processed/
|
||||
data/processed/
|
||||
*.pdf
|
||||
*.zip
|
||||
*.pth
|
||||
*.bck
|
||||
*.pt
|
||||
cifardata/
|
||||
*.meta.json
|
||||
*.joblib
|
||||
*.gz
|
@ -2,7 +2,7 @@ general:
|
||||
name: 'graph_dit'
|
||||
wandb: 'disabled'
|
||||
gpus: 1
|
||||
gpu_number: 3
|
||||
gpu_number: 2
|
||||
resume: null
|
||||
test_only: null
|
||||
sample_every_val: 2500
|
||||
@ -31,7 +31,8 @@ model:
|
||||
lambda_train: [1, 10] # node and edge training weight
|
||||
ensure_connected: True
|
||||
train:
|
||||
n_epochs: 5000
|
||||
# n_epochs: 5000
|
||||
n_epochs: 10
|
||||
batch_size: 1200
|
||||
lr: 0.0002
|
||||
clip_grad: null
|
||||
|
@ -220,7 +220,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
# self.sampling_metrics.reset()
|
||||
self.val_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
# @torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
@ -313,7 +313,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.test_E_logp.reset()
|
||||
self.test_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
# @torch.no_grad()
|
||||
def test_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
@ -573,7 +573,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
return nll
|
||||
|
||||
@torch.no_grad()
|
||||
# @torch.no_grad()
|
||||
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
|
||||
"""
|
||||
:param batch_id: int
|
||||
@ -742,19 +742,24 @@ class Graph_DiT(pl.LightningModule):
|
||||
if valid_rlt[i]:
|
||||
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||
# edges = e_list[i].cpu().numpy()
|
||||
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
|
||||
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=sampled_s.X.device , args=self.args))
|
||||
else:
|
||||
score.append(-1)
|
||||
return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
target_score = torch.ones(100, dtype=torch.float32, device=sampled_s.X.device, requires_grad=True) * 2000.0
|
||||
# target_score_list = [2000 for i in range(100)]
|
||||
# return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), torch.tensor(target_score_list, device=sampled_s.X.device, dtype=torch.float32, requires_grad=True)
|
||||
return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), target_score
|
||||
|
||||
sample_num = 10
|
||||
best_arch = None
|
||||
best_score_int = -1e8
|
||||
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||
print(f'score.requires_grad: {score.requires_grad}')
|
||||
|
||||
for i in range(sample_num):
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
score = get_score(sampled_s)
|
||||
score, target_score = get_score(sampled_s)
|
||||
print(f'score: {score}')
|
||||
print(f'score.shape: {score.shape}')
|
||||
print(f'torch.sum(score): {torch.sum(score)}')
|
||||
@ -779,14 +784,19 @@ class Graph_DiT(pl.LightningModule):
|
||||
print(f'X_s: {X_s}, E_s: {E_s}')
|
||||
|
||||
# NASWOT score
|
||||
target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||
target_score = target_score.to(X_s.device)
|
||||
# target_score = torch.ones(100, requires_grad=True, device=X_s.device) * 2000.0
|
||||
# target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
print(f'best_score.device: {best_score.device}, target_score.device: {target_score.device}')
|
||||
# target_score = target_score.to(X_s.device)
|
||||
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
|
||||
# compute loss mse(cur_score - target_score)
|
||||
mse_loss = torch.nn.MSELoss()
|
||||
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
loss = mse_loss(best_score, target_score)
|
||||
loss = mse_loss(target_score, best_score)
|
||||
print(f'loss: {loss.requires_grad}')
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
# loss backward = gradient
|
||||
@ -798,8 +808,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
beta_ratio = 0.5
|
||||
# x_current = pred.X - beta_ratio * x_grad
|
||||
# e_current = pred.E - beta_ratio * e_grad
|
||||
E_s = pred.X - beta_ratio * x_grad
|
||||
X_s = pred.E - beta_ratio * e_grad
|
||||
X_s = pred.X - beta_ratio * x_grad
|
||||
E_s = pred.E - beta_ratio * e_grad
|
||||
|
||||
# update prob.X prob_E with using gradient
|
||||
|
||||
|
@ -86,7 +86,7 @@ class Denoiser(nn.Module):
|
||||
"""
|
||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||
|
||||
print("Denoiser Forward")
|
||||
# print("Denoiser Forward")
|
||||
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
|
||||
force_drop_id = torch.zeros_like(y.sum(-1))
|
||||
# drop the nan values
|
||||
|
Loading…
Reference in New Issue
Block a user