update some score function
This commit is contained in:
parent
f5d00be56e
commit
205f43291b
9
.gitignore
vendored
9
.gitignore
vendored
@ -159,3 +159,12 @@ archive.zip
|
|||||||
logs/
|
logs/
|
||||||
generated/
|
generated/
|
||||||
data/processed/
|
data/processed/
|
||||||
|
*.pdf
|
||||||
|
*.zip
|
||||||
|
*.pth
|
||||||
|
*.bck
|
||||||
|
*.pt
|
||||||
|
cifardata/
|
||||||
|
*.meta.json
|
||||||
|
*.joblib
|
||||||
|
*.gz
|
@ -2,7 +2,7 @@ general:
|
|||||||
name: 'graph_dit'
|
name: 'graph_dit'
|
||||||
wandb: 'disabled'
|
wandb: 'disabled'
|
||||||
gpus: 1
|
gpus: 1
|
||||||
gpu_number: 3
|
gpu_number: 2
|
||||||
resume: null
|
resume: null
|
||||||
test_only: null
|
test_only: null
|
||||||
sample_every_val: 2500
|
sample_every_val: 2500
|
||||||
@ -31,7 +31,8 @@ model:
|
|||||||
lambda_train: [1, 10] # node and edge training weight
|
lambda_train: [1, 10] # node and edge training weight
|
||||||
ensure_connected: True
|
ensure_connected: True
|
||||||
train:
|
train:
|
||||||
n_epochs: 5000
|
# n_epochs: 5000
|
||||||
|
n_epochs: 10
|
||||||
batch_size: 1200
|
batch_size: 1200
|
||||||
lr: 0.0002
|
lr: 0.0002
|
||||||
clip_grad: null
|
clip_grad: null
|
||||||
|
@ -220,7 +220,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# self.sampling_metrics.reset()
|
# self.sampling_metrics.reset()
|
||||||
self.val_y_collection = []
|
self.val_y_collection = []
|
||||||
|
|
||||||
@torch.no_grad()
|
# @torch.no_grad()
|
||||||
def validation_step(self, data, i):
|
def validation_step(self, data, i):
|
||||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||||
@ -313,7 +313,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.test_E_logp.reset()
|
self.test_E_logp.reset()
|
||||||
self.test_y_collection = []
|
self.test_y_collection = []
|
||||||
|
|
||||||
@torch.no_grad()
|
# @torch.no_grad()
|
||||||
def test_step(self, data, i):
|
def test_step(self, data, i):
|
||||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||||
@ -573,7 +573,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
return nll
|
return nll
|
||||||
|
|
||||||
@torch.no_grad()
|
# @torch.no_grad()
|
||||||
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
|
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
|
||||||
"""
|
"""
|
||||||
:param batch_id: int
|
:param batch_id: int
|
||||||
@ -742,19 +742,24 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
if valid_rlt[i]:
|
if valid_rlt[i]:
|
||||||
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||||
# edges = e_list[i].cpu().numpy()
|
# edges = e_list[i].cpu().numpy()
|
||||||
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
|
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=sampled_s.X.device , args=self.args))
|
||||||
else:
|
else:
|
||||||
score.append(-1)
|
score.append(-1)
|
||||||
return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||||
|
target_score = torch.ones(100, dtype=torch.float32, device=sampled_s.X.device, requires_grad=True) * 2000.0
|
||||||
|
# target_score_list = [2000 for i in range(100)]
|
||||||
|
# return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), torch.tensor(target_score_list, device=sampled_s.X.device, dtype=torch.float32, requires_grad=True)
|
||||||
|
return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), target_score
|
||||||
|
|
||||||
sample_num = 10
|
sample_num = 10
|
||||||
best_arch = None
|
best_arch = None
|
||||||
best_score_int = -1e8
|
best_score_int = -1e8
|
||||||
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||||
|
print(f'score.requires_grad: {score.requires_grad}')
|
||||||
|
|
||||||
for i in range(sample_num):
|
for i in range(sample_num):
|
||||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||||
score = get_score(sampled_s)
|
score, target_score = get_score(sampled_s)
|
||||||
print(f'score: {score}')
|
print(f'score: {score}')
|
||||||
print(f'score.shape: {score.shape}')
|
print(f'score.shape: {score.shape}')
|
||||||
print(f'torch.sum(score): {torch.sum(score)}')
|
print(f'torch.sum(score): {torch.sum(score)}')
|
||||||
@ -779,14 +784,19 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
print(f'X_s: {X_s}, E_s: {E_s}')
|
print(f'X_s: {X_s}, E_s: {E_s}')
|
||||||
|
|
||||||
# NASWOT score
|
# NASWOT score
|
||||||
target_score = torch.ones(100, requires_grad=True) * 2000.0
|
# target_score = torch.ones(100, requires_grad=True, device=X_s.device) * 2000.0
|
||||||
target_score = target_score.to(X_s.device)
|
# target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||||
|
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||||
|
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||||
|
print(f'best_score.device: {best_score.device}, target_score.device: {target_score.device}')
|
||||||
|
# target_score = target_score.to(X_s.device)
|
||||||
|
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||||
|
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||||
|
|
||||||
# compute loss mse(cur_score - target_score)
|
# compute loss mse(cur_score - target_score)
|
||||||
mse_loss = torch.nn.MSELoss()
|
mse_loss = torch.nn.MSELoss()
|
||||||
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
loss = mse_loss(target_score, best_score)
|
||||||
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
print(f'loss: {loss.requires_grad}')
|
||||||
loss = mse_loss(best_score, target_score)
|
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
|
|
||||||
# loss backward = gradient
|
# loss backward = gradient
|
||||||
@ -798,8 +808,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
beta_ratio = 0.5
|
beta_ratio = 0.5
|
||||||
# x_current = pred.X - beta_ratio * x_grad
|
# x_current = pred.X - beta_ratio * x_grad
|
||||||
# e_current = pred.E - beta_ratio * e_grad
|
# e_current = pred.E - beta_ratio * e_grad
|
||||||
E_s = pred.X - beta_ratio * x_grad
|
X_s = pred.X - beta_ratio * x_grad
|
||||||
X_s = pred.E - beta_ratio * e_grad
|
E_s = pred.E - beta_ratio * e_grad
|
||||||
|
|
||||||
# update prob.X prob_E with using gradient
|
# update prob.X prob_E with using gradient
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class Denoiser(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||||
|
|
||||||
print("Denoiser Forward")
|
# print("Denoiser Forward")
|
||||||
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
|
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
|
||||||
force_drop_id = torch.zeros_like(y.sum(-1))
|
force_drop_id = torch.zeros_like(y.sum(-1))
|
||||||
# drop the nan values
|
# drop the nan values
|
||||||
|
Loading…
Reference in New Issue
Block a user