From b6b7210ba7e4d5fb8232af7b2195bf09ecafc765 Mon Sep 17 00:00:00 2001 From: Hanzhang Ma Date: Tue, 11 Jun 2024 00:08:23 +0200 Subject: [PATCH] need to sync with compute_meta --- graph_dit/test_nasbench.ipynb | 225 +++++++++++++++++++++++++--------- 1 file changed, 169 insertions(+), 56 deletions(-) diff --git a/graph_dit/test_nasbench.ipynb b/graph_dit/test_nasbench.ipynb index 46fdb9c..4f131a4 100644 --- a/graph_dit/test_nasbench.ipynb +++ b/graph_dit/test_nasbench.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -15663,7 +15663,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -15730,7 +15730,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -15762,7 +15762,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -15788,7 +15788,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -15933,7 +15933,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -16025,7 +16025,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -16043,7 +16043,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -16056,7 +16056,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -16073,7 +16073,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -16109,7 +16109,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -16126,7 +16126,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -16147,7 +16147,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -16161,7 +16161,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -16179,7 +16179,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -16240,7 +16240,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -16370,7 +16370,13 @@ "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_1x1~2|\n", "|none~0|+|none~0|skip_connect~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|\n", "|nor_conv_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_1x1~2|\n", - "|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n", + "|nor_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "|none~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\n", "|none~0|+|skip_connect~0|nor_conv_3x3~1|+|nor_conv_3x3~0|none~1|none~2|\n", "|skip_connect~0|+|avg_pool_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|avg_pool_3x3~2|\n", @@ -31953,7 +31959,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -31976,7 +31982,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -31988,16 +31994,27 @@ " 'skip_connect': 'P', # Phosphorus for skip connection\n", " 'none': 'S', # Sulfur for no operation\n", " 'output': 'He' # Helium for output\n", - "}\n" + "}\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def graphs_to_json(graphs, filename):\n", + " bonds = {\n", + " 'nor_conv_1x1': 1,\n", + " 'nor_conv_3x3': 2,\n", + " 'avg_pool_3x3': 3,\n", + " 'skip_connect': 4,\n", + " 'input': 0,\n", + " 'output': 5,\n", + " 'none': 6\n", + " }\n", + "\n", " source_name = \"nas-bench-201\"\n", " num_graph = len(graphs)\n", " pt = Chem.GetPeriodicTable()\n", @@ -32020,7 +32037,7 @@ " for graph in graphs:\n", " ops = graph[1]\n", " n_atom = len(ops)\n", - " n_bond = 1\n", + " n_bond = len(ops)\n", " n_atom_list.append(n_atom)\n", " n_bond_list.append(n_bond)\n", "\n", @@ -32097,7 +32114,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -46742,7 +46759,7 @@ " [1.0, 0.0, 0.0, 0.0, 0.0]]]}" ] }, - "execution_count": 33, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -46753,7 +46770,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -46767,14 +46784,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 35, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -46789,6 +46799,7 @@ "class Dataset(InMemoryDataset):\n", " def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):\n", " self.target_prop = target_prop\n", + " source = './NAS-Bench-201-v1_1-096897.pth'\n", " self.source = source\n", " self.api = API(source) # Initialize NAS-Bench-201 API\n", " super().__init__(root, transform, pre_transform, pre_filter)\n", @@ -46878,24 +46889,16 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 39, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n" - ] - } - ], + "outputs": [], "source": [ - "dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')" + "# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -46924,7 +46927,11 @@ " def prepare_data(self) -> None:\n", " target = getattr(self.cfg.dataset, 'guidance_target', None)\n", " print(\"target\", target)\n", - " base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n", + " # try:\n", + " # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]\n", + " # except NameError:\n", + " # base_path = pathlib.Path(os.getcwd()).parent[2]\n", + " base_path = '/home/stud/hanzhang/Graph-Dit'\n", " root_path = os.path.join(base_path, self.datadir)\n", " self.root_path = root_path\n", "\n", @@ -46935,12 +46942,13 @@ "\n", " # Load the dataset to the memory\n", " # Dataset has target property, root path, and transform\n", - " dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)\n", + " source = './NAS-Bench-201-v1_1-096897.pth'\n", + " dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)\n", "\n", - " if len(self.task.split('-')) == 2:\n", - " train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n", - " else:\n", - " train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n", + " # if len(self.task.split('-')) == 2:\n", + " # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)\n", + " # else:\n", + " train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)\n", "\n", " self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index\n", " train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)\n", @@ -47004,12 +47012,117 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ - "cfg = utils.load_config()\n", - "datamodule = DataModule(cfg=cfg)" + "from omegaconf import DictConfig, OmegaConf\n", + "import argparse\n", + "import hydra\n", + "\n", + "def parse_arg():\n", + " parser = argparse.ArgumentParser(description='Diffusion')\n", + " parser.add_argument('--config', type=str, default='config.yaml', help='config file')\n", + " return parser.parse_args()\n", + "\n", + "def task1(cfg: DictConfig):\n", + " datamodule = DataModule(cfg=cfg)\n", + " datamodule.prepare_data()\n", + "\n", + "cfg = {\n", + " 'general':{\n", + " 'name': 'graph_dit',\n", + " 'wandb': 'disabled' ,\n", + " 'gpus': 1,\n", + " 'resume': 'null',\n", + " 'test_only': 'null',\n", + " 'sample_every_val': 2500,\n", + " 'samples_to_generate': 512,\n", + " 'samples_to_save': 3,\n", + " 'chains_to_save': 1,\n", + " 'log_every_steps': 50,\n", + " 'number_chain_steps': 8,\n", + " 'final_model_samples_to_generate': 10000,\n", + " 'final_model_samples_to_save': 20,\n", + " 'final_model_chains_to_save': 1,\n", + " 'enable_progress_bar': False,\n", + " 'save_model': True,\n", + " },\n", + " 'model':{\n", + " 'type': 'discrete',\n", + " 'transition': 'marginal',\n", + " 'model': 'graph_dit',\n", + " 'diffusion_steps': 500,\n", + " 'diffusion_noise_schedule': 'cosine',\n", + " 'guide_scale': 2,\n", + " 'hidden_size': 1152,\n", + " 'depth': 6,\n", + " 'num_heads': 16,\n", + " 'mlp_ratio': 4,\n", + " 'drop_condition': 0.01,\n", + " 'lambda_train': [1, 10], # node and edge training weight \n", + " 'ensure_connected': True,\n", + " },\n", + " 'train':{\n", + " 'n_epochs': 10000,\n", + " 'batch_size': 1200,\n", + " 'lr': 0.0002,\n", + " 'clip_grad': 'null',\n", + " 'num_workers': 0,\n", + " 'weight_decay': 0,\n", + " 'seed': 0,\n", + " 'val_check_interval': 'null',\n", + " 'check_val_every_n_epoch': 1,\n", + " },\n", + " 'dataset':{\n", + " 'datadir': 'data',\n", + " 'task_name': 'nasbench-201',\n", + " 'guidance_target': 'nasbench-201',\n", + " 'pin_memory': False,\n", + " },\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataModule\n", + "task nasbench-201\n", + "datadir data\n", + "target nasbench-201\n", + "try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n", + " warnings.warn(msg)\n", + "/home/stud/hanzhang/anaconda3/envs/graphdit/lib/python3.9/site-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", + " warnings.warn(out)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nasbench-201 dataset len 15625 train len 9375 val len 3125 test len 3125 unlabeled len 0\n", + "train len 9375 val len 3125 test len 3125\n", + "train len 9375 val len 3125 test len 3125\n", + "dataset len 15625 train len 9375 val len 3125 test len 3125\n" + ] + } + ], + "source": [ + "cfg = OmegaConf.create(cfg)\n", + "task1(cfg)" ] } ],