need to sync with compute_meta

This commit is contained in:
Hanzhang Ma 2024-06-11 00:08:23 +02:00
parent 7831979db7
commit b6b7210ba7

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -15663,7 +15663,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -15730,7 +15730,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -15762,7 +15762,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -15788,7 +15788,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -15933,7 +15933,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -16025,7 +16025,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -16043,7 +16043,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -16056,7 +16056,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -16073,7 +16073,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -16109,7 +16109,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -16126,7 +16126,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -16147,7 +16147,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -16161,7 +16161,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -16179,7 +16179,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 17,
"metadata": {},
"outputs": [
{
@ -16240,7 +16240,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 18,
"metadata": {},
"outputs": [
{
@ -16370,7 +16370,13 @@
"|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_1x1~2|\n",
"|none~0|+|none~0|skip_connect~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|\n",
"|nor_conv_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_1x1~2|\n",
"|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n",
"|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"|none~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\n",
"|none~0|+|skip_connect~0|nor_conv_3x3~1|+|nor_conv_3x3~0|none~1|none~2|\n",
"|skip_connect~0|+|avg_pool_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|avg_pool_3x3~2|\n",
@ -31953,7 +31959,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 19,
"metadata": {},
"outputs": [
{
@ -31976,7 +31982,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -31988,16 +31994,27 @@
" 'skip_connect': 'P', # Phosphorus for skip connection\n",
" 'none': 'S', # Sulfur for no operation\n",
" 'output': 'He' # Helium for output\n",
"}\n"
"}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def graphs_to_json(graphs, filename):\n",
" bonds = {\n",
" 'nor_conv_1x1': 1,\n",
" 'nor_conv_3x3': 2,\n",
" 'avg_pool_3x3': 3,\n",
" 'skip_connect': 4,\n",
" 'input': 0,\n",
" 'output': 5,\n",
" 'none': 6\n",
" }\n",
"\n",
" source_name = \"nas-bench-201\"\n",
" num_graph = len(graphs)\n",
" pt = Chem.GetPeriodicTable()\n",
@ -32020,7 +32037,7 @@
" for graph in graphs:\n",
" ops = graph[1]\n",
" n_atom = len(ops)\n",
" n_bond = 1\n",
" n_bond = len(ops)\n",
" n_atom_list.append(n_atom)\n",
" n_bond_list.append(n_bond)\n",
"\n",
@ -32097,7 +32114,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 22,
"metadata": {},
"outputs": [
{
@ -46742,7 +46759,7 @@
" [1.0, 0.0, 0.0, 0.0, 0.0]]]}"
]
},
"execution_count": 33,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@ -46753,7 +46770,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -46767,14 +46784,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@ -46789,6 +46799,7 @@
"class Dataset(InMemoryDataset):\n",
" def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n",
" self.target_prop = target_prop\n",
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
" self.source = source\n",
" self.api = API(source) # Initialize NAS-Bench-201 API\n",
" super().__init__(root, transform, pre_transform, pre_filter)\n",
@ -46878,24 +46889,16 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
]
}
],
"outputs": [],
"source": [
"dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
"# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@ -46924,7 +46927,11 @@
" def prepare_data(self) -> None:\n",
" target = getattr(self.cfg.dataset, 'guidance_target', None)\n",
" print(\"target\", target)\n",
" base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
" # try:\n",
" # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n",
" # except NameError:\n",
" # base_path = pathlib.Path(os.getcwd()).parent[2]\n",
" base_path = '/home/stud/hanzhang/Graph-Dit'\n",
" root_path = os.path.join(base_path, self.datadir)\n",
" self.root_path = root_path\n",
"\n",
@ -46935,11 +46942,12 @@
"\n",
" # Load the dataset to the memory\n",
" # Dataset has target property, root path, and transform\n",
" dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)\n",
" source = './NAS-Bench-201-v1_1-096897.pth'\n",
" dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n",
"\n",
" if len(self.task.split('-')) == 2:\n",
" train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
" else:\n",
" # if len(self.task.split('-')) == 2:\n",
" # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n",
" # else:\n",
" train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n",
"\n",
" self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n",
@ -47004,12 +47012,117 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"cfg = utils.load_config()\n",
"datamodule = DataModule(cfg=cfg)"
"from omegaconf import DictConfig, OmegaConf\n",
"import argparse\n",
"import hydra\n",
"\n",
"def parse_arg():\n",
" parser = argparse.ArgumentParser(description='Diffusion')\n",
" parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n",
" return parser.parse_args()\n",
"\n",
"def task1(cfg: DictConfig):\n",
" datamodule = DataModule(cfg=cfg)\n",
" datamodule.prepare_data()\n",
"\n",
"cfg = {\n",
" 'general':{\n",
" 'name': 'graph_dit',\n",
" 'wandb': 'disabled' ,\n",
" 'gpus': 1,\n",
" 'resume': 'null',\n",
" 'test_only': 'null',\n",
" 'sample_every_val': 2500,\n",
" 'samples_to_generate': 512,\n",
" 'samples_to_save': 3,\n",
" 'chains_to_save': 1,\n",
" 'log_every_steps': 50,\n",
" 'number_chain_steps': 8,\n",
" 'final_model_samples_to_generate': 10000,\n",
" 'final_model_samples_to_save': 20,\n",
" 'final_model_chains_to_save': 1,\n",
" 'enable_progress_bar': False,\n",
" 'save_model': True,\n",
" },\n",
" 'model':{\n",
" 'type': 'discrete',\n",
" 'transition': 'marginal',\n",
" 'model': 'graph_dit',\n",
" 'diffusion_steps': 500,\n",
" 'diffusion_noise_schedule': 'cosine',\n",
" 'guide_scale': 2,\n",
" 'hidden_size': 1152,\n",
" 'depth': 6,\n",
" 'num_heads': 16,\n",
" 'mlp_ratio': 4,\n",
" 'drop_condition': 0.01,\n",
" 'lambda_train': [1, 10], # node and edge training weight \n",
" 'ensure_connected': True,\n",
" },\n",
" 'train':{\n",
" 'n_epochs': 10000,\n",
" 'batch_size': 1200,\n",
" 'lr': 0.0002,\n",
" 'clip_grad': 'null',\n",
" 'num_workers': 0,\n",
" 'weight_decay': 0,\n",
" 'seed': 0,\n",
" 'val_check_interval': 'null',\n",
" 'check_val_every_n_epoch': 1,\n",
" },\n",
" 'dataset':{\n",
" 'datadir': 'data',\n",
" 'task_name': 'nasbench-201',\n",
" 'guidance_target': 'nasbench-201',\n",
" 'pin_memory': False,\n",
" },\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataModule\n",
"task nasbench-201\n",
"datadir data\n",
"target nasbench-201\n",
"try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
" warnings.warn(msg)\n",
"/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
" warnings.warn(out)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"nasbench-201 dataset len 15625 train len 9375 val len 3125 test len 3125 unlabeled len 0\n",
"train len 9375 val len 3125 test len 3125\n",
"train len 9375 val len 3125 test len 3125\n",
"dataset len 15625 train len 9375 val len 3125 test len 3125\n"
]
}
],
"source": [
"cfg = OmegaConf.create(cfg)\n",
"task1(cfg)"
]
}
],