need to sync with compute_meta

This commit is contained in:
Hanzhang Ma 2024-06-11 00:08:23 +02:00
parent 7831979db7
commit b6b7210ba7

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -15663,7 +15663,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -15730,7 +15730,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -15762,7 +15762,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -15788,7 +15788,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -15933,7 +15933,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -16025,7 +16025,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -16043,7 +16043,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -16056,7 +16056,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -16073,7 +16073,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -16109,7 +16109,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -16126,7 +16126,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -16147,7 +16147,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -16161,7 +16161,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -16179,7 +16179,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 17,
"metadata": {},
"outputs": [
{
@ -16240,7 +16240,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 18,
"metadata": {},
"outputs": [
{
@ -16370,7 +16370,13 @@
"|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_1x1~2|\n",
"|none~0|+|none~0|skip_connect~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|\n",
"|nor_conv_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_1x1~2|\n",
"|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n",
"|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"|none~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\n",
"|none~0|+|skip_connect~0|nor_conv_3x3~1|+|nor_conv_3x3~0|none~1|none~2|\n",
"|skip_connect~0|+|avg_pool_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|avg_pool_3x3~2|\n",
@ -31953,7 +31959,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 19,
"metadata": {},
"outputs": [
{
@ -31976,7 +31982,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -31988,16 +31994,27 @@
" 'skip_connect': 'P', # Phosphorus for skip connection\n",
" 'none': 'S', # Sulfur for no operation\n",
" 'output': 'He' # Helium for output\n",
"}\n"
"}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def graphs_to_json(graphs, filename):\n",
" bonds = {\n",
" 'nor_conv_1x1': 1,\n",
" 'nor_conv_3x3': 2,\n",
" 'avg_pool_3x3': 3,\n",
" 'skip_connect': 4,\n",
" 'input': 0,\n",
" 'output': 5,\n",
" 'none': 6\n",
" }\n",
"\n",
" source_name = \"nas-bench-201\"\n",
" num_graph = len(graphs)\n",
" pt = Chem.GetPeriodicTable()\n",
@ -32020,7 +32037,7 @@
" for graph in graphs:\n",
" ops = graph[1]\n",
" n_atom = len(ops)\n",
" n_bond = 1\n",
" n_bond = len(ops)\n",
" n_atom_list.append(n_atom)\n",
" n_bond_list.append(n_bond)\n",
"\n",
@ -32097,7 +32114,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 22,
"metadata": {},
"outputs": [
{
@ -46742,7 +46759,7 @@
" [1.0, 0.0, 0.0, 0.0, 0.0]]]}"
]
},
"execution_count": 33,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@ -46753,7 +46770,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -46767,14 +46784,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@ -46789,6 +46799,7 @@
"class Dataset(InMemoryDataset):\n",
" def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n",
" self.target_prop = target_prop\n",
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
" self.source = source\n",
" self.api = API(source) # Initialize NAS-Bench-201 API\n",
" super().__init__(root, transform, pre_transform, pre_filter)\n",
@ -46878,24 +46889,16 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
]
}
],
"outputs": [],
"source": [
"dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
"# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@ -46924,7 +46927,11 @@
" def prepare_data(self) -> None:\n",
" target = getattr(self.cfg.dataset, 'guidance_target', None)\n",
" print(\"target\", target)\n",
" base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
" # try:\n",
" # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
" # except NameError:\n",
" # base_path = pathlib.Path(os.getcwd()).parent[2]\n",
" base_path = '/home/stud/hanzhang/Graph-Dit'\n",
" root_path = os.path.join(base_path, self.datadir)\n",
" self.root_path = root_path\n",
"\n",
@ -46935,12 +46942,13 @@
"\n",
" # Load the dataset to the memory\n",
" # Dataset has target property, root path, and transform\n",
" dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)\n",
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
" dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n",
"\n",
" if len(self.task.split('-')) == 2:\n",
" train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
" else:\n",
" train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n",
" # if len(self.task.split('-')) == 2:\n",
" # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
" # else:\n",
" train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n",
"\n",
" self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n",
" train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n",
@ -47004,12 +47012,117 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"cfg = utils.load_config()\n",
"datamodule = DataModule(cfg=cfg)"
"from omegaconf import DictConfig, OmegaConf\n",
"import argparse\n",
"import hydra\n",
"\n",
"def parse_arg():\n",
" parser = argparse.ArgumentParser(description='Diffusion')\n",
" parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n",
" return parser.parse_args()\n",
"\n",
"def task1(cfg: DictConfig):\n",
" datamodule = DataModule(cfg=cfg)\n",
" datamodule.prepare_data()\n",
"\n",
"cfg = {\n",
" 'general':{\n",
" 'name': 'graph_dit',\n",
" 'wandb': 'disabled' ,\n",
" 'gpus': 1,\n",
" 'resume': 'null',\n",
" 'test_only': 'null',\n",
" 'sample_every_val': 2500,\n",
" 'samples_to_generate': 512,\n",
" 'samples_to_save': 3,\n",
" 'chains_to_save': 1,\n",
" 'log_every_steps': 50,\n",
" 'number_chain_steps': 8,\n",
" 'final_model_samples_to_generate': 10000,\n",
" 'final_model_samples_to_save': 20,\n",
" 'final_model_chains_to_save': 1,\n",
" 'enable_progress_bar': False,\n",
" 'save_model': True,\n",
" },\n",
" 'model':{\n",
" 'type': 'discrete',\n",
" 'transition': 'marginal',\n",
" 'model': 'graph_dit',\n",
" 'diffusion_steps': 500,\n",
" 'diffusion_noise_schedule': 'cosine',\n",
" 'guide_scale': 2,\n",
" 'hidden_size': 1152,\n",
" 'depth': 6,\n",
" 'num_heads': 16,\n",
" 'mlp_ratio': 4,\n",
" 'drop_condition': 0.01,\n",
" 'lambda_train': [1, 10], # node and edge training weight \n",
" 'ensure_connected': True,\n",
" },\n",
" 'train':{\n",
" 'n_epochs': 10000,\n",
" 'batch_size': 1200,\n",
" 'lr': 0.0002,\n",
" 'clip_grad': 'null',\n",
" 'num_workers': 0,\n",
" 'weight_decay': 0,\n",
" 'seed': 0,\n",
" 'val_check_interval': 'null',\n",
" 'check_val_every_n_epoch': 1,\n",
" },\n",
" 'dataset':{\n",
" 'datadir': 'data',\n",
" 'task_name': 'nasbench-201',\n",
" 'guidance_target': 'nasbench-201',\n",
" 'pin_memory': False,\n",
" },\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataModule\n",
"task nasbench-201\n",
"datadir data\n",
"target nasbench-201\n",
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
" warnings.warn(msg)\n",
"/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
" warnings.warn(out)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"nasbench-201 dataset len 15625 train len 9375 val len 3125 test len 3125 unlabeled len 0\n",
"train len 9375 val len 3125 test len 3125\n",
"train len 9375 val len 3125 test len 3125\n",
"dataset len 15625 train len 9375 val len 3125 test len 3125\n"
]
}
],
"source": [
"cfg = OmegaConf.create(cfg)\n",
"task1(cfg)"
]
}
],