add train_loader and searchspace codes
This commit is contained in:
parent
5e66aa74e7
commit
f5d00be56e
@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split
|
|||||||
import utils as utils
|
import utils as utils
|
||||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||||
from diffusion.distributions import DistributionNodes
|
from diffusion.distributions import DistributionNodes
|
||||||
|
# from naswot.score_networks import get_nasbench201_idx_score
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
@ -679,7 +680,8 @@ class Dataset(InMemoryDataset):
|
|||||||
self.api = API(source)
|
self.api = API(source)
|
||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
len_data = len(self.api)
|
# len_data = len(self.api)
|
||||||
|
len_data = 1000
|
||||||
def check_valid_graph(nodes, edges):
|
def check_valid_graph(nodes, edges):
|
||||||
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||||
return False
|
return False
|
||||||
@ -727,7 +729,10 @@ class Dataset(InMemoryDataset):
|
|||||||
edges[i, j] = 1
|
edges[i, j] = 1
|
||||||
return nodes, edges
|
return nodes, edges
|
||||||
|
|
||||||
def graph_to_graph_data(graph):
|
def get_nasbench_201_val(idx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def graph_to_graph_data(graph, idx):
|
||||||
ops = graph[1]
|
ops = graph[1]
|
||||||
adj = graph[0]
|
adj = graph[0]
|
||||||
nodes = []
|
nodes = []
|
||||||
@ -742,13 +747,14 @@ class Dataset(InMemoryDataset):
|
|||||||
if adj[start][end] == 1:
|
if adj[start][end] == 1:
|
||||||
edges_list.append((start, end))
|
edges_list.append((start, end))
|
||||||
edge_type.append(1)
|
edge_type.append(1)
|
||||||
# edges_list.append((end, start))
|
edges_list.append((end, start))
|
||||||
# edge_type.append(1)
|
edge_type.append(1)
|
||||||
|
|
||||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||||
edge_attr = edge_type
|
edge_attr = edge_type
|
||||||
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
# y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
||||||
|
y = get_nasbench_201_val(idx)
|
||||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||||
return data
|
return data
|
||||||
graph_list = []
|
graph_list = []
|
||||||
|
Loading…
Reference in New Issue
Block a user