5 Commits

Author SHA1 Message Date
mhz
a7f7010da7 write graph code for the absctract dataset 2024-06-26 23:42:01 +02:00
mhz
14186fa97f write test code 2024-06-26 23:41:37 +02:00
mhz
a222c514d9 add get_train_graphs 2024-06-26 22:42:06 +02:00
mhz
062a27b83f try update the api in DataInfo 2024-06-26 22:10:07 +02:00
mhz
0c7c525680 try update the api in DataInfo 2024-06-26 22:09:46 +02:00
5 changed files with 74 additions and 18 deletions

View File

@@ -118,6 +118,21 @@ class AbstractDatasetInfos:
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}
self.output_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}
print('input dims')
print(self.input_dims)
print('output dims')
print(self.output_dims)
def compute_graph_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float()
self.input_dims = {'X': example_batch_x.size(1), self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1), 'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)} 'y': example_batch['y'].size(1)}

View File

@@ -50,12 +50,12 @@ class DataModule(AbstractDataModule):
def prepare_data(self) -> None: def prepare_data(self) -> None:
target = getattr(self.cfg.dataset, 'guidance_target', None) target = getattr(self.cfg.dataset, 'guidance_target', None)
print("target", target) print("target", target) # nasbench-201
# try: # try:
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError: # except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2] # base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/home/stud/hanzhang/Graph-Dit' base_path = '/home/stud/hanzhang/nasbenchDiT'
root_path = os.path.join(base_path, self.datadir) root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path self.root_path = root_path
@@ -68,13 +68,16 @@ class DataModule(AbstractDataModule):
# Dataset has target property, root path, and transform # Dataset has target property, root path, and transform
source = './NAS-Bench-201-v1_1-096897.pth' source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
self.dataset = dataset
self.api = dataset.api
# if len(self.task.split('-')) == 2: # if len(self.task.split('-')) == 2:
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
# else: # else:
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index self.train_index, self.val_index, self.test_index, self.unlabeled_index = (
train_index, val_index, test_index, unlabeled_index)
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
if len(unlabeled_index) > 0: if len(unlabeled_index) > 0:
train_index = torch.cat([train_index, unlabeled_index], dim=0) train_index = torch.cat([train_index, unlabeled_index], dim=0)
@@ -175,6 +178,27 @@ class DataModule(AbstractDataModule):
smiles = Chem.MolToSmiles(mol) smiles = Chem.MolToSmiles(mol)
return smiles return smiles
def get_train_graphs(self):
train_graphs = []
test_graphs = []
for graph in self.train_dataset:
train_graphs.append(graph)
for graph in self.test_dataset:
test_graphs.append(graph)
return train_graphs, test_graphs
# def get_train_smiles(self):
# filename = f'{self.task}.csv.gz'
# df = pd.read_csv(f'{self.root_path}/raw/{filename}')
# df_test = df.iloc[self.test_index]
# df = df.iloc[self.train_index]
# smiles_list = df['smiles'].tolist()
# smiles_list_test = df_test['smiles'].tolist()
# smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
# smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
# return smiles_list, smiles_list_test
def get_train_smiles(self): def get_train_smiles(self):
train_smiles = [] train_smiles = []
test_smiles = [] test_smiles = []
@@ -477,14 +501,17 @@ def graphs_to_json(graphs, filename):
class Dataset(InMemoryDataset): class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source self.source = source
super().__init__(root, transform, pre_transform, pre_filter)
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
self.api = API(source) # Initialize NAS-Bench-201 API self.api = API(source) # Initialize NAS-Bench-201 API
print('API loaded') print('API loaded')
super().__init__(root, transform, pre_transform, pre_filter)
print('Dataset initialized') print('Dataset initialized')
print(self.processed_paths[0])
self.data, self.slices = torch.load(self.processed_paths[0]) self.data, self.slices = torch.load(self.processed_paths[0])
self.data.edge_attr = self.data.edge_attr.squeeze()
self.data.idx = torch.arange(len(self.data.y))
print(f"self.data={self.data}, self.slices={self.slices}")
@property @property
def raw_file_names(self): def raw_file_names(self):
@@ -676,7 +703,7 @@ def create_adj_matrix_and_ops(nodes, edges):
adj_matrix[src][dst] = 1 adj_matrix[src][dst] = 1
return adj_matrix, nodes return adj_matrix, nodes
class DataInfos(AbstractDatasetInfos): class DataInfos(AbstractDatasetInfos):
def __init__(self, datamodule, cfg): def __init__(self, datamodule, cfg, dataset):
tasktype_dict = { tasktype_dict = {
'hiv_b': 'classification', 'hiv_b': 'classification',
'bace_b': 'classification', 'bace_b': 'classification',
@@ -689,6 +716,7 @@ class DataInfos(AbstractDatasetInfos):
self.task = task_name self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression") self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected self.ensure_connected = cfg.model.ensure_connected
self.api = dataset.api
datadir = cfg.dataset.datadir datadir = cfg.dataset.datadir
@@ -699,9 +727,9 @@ class DataInfos(AbstractDatasetInfos):
length = 15625 length = 15625
ops_type = {} ops_type = {}
len_ops = set() len_ops = set()
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
for i in range(length): for i in range(length):
arch_info = api.query_meta_info_by_index(i) arch_info = self.api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str) nodes, edges = parse_architecture_string(arch_info.arch_str)
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
if i < 5: if i < 5:
@@ -929,4 +957,4 @@ def compute_meta(root, source_name, train_index, test_index):
if __name__ == "__main__": if __name__ == "__main__":
pass dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)

View File

@@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def validation_step(self, data, i): def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask, collapse=False) dense_data = dense_data.mask(node_mask)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data) pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
@@ -444,11 +444,9 @@ class Graph_DiT(pl.LightningModule):
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
print(f"alpha_t_bar.shape {alpha_t_bar.shape}")
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out) Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
print(f"E.shape {E.shape}")
print(f"X.shape {X.shape}")
bs, n, d = X.shape bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X prob_all = X_all @ Qtb.X

View File

@@ -78,16 +78,20 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg) datamodule = dataset.DataModule(cfg)
datamodule.prepare_data() datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
# train_smiles, reference_smiles = datamodule.get_train_smiles() # train_smiles, reference_smiles = datamodule.get_train_smiles()
train_graphs, reference_graphs = datamodule.get_train_graphs()
# get input output dimensions # get input output dimensions
dataset_infos.compute_input_output_dims(datamodule=datamodule) dataset_infos.compute_input_output_dims(datamodule=datamodule)
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
# sampling_metrics = SamplingMolecularMetrics( # sampling_metrics = SamplingMolecularMetrics(
# dataset_infos, train_smiles, reference_smiles # dataset_infos, train_smiles, reference_smiles
# ) # )
sampling_metrics = SamplingGraphMetrics(
dataset_infos, train_graphs, reference_graphs
)
visualization_tools = MolecularVisualization(dataset_infos) visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = { model_kwargs = {
@@ -135,5 +139,16 @@ def main(cfg: DictConfig):
else: else:
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
@hydra.main(
version_base="1.1", config_path="../configs", config_name="config"
)
def test(cfg: DictConfig):
datamodule = dataset.DataModule(cfg)
datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
train_graphs, reference_graphs = datamodule.get_train_graphs()
dataset_infos.compute_input_output_dims(datamodule=datamodule)
if __name__ == "__main__": if __name__ == "__main__":
main() test()

0
graph_dit/workingdoc.md Normal file
View File