diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 8f21f7e..1adb5ec 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -674,7 +674,7 @@ class Dataset(InMemoryDataset): edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type - y = torch.tensor([0], dtype=torch.float).view(1, -1) + y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) return data graph_list = [] @@ -898,7 +898,7 @@ class Dataset_origin(InMemoryDataset): torch.save(self.collate(data_list), self.processed_paths[0]) def parse_architecture_string(arch_str): - print(arch_str) + # print(arch_str) steps = arch_str.split('+') nodes = ['input'] # Start with input node edges = []