need to sync with compute_meta
This commit is contained in:
		| @@ -2,7 +2,7 @@ | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 12, | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -11,7 +11,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 13, | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -20,7 +20,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 14, | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -15663,7 +15663,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 15, | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -15730,7 +15730,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 16, | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -15762,7 +15762,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 17, | ||||
|    "execution_count": 6, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -15788,7 +15788,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 18, | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -15933,7 +15933,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 19, | ||||
|    "execution_count": 8, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16025,7 +16025,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 20, | ||||
|    "execution_count": 9, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16043,7 +16043,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 21, | ||||
|    "execution_count": 10, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16056,7 +16056,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 22, | ||||
|    "execution_count": 11, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -16073,7 +16073,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 23, | ||||
|    "execution_count": 12, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -16109,7 +16109,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 24, | ||||
|    "execution_count": 13, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16126,7 +16126,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 25, | ||||
|    "execution_count": 14, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16147,7 +16147,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 26, | ||||
|    "execution_count": 15, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -16161,7 +16161,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 27, | ||||
|    "execution_count": 16, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -16179,7 +16179,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 28, | ||||
|    "execution_count": 17, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -16240,7 +16240,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 29, | ||||
|    "execution_count": 18, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -16370,7 +16370,13 @@ | ||||
|       "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_1x1~2|\n", | ||||
|       "|none~0|+|none~0|skip_connect~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|\n", | ||||
|       "|nor_conv_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_1x1~2|\n", | ||||
|       "|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n", | ||||
|       "|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "|none~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\n", | ||||
|       "|none~0|+|skip_connect~0|nor_conv_3x3~1|+|nor_conv_3x3~0|none~1|none~2|\n", | ||||
|       "|skip_connect~0|+|avg_pool_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|avg_pool_3x3~2|\n", | ||||
| @@ -31953,7 +31959,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 30, | ||||
|    "execution_count": 19, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -31976,7 +31982,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 31, | ||||
|    "execution_count": 20, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -31988,16 +31994,27 @@ | ||||
|     "    'skip_connect': 'P',   # Phosphorus for skip connection\n", | ||||
|     "    'none': 'S',           # Sulfur for no operation\n", | ||||
|     "    'output': 'He'         # Helium for output\n", | ||||
|     "}\n" | ||||
|     "}\n", | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 32, | ||||
|    "execution_count": 21, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def graphs_to_json(graphs, filename):\n", | ||||
|     "    bonds = {\n", | ||||
|     "        'nor_conv_1x1': 1,\n", | ||||
|     "        'nor_conv_3x3': 2,\n", | ||||
|     "        'avg_pool_3x3': 3,\n", | ||||
|     "        'skip_connect': 4,\n", | ||||
|     "        'input': 0,\n", | ||||
|     "        'output': 5,\n", | ||||
|     "        'none': 6\n", | ||||
|     "    }\n", | ||||
|     "\n", | ||||
|     "    source_name = \"nas-bench-201\"\n", | ||||
|     "    num_graph = len(graphs)\n", | ||||
|     "    pt = Chem.GetPeriodicTable()\n", | ||||
| @@ -32020,7 +32037,7 @@ | ||||
|     "    for graph in graphs:\n", | ||||
|     "        ops = graph[1]\n", | ||||
|     "        n_atom = len(ops)\n", | ||||
|     "        n_bond = 1\n", | ||||
|     "        n_bond = len(ops)\n", | ||||
|     "        n_atom_list.append(n_atom)\n", | ||||
|     "        n_bond_list.append(n_bond)\n", | ||||
|     "\n", | ||||
| @@ -32097,7 +32114,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 33, | ||||
|    "execution_count": 22, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -46742,7 +46759,7 @@ | ||||
|        "   [1.0, 0.0, 0.0, 0.0, 0.0]]]}" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 33, | ||||
|      "execution_count": 22, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
| @@ -46753,7 +46770,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 34, | ||||
|    "execution_count": 23, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -46767,14 +46784,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 35, | ||||
|    "execution_count": 38, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -46789,6 +46799,7 @@ | ||||
|     "class Dataset(InMemoryDataset):\n", | ||||
|     "    def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n", | ||||
|     "        self.target_prop = target_prop\n", | ||||
|     "        source = './NAS-Bench-201-v1_1-096897.pth'\n", | ||||
|     "        self.source = source\n", | ||||
|     "        self.api = API(source)  # Initialize NAS-Bench-201 API\n", | ||||
|     "        super().__init__(root, transform, pre_transform, pre_filter)\n", | ||||
| @@ -46878,24 +46889,16 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 36, | ||||
|    "execution_count": 39, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')" | ||||
|     "# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 38, | ||||
|    "execution_count": 40, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -46924,7 +46927,11 @@ | ||||
|     "    def prepare_data(self) -> None:\n", | ||||
|     "        target = getattr(self.cfg.dataset, 'guidance_target', None)\n", | ||||
|     "        print(\"target\", target)\n", | ||||
|     "        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n", | ||||
|     "        # try:\n", | ||||
|     "        #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n", | ||||
|     "        # except NameError:\n", | ||||
|     "        # base_path = pathlib.Path(os.getcwd()).parent[2]\n", | ||||
|     "        base_path = '/home/stud/hanzhang/Graph-Dit'\n", | ||||
|     "        root_path = os.path.join(base_path, self.datadir)\n", | ||||
|     "        self.root_path = root_path\n", | ||||
|     "\n", | ||||
| @@ -46935,12 +46942,13 @@ | ||||
|     "\n", | ||||
|     "        # Load the dataset to the memory\n", | ||||
|     "        # Dataset has target property, root path, and transform\n", | ||||
|     "        dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)\n", | ||||
|     "        source = './NAS-Bench-201-v1_1-096897.pth'\n", | ||||
|     "        dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n", | ||||
|     "\n", | ||||
|     "        if len(self.task.split('-')) == 2:\n", | ||||
|     "            train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n", | ||||
|     "        else:\n", | ||||
|     "            train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n", | ||||
|     "        # if len(self.task.split('-')) == 2:\n", | ||||
|     "        #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n", | ||||
|     "        # else:\n", | ||||
|     "        train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n", | ||||
|     "\n", | ||||
|     "        self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n", | ||||
|     "        train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n", | ||||
| @@ -47004,12 +47012,117 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 41, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "cfg = utils.load_config()\n", | ||||
|     "datamodule = DataModule(cfg=cfg)" | ||||
|     "from omegaconf import DictConfig, OmegaConf\n", | ||||
|     "import argparse\n", | ||||
|     "import hydra\n", | ||||
|     "\n", | ||||
|     "def parse_arg():\n", | ||||
|     "    parser = argparse.ArgumentParser(description='Diffusion')\n", | ||||
|     "    parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n", | ||||
|     "    return parser.parse_args()\n", | ||||
|     "\n", | ||||
|     "def task1(cfg: DictConfig):\n", | ||||
|     "    datamodule = DataModule(cfg=cfg)\n", | ||||
|     "    datamodule.prepare_data()\n", | ||||
|     "\n", | ||||
|     "cfg = {\n", | ||||
|     "    'general':{\n", | ||||
|     "        'name': 'graph_dit',\n", | ||||
|     "        'wandb': 'disabled' ,\n", | ||||
|     "        'gpus': 1,\n", | ||||
|     "        'resume': 'null',\n", | ||||
|     "        'test_only': 'null',\n", | ||||
|     "        'sample_every_val': 2500,\n", | ||||
|     "        'samples_to_generate': 512,\n", | ||||
|     "        'samples_to_save': 3,\n", | ||||
|     "        'chains_to_save': 1,\n", | ||||
|     "        'log_every_steps': 50,\n", | ||||
|     "        'number_chain_steps': 8,\n", | ||||
|     "        'final_model_samples_to_generate': 10000,\n", | ||||
|     "        'final_model_samples_to_save': 20,\n", | ||||
|     "        'final_model_chains_to_save': 1,\n", | ||||
|     "        'enable_progress_bar': False,\n", | ||||
|     "        'save_model': True,\n", | ||||
|     "    },\n", | ||||
|     "    'model':{\n", | ||||
|     "            'type': 'discrete',\n", | ||||
|     "            'transition': 'marginal',\n", | ||||
|     "            'model': 'graph_dit',\n", | ||||
|     "            'diffusion_steps': 500,\n", | ||||
|     "            'diffusion_noise_schedule': 'cosine',\n", | ||||
|     "            'guide_scale': 2,\n", | ||||
|     "            'hidden_size': 1152,\n", | ||||
|     "            'depth': 6,\n", | ||||
|     "            'num_heads': 16,\n", | ||||
|     "            'mlp_ratio': 4,\n", | ||||
|     "            'drop_condition': 0.01,\n", | ||||
|     "            'lambda_train': [1, 10],  # node and edge training weight \n", | ||||
|     "            'ensure_connected': True,\n", | ||||
|     "    },\n", | ||||
|     "    'train':{\n", | ||||
|     "            'n_epochs': 10000,\n", | ||||
|     "            'batch_size': 1200,\n", | ||||
|     "            'lr': 0.0002,\n", | ||||
|     "            'clip_grad': 'null',\n", | ||||
|     "            'num_workers': 0,\n", | ||||
|     "            'weight_decay': 0,\n", | ||||
|     "            'seed': 0,\n", | ||||
|     "            'val_check_interval': 'null',\n", | ||||
|     "            'check_val_every_n_epoch': 1,\n", | ||||
|     "    },\n", | ||||
|     "    'dataset':{\n", | ||||
|     "            'datadir': 'data',\n", | ||||
|     "            'task_name': 'nasbench-201',\n", | ||||
|     "            'guidance_target': 'nasbench-201',\n", | ||||
|     "            'pin_memory': False,\n", | ||||
|     "    },\n", | ||||
|     "}\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 42, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "DataModule\n", | ||||
|       "task nasbench-201\n", | ||||
|       "datadir data\n", | ||||
|       "target nasbench-201\n", | ||||
|       "try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n", | ||||
|       "  warnings.warn(msg)\n", | ||||
|       "/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", | ||||
|       "  warnings.warn(out)\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "nasbench-201  dataset len 15625 train len 9375 val len 3125 test len 3125 unlabeled len 0\n", | ||||
|       "train len 9375 val len 3125 test len 3125\n", | ||||
|       "train len 9375 val len 3125 test len 3125\n", | ||||
|       "dataset len 15625 train len 9375 val len 3125 test len 3125\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "cfg = OmegaConf.create(cfg)\n", | ||||
|     "task1(cfg)" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|   | ||||
		Reference in New Issue
	
	Block a user