From bb33ca9a68008b1882e54d403f9b4a21d671f0bf Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Thu, 11 Jul 2024 11:48:51 +0200 Subject: [PATCH] run the specific model --- test.ipynb | 502 +++++++++++++++++++++++ xautodl/models/cell_searchs/genotypes.py | 8 + xautodl/utils/evaluation_utils.py | 3 +- 3 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 test.ipynb diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..dcae5cb --- /dev/null +++ b/test.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nats_bench import create\n", + "\n", + "# Create the API for size search space\n", + "api = create(None, 'sss', fast_mode=True, verbose=True)\n", + "\n", + "# Create the API for tologoy search space\n", + "api = create(None, 'tss', fast_mode=True, verbose=True)\n", + "\n", + "# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n", + "# info is a dict, where you can easily figure out the meaning by key\n", + "info = api.get_more_info(1234, 'cifar10')\n", + "\n", + "# Query the flops, params, latency. info is a dict.\n", + "info = api.get_cost_info(12, 'cifar10')\n", + "\n", + "# Simulate the training of the 1224-th candidate:\n", + "validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n", + "\n", + "# Clear the parameters of the 12-th candidate.\n", + "api.clear_params(12)\n", + "\n", + "# Reload all information of the 12-th candidate.\n", + "api.reload(index=12)\n", + "\n", + "# Create the instance of th 12-th candidate for CIFAR-10.\n", + "from models import get_cell_based_tiny_net\n", + "config = api.get_net_config(12, 'cifar10')\n", + "network = get_cell_based_tiny_net(config)\n", + "\n", + "# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n", + "params = api.get_net_param(12, 'cifar10', None)\n", + "network.load_state_dict(next(iter(params.values())))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nas_201_api import NASBench201API as API\n", + "import os\n", + "# api = API('./NAS-Bench-201-v1_1_096897.pth')\n", + "# get the current path\n", + "print(os.path.abspath(os.path.curdir))\n", + "cur_path = os.path.abspath(os.path.curdir)\n", + "data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n", + "api = API(data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get the best performance on CIFAR-10\n", + "len = 15625\n", + "accs = []\n", + "for i in range(1, len):\n", + " results = api.query_by_index(i, 'cifar10')\n", + " dict_items = list(results.items())\n", + " train_info = dict_items[0][1].get_train()\n", + " acc = train_info['accuracy']\n", + " accs.append((i, acc))\n", + "print(max(accs, key=lambda x: x[1]))\n", + "best_index, best_acc = max(accs, key=lambda x: x[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def find_best_index(dataset):\n", + " len = 15625\n", + " accs = []\n", + " for i in range(1, len):\n", + " results = api.query_by_index(i, dataset)\n", + " dict_items = list(results.items())\n", + " train_info = dict_items[0][1].get_train()\n", + " acc = train_info['accuracy']\n", + " accs.append((i, acc))\n", + " return max(accs, key=lambda x: x[1])\n", + "best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n", + "best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n", + "best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n", + "print(best_cifar_10_index, best_cifar_10_acc)\n", + "print(best_cifar_100_index, best_cifar_100_acc)\n", + "print(best_ImageNet16_index, best_ImageNet16_acc)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "api.show(5374)\n", + "config = api.get_net_config(best_index, 'cifar10')\n", + "from models import get_cell_based_tiny_net\n", + "network = get_cell_based_tiny_net(config)\n", + "print(network)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "api.get_net_param(5374, 'cifar10', None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os, sys, time, torch, random, argparse\n", + "from PIL import ImageFile\n", + "\n", + "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "\n", + "from config_utils import load_config\n", + "from procedures.starts import get_machine_info\n", + "from datasets.get_dataset_with_transform import get_datasets\n", + "from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n", + "from models import CellStructure, CellArchitectures, get_search_spaces" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_all_datasets(\n", + " arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n", + "):\n", + " machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n", + " all_infos = {\"info\": machine_info}\n", + " all_dataset_keys = []\n", + " # look all the datasets\n", + " for dataset, xpath, split in zip(datasets, xpaths, splits):\n", + " # train valid data\n", + " train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n", + " # load the configuration\n", + " if dataset == \"cifar10\" or dataset == \"cifar100\":\n", + " if use_less:\n", + " config_path = \"configs/nas-benchmark/LESS.config\"\n", + " else:\n", + " config_path = \"configs/nas-benchmark/CIFAR.config\"\n", + " split_info = load_config(\n", + " \"configs/nas-benchmark/cifar-split.txt\", None, None\n", + " )\n", + " elif dataset.startswith(\"ImageNet16\"):\n", + " if use_less:\n", + " config_path = \"configs/nas-benchmark/LESS.config\"\n", + " else:\n", + " config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n", + " split_info = load_config(\n", + " \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n", + " )\n", + " else:\n", + " raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", + " config = load_config(\n", + " config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n", + " )\n", + " # check whether use splited validation set\n", + " if bool(split):\n", + " assert dataset == \"cifar10\"\n", + " ValLoaders = {\n", + " \"ori-test\": torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " shuffle=False,\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " )\n", + " }\n", + " assert len(train_data) == len(split_info.train) + len(\n", + " split_info.valid\n", + " ), \"invalid length : {:} vs {:} + {:}\".format(\n", + " len(train_data), len(split_info.train), len(split_info.valid)\n", + " )\n", + " train_data_v2 = deepcopy(train_data)\n", + " train_data_v2.transform = valid_data.transform\n", + " valid_data = train_data_v2\n", + " # data loader\n", + " train_loader = torch.utils.data.DataLoader(\n", + " train_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " )\n", + " valid_loader = torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " )\n", + " ValLoaders[\"x-valid\"] = valid_loader\n", + " else:\n", + " # data loader\n", + " train_loader = torch.utils.data.DataLoader(\n", + " train_data,\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " )\n", + " valid_loader = torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " shuffle=False,\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " )\n", + " if dataset == \"cifar10\":\n", + " ValLoaders = {\"ori-test\": valid_loader}\n", + " elif dataset == \"cifar100\":\n", + " cifar100_splits = load_config(\n", + " \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n", + " )\n", + " ValLoaders = {\n", + " \"ori-test\": valid_loader,\n", + " \"x-valid\": torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", + " cifar100_splits.xvalid\n", + " ),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " ),\n", + " \"x-test\": torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", + " cifar100_splits.xtest\n", + " ),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " ),\n", + " }\n", + " elif dataset == \"ImageNet16-120\":\n", + " imagenet16_splits = load_config(\n", + " \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n", + " )\n", + " ValLoaders = {\n", + " \"ori-test\": valid_loader,\n", + " \"x-valid\": torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", + " imagenet16_splits.xvalid\n", + " ),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " ),\n", + " \"x-test\": torch.utils.data.DataLoader(\n", + " valid_data,\n", + " batch_size=config.batch_size,\n", + " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", + " imagenet16_splits.xtest\n", + " ),\n", + " num_workers=workers,\n", + " pin_memory=True,\n", + " ),\n", + " }\n", + " else:\n", + " raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", + "\n", + " dataset_key = \"{:}\".format(dataset)\n", + " if bool(split):\n", + " dataset_key = dataset_key + \"-valid\"\n", + " logger.log(\n", + " \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n", + " dataset_key,\n", + " len(train_data),\n", + " len(valid_data),\n", + " len(train_loader),\n", + " len(valid_loader),\n", + " config.batch_size,\n", + " )\n", + " )\n", + " logger.log(\n", + " \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n", + " )\n", + " for key, value in ValLoaders.items():\n", + " logger.log(\n", + " \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n", + " )\n", + " results = evaluate_for_seed(\n", + " arch_config, config, arch, train_loader, ValLoaders, seed, logger\n", + " )\n", + " all_infos[dataset_key] = results\n", + " all_dataset_keys.append(dataset_key)\n", + " all_infos[\"all_dataset_keys\"] = all_dataset_keys\n", + " return all_infos\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_single_model(\n", + " save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n", + "):\n", + " assert torch.cuda.is_available(), \"CUDA is not available.\"\n", + " torch.backends.cudnn.enabled = True\n", + " torch.backends.cudnn.deterministic = True\n", + " # torch.backends.cudnn.benchmark = True\n", + " torch.set_num_threads(workers)\n", + "\n", + " save_dir = (\n", + " Path(save_dir)\n", + " / \"specifics\"\n", + " / \"{:}-{:}-{:}-{:}\".format(\n", + " \"LESS\" if use_less else \"FULL\",\n", + " model_str,\n", + " arch_config[\"channel\"],\n", + " arch_config[\"num_cells\"],\n", + " )\n", + " )\n", + " logger = Logger(str(save_dir), 0, False)\n", + " if model_str in CellArchitectures:\n", + " arch = CellArchitectures[model_str]\n", + " logger.log(\n", + " \"The model string is found in pre-defined architecture dict : {:}\".format(\n", + " model_str\n", + " )\n", + " )\n", + " else:\n", + " try:\n", + " arch = CellStructure.str2structure(model_str)\n", + " except:\n", + " raise ValueError(\n", + " \"Invalid model string : {:}. It can not be found or parsed.\".format(\n", + " model_str\n", + " )\n", + " )\n", + " assert arch.check_valid_op(\n", + " get_search_spaces(\"cell\", \"full\")\n", + " ), \"{:} has the invalid op.\".format(arch)\n", + " logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n", + " logger.log(\"arch_config : {:}\".format(arch_config))\n", + "\n", + " start_time, seed_time = time.time(), AverageMeter()\n", + " for _is, seed in enumerate(seeds):\n", + " logger.log(\n", + " \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n", + " _is, len(seeds), seed\n", + " )\n", + " )\n", + " to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n", + " if to_save_name.exists():\n", + " logger.log(\n", + " \"Find the existing file {:}, directly load!\".format(to_save_name)\n", + " )\n", + " checkpoint = torch.load(to_save_name)\n", + " else:\n", + " logger.log(\n", + " \"Does not find the existing file {:}, train and evaluate!\".format(\n", + " to_save_name\n", + " )\n", + " )\n", + " checkpoint = evaluate_all_datasets(\n", + " arch,\n", + " datasets,\n", + " xpaths,\n", + " splits,\n", + " use_less,\n", + " seed,\n", + " arch_config,\n", + " workers,\n", + " logger,\n", + " )\n", + " torch.save(checkpoint, to_save_name)\n", + " # log information\n", + " logger.log(\"{:}\".format(checkpoint[\"info\"]))\n", + " all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n", + " for dataset_key in all_dataset_keys:\n", + " logger.log(\n", + " \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n", + " )\n", + " dataset_info = checkpoint[dataset_key]\n", + " # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n", + " logger.log(\n", + " \"Flops = {:} MB, Params = {:} MB\".format(\n", + " dataset_info[\"flop\"], dataset_info[\"param\"]\n", + " )\n", + " )\n", + " logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n", + " logger.log(\n", + " \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n", + " )\n", + " last_epoch = dataset_info[\"total_epoch\"] - 1\n", + " train_acc1es, train_acc5es = (\n", + " dataset_info[\"train_acc1es\"],\n", + " dataset_info[\"train_acc5es\"],\n", + " )\n", + " valid_acc1es, valid_acc5es = (\n", + " dataset_info[\"valid_acc1es\"],\n", + " dataset_info[\"valid_acc5es\"],\n", + " )\n", + " logger.log(\n", + " \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n", + " train_acc1es[last_epoch],\n", + " train_acc5es[last_epoch],\n", + " 100 - train_acc1es[last_epoch],\n", + " valid_acc1es[last_epoch],\n", + " valid_acc5es[last_epoch],\n", + " 100 - valid_acc1es[last_epoch],\n", + " )\n", + " )\n", + " # measure elapsed time\n", + " seed_time.update(time.time() - start_time)\n", + " start_time = time.time()\n", + " need_time = \"Time Left: {:}\".format(\n", + " convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n", + " )\n", + " logger.log(\n", + " \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}\".format(\n", + " _is, len(seeds), seed, need_time\n", + " )\n", + " )\n", + " logger.close()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_single_model(\n", + " save_dir=\"./outputs\",\n", + " workers=8,\n", + " datasets=\"cifar10\", \n", + " xpaths=\"/root/cifardata/cifar-10-batches-py\",\n", + " splits=[0, 0, 0],\n", + " use_less=False,\n", + " seeds=[777],\n", + " model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n", + " arch_config={\"channel\": 16, \"num_cells\": 8},)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "natsbench", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/xautodl/models/cell_searchs/genotypes.py b/xautodl/models/cell_searchs/genotypes.py index f0ec8f2..30f1232 100644 --- a/xautodl/models/cell_searchs/genotypes.py +++ b/xautodl/models/cell_searchs/genotypes.py @@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure( (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), ] # node-3 ) +Number_5374 = Structure( + [ + (("nor_conv_3x3", 0),), # node-1 + (("nor_conv_1x1", 0), ("nor_conv_3x3", 1)), # node-2 + (("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)), # node-3 + ] +) AllFull_CODE = Structure( [ @@ -271,4 +278,5 @@ architectures = { "all_c1x1": AllConv1x1_CODE, "all_idnt": AllIdentity_CODE, "all_full": AllFull_CODE, + "5374": Number_5374, } diff --git a/xautodl/utils/evaluation_utils.py b/xautodl/utils/evaluation_utils.py index 088f318..a0a5e74 100644 --- a/xautodl/utils/evaluation_utils.py +++ b/xautodl/utils/evaluation_utils.py @@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res