add train_loader and searchspace codes
This commit is contained in:
		| @@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split | ||||
| import utils as utils | ||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||
| from diffusion.distributions import DistributionNodes | ||||
| # from naswot.score_networks import get_nasbench201_idx_score | ||||
|  | ||||
| import networkx as nx | ||||
|  | ||||
| @@ -679,7 +680,8 @@ class Dataset(InMemoryDataset): | ||||
|         self.api = API(source) | ||||
|  | ||||
|         data_list = [] | ||||
|         len_data = len(self.api) | ||||
|         # len_data = len(self.api) | ||||
|         len_data = 1000 | ||||
|         def check_valid_graph(nodes, edges): | ||||
|             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: | ||||
|                 return False | ||||
| @@ -727,7 +729,10 @@ class Dataset(InMemoryDataset): | ||||
|                         edges[i, j] = 1 | ||||
|             return nodes, edges | ||||
|          | ||||
|         def graph_to_graph_data(graph): | ||||
|         def get_nasbench_201_val(idx): | ||||
|             pass | ||||
|  | ||||
|         def graph_to_graph_data(graph, idx): | ||||
|             ops = graph[1] | ||||
|             adj = graph[0] | ||||
|             nodes = [] | ||||
| @@ -742,13 +747,14 @@ class Dataset(InMemoryDataset): | ||||
|                     if adj[start][end] == 1: | ||||
|                         edges_list.append((start, end)) | ||||
|                         edge_type.append(1) | ||||
|                         # edges_list.append((end, start)) | ||||
|                         # edge_type.append(1) | ||||
|                         edges_list.append((end, start)) | ||||
|                         edge_type.append(1) | ||||
|              | ||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t() | ||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||
|             edge_attr = edge_type | ||||
|             y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|             y = get_nasbench_201_val(idx) | ||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
|             return data | ||||
|         graph_list = [] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user