diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 9db49eb..2ba84b1 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -771,9 +771,10 @@ class Dataset(InMemoryDataset): edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) - y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) + # y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) + y = self.swap_scores[idx] print(y, idx) - if y > 1600: + if y > 60000: print(f'idx={idx}, y={y}') y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) @@ -812,6 +813,14 @@ class Dataset(InMemoryDataset): args.num_labels = 1 searchspace = nasspace.get_search_space(args) train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) + self.swap_scores = [] + import csv + # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: + reader = csv.reader(f) + header = next(reader) + data = [row for row in reader] + self.swap_scores = [float(row[0]) for row in data] device = torch.device('cuda:2') with tqdm(total = len_data) as pbar: active_nodes = set() @@ -823,14 +832,8 @@ class Dataset(InMemoryDataset): flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' for graph in graph_list: print(f'iterate every graph in graph_list, here is {i}') - # arch_info = self.api.query_meta_info_by_index(i) - # results = self.api.query_by_index(i, 'cifar100') arch_info = graph['arch_str'] - # results = - # nodes, edges = parse_architecture_string(arch_info.arch_str) - # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) - # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) for op in ops: if op not in active_nodes: active_nodes.add(op) @@ -839,12 +842,6 @@ class Dataset(InMemoryDataset): if data is None: pbar.update(1) continue - # with open(flex_graph_path, 'a') as f: - # flex_graph = { - # 'adj_matrix': adj_matrix, - # 'ops': ops, - # } - # json.dump(flex_graph, f) flex_graph_list.append({ 'adj_matrix':adj_matrix, 'ops': ops,