add the idea of guidance
This commit is contained in:
parent
fcdd8efc4f
commit
55ff19421d
@ -8,6 +8,7 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import pathlib
|
import pathlib
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -49,6 +50,9 @@ op_type = {
|
|||||||
'none': 5,
|
'none': 5,
|
||||||
'output': 6,
|
'output': 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||||
|
|
||||||
class DataModule(AbstractDataModule):
|
class DataModule(AbstractDataModule):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.datadir = cfg.dataset.datadir
|
self.datadir = cfg.dataset.datadir
|
||||||
@ -676,6 +680,52 @@ class Dataset(InMemoryDataset):
|
|||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
len_data = len(self.api)
|
len_data = len(self.api)
|
||||||
|
def check_valid_graph(nodes, edges):
|
||||||
|
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||||
|
return False
|
||||||
|
if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||||
|
return False
|
||||||
|
for i in range(0, len(nodes)):
|
||||||
|
if edges[i][i] == 1:
|
||||||
|
return False
|
||||||
|
for i in range(1, len(nodes) - 1):
|
||||||
|
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||||
|
return False
|
||||||
|
for i in range(0, len(nodes)):
|
||||||
|
for j in range(i, len(nodes)):
|
||||||
|
if edges[i, j] == 1 and nodes[j] == 'input':
|
||||||
|
return False
|
||||||
|
for i in range(0, len(nodes)):
|
||||||
|
for j in range(i, len(nodes)):
|
||||||
|
if edges[i, j] == 1 and nodes[i] == 'output':
|
||||||
|
return False
|
||||||
|
flag = 0
|
||||||
|
for i in range(0,len(nodes)):
|
||||||
|
if edges[i,-1] == 1:
|
||||||
|
flag = 1
|
||||||
|
break
|
||||||
|
if flag == 0: return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
|
||||||
|
nasbench_201_node_num = 8
|
||||||
|
# random.seed(random_seed)
|
||||||
|
nodes_num = random.randint(min_nodes, max_nodes)
|
||||||
|
# print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
|
||||||
|
add_num = nodes_num - nasbench_201_node_num
|
||||||
|
# ori_nodes, ori_edges = parse_architecture_string(arch_str)
|
||||||
|
add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)]
|
||||||
|
# print(add_nodes)
|
||||||
|
nodes = ori_nodes[:-1] + add_nodes + ['output']
|
||||||
|
edges = np.zeros((nodes_num , nodes_num))
|
||||||
|
edges[:6, :6] = ori_edges[:6, :6]
|
||||||
|
edges[0:8, -1] = ori_edges[0:8 , -1]
|
||||||
|
for i in range(0, nodes_num):
|
||||||
|
for j in range(max(7,i + 1), nodes_num):
|
||||||
|
rand = random.random()
|
||||||
|
if rand < random_ratio:
|
||||||
|
edges[i, j] = 1
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
def graph_to_graph_data(graph):
|
def graph_to_graph_data(graph):
|
||||||
ops = graph[1]
|
ops = graph[1]
|
||||||
@ -746,6 +796,9 @@ class Dataset(InMemoryDataset):
|
|||||||
})
|
})
|
||||||
data = graph_to_graph_data((adj_matrix, ops))
|
data = graph_to_graph_data((adj_matrix, ops))
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
|
|
||||||
|
# new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8, random_ratio=0.5)
|
||||||
|
# data_list.append(graph_to_graph_data((new_adj, new_ops)))
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
for graph in graph_list:
|
for graph in graph_list:
|
||||||
|
@ -134,7 +134,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
||||||
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
||||||
log=i % self.log_every_steps == 0)
|
log=i % self.log_every_steps == 0)
|
||||||
|
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
|
||||||
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||||
log=i % self.log_every_steps == 0)
|
log=i % self.log_every_steps == 0)
|
||||||
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||||
@ -602,6 +602,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
# Normalize predictions
|
# Normalize predictions
|
||||||
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
|
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
|
||||||
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
|
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
|
||||||
|
# gradient
|
||||||
|
|
||||||
# Retrieve transitions matrix
|
# Retrieve transitions matrix
|
||||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
|
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
|
||||||
@ -629,7 +630,9 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
|
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
|
||||||
|
|
||||||
return prob_X, prob_E
|
return prob_X, prob_E
|
||||||
|
# diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t)
|
||||||
|
# with condition = P_t(G_{t-1} |G_t, C)
|
||||||
|
# with condition = P_t(A_{t-1} |A_t, y)
|
||||||
prob_X, prob_E = get_prob(noisy_data)
|
prob_X, prob_E = get_prob(noisy_data)
|
||||||
|
|
||||||
### Guidance
|
### Guidance
|
||||||
@ -640,14 +643,39 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||||
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||||
|
|
||||||
|
|
||||||
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||||
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||||
|
|
||||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||||
|
|
||||||
|
# sample multiple times and get the best score arch...
|
||||||
|
|
||||||
|
sample_num = 100
|
||||||
|
best_arch = None
|
||||||
|
best_score = -1e8
|
||||||
|
|
||||||
|
for i in range(sample_num):
|
||||||
|
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||||
|
score = get_score(sampled_s)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_arch = sampled_s
|
||||||
|
|
||||||
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||||
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||||
|
|
||||||
|
# NASWOT score
|
||||||
|
target_score = torch.tensor([3000.0])
|
||||||
|
|
||||||
|
# compute loss mse(cur_score - target_score)
|
||||||
|
|
||||||
|
# loss backward = gradient
|
||||||
|
|
||||||
|
# get prob.X, prob_E gradient
|
||||||
|
|
||||||
|
# update prob.X prob_E with using gradient
|
||||||
|
|
||||||
assert (E_s == torch.transpose(E_s, 1, 2)).all()
|
assert (E_s == torch.transpose(E_s, 1, 2)).all()
|
||||||
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
|
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user