From f5d00be56e261d9818c8a375aa7679ea3cd79bd5 Mon Sep 17 00:00:00 2001 From: mhz Date: Tue, 30 Jul 2024 00:12:37 +0200 Subject: [PATCH] add train_loader and searchspace codes --- graph_dit/datasets/dataset.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 43a6e26..5048716 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -25,6 +25,7 @@ from sklearn.model_selection import train_test_split import utils as utils from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from diffusion.distributions import DistributionNodes +# from naswot.score_networks import get_nasbench201_idx_score import networkx as nx @@ -679,7 +680,8 @@ class Dataset(InMemoryDataset): self.api = API(source) data_list = [] - len_data = len(self.api) + # len_data = len(self.api) + len_data = 1000 def check_valid_graph(nodes, edges): if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: return False @@ -726,8 +728,11 @@ class Dataset(InMemoryDataset): if rand < random_ratio: edges[i, j] = 1 return nodes, edges + + def get_nasbench_201_val(idx): + pass - def graph_to_graph_data(graph): + def graph_to_graph_data(graph, idx): ops = graph[1] adj = graph[0] nodes = [] @@ -742,13 +747,14 @@ class Dataset(InMemoryDataset): if adj[start][end] == 1: edges_list.append((start, end)) edge_type.append(1) - # edges_list.append((end, start)) - # edge_type.append(1) + edges_list.append((end, start)) + edge_type.append(1) edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type - y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) + # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) + y = get_nasbench_201_val(idx) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) return data graph_list = []