{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nats_bench import create\n", "\n", "# Create the API for size search space\n", "api = create(None, 'sss', fast_mode=True, verbose=True)\n", "\n", "# Create the API for tologoy search space\n", "api = create(None, 'tss', fast_mode=True, verbose=True)\n", "\n", "# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n", "# info is a dict, where you can easily figure out the meaning by key\n", "info = api.get_more_info(1234, 'cifar10')\n", "\n", "# Query the flops, params, latency. info is a dict.\n", "info = api.get_cost_info(12, 'cifar10')\n", "\n", "# Simulate the training of the 1224-th candidate:\n", "validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n", "\n", "# Clear the parameters of the 12-th candidate.\n", "api.clear_params(12)\n", "\n", "# Reload all information of the 12-th candidate.\n", "api.reload(index=12)\n", "\n", "# Create the instance of th 12-th candidate for CIFAR-10.\n", "from models import get_cell_based_tiny_net\n", "config = api.get_net_config(12, 'cifar10')\n", "network = get_cell_based_tiny_net(config)\n", "\n", "# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n", "params = api.get_net_param(12, 'cifar10', None)\n", "network.load_state_dict(next(iter(params.values())))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nas_201_api import NASBench201API as API\n", "import os\n", "# api = API('./NAS-Bench-201-v1_1_096897.pth')\n", "# get the current path\n", "print(os.path.abspath(os.path.curdir))\n", "cur_path = os.path.abspath(os.path.curdir)\n", "data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n", "api = API(data_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get the best performance on CIFAR-10\n", "len = 15625\n", "accs = []\n", "for i in range(1, len):\n", " results = api.query_by_index(i, 'cifar10')\n", " dict_items = list(results.items())\n", " train_info = dict_items[0][1].get_train()\n", " acc = train_info['accuracy']\n", " accs.append((i, acc))\n", "print(max(accs, key=lambda x: x[1]))\n", "best_index, best_acc = max(accs, key=lambda x: x[1])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def find_best_index(dataset):\n", " len = 15625\n", " accs = []\n", " for i in range(1, len):\n", " results = api.query_by_index(i, dataset)\n", " dict_items = list(results.items())\n", " train_info = dict_items[0][1].get_train()\n", " acc = train_info['accuracy']\n", " accs.append((i, acc))\n", " return max(accs, key=lambda x: x[1])\n", "best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n", "best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n", "best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n", "print(best_cifar_10_index, best_cifar_10_acc)\n", "print(best_cifar_100_index, best_cifar_100_acc)\n", "print(best_ImageNet16_index, best_ImageNet16_acc)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "api.show(5374)\n", "config = api.get_net_config(best_index, 'cifar10')\n", "from models import get_cell_based_tiny_net\n", "network = get_cell_based_tiny_net(config)\n", "print(network)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "api.get_net_param(5374, 'cifar10', None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, sys, time, torch, random, argparse\n", "from PIL import ImageFile\n", "\n", "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", "from copy import deepcopy\n", "from pathlib import Path\n", "\n", "from config_utils import load_config\n", "from procedures.starts import get_machine_info\n", "from datasets.get_dataset_with_transform import get_datasets\n", "from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n", "from models import CellStructure, CellArchitectures, get_search_spaces" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def evaluate_all_datasets(\n", " arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n", "):\n", " machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n", " all_infos = {\"info\": machine_info}\n", " all_dataset_keys = []\n", " # look all the datasets\n", " for dataset, xpath, split in zip(datasets, xpaths, splits):\n", " # train valid data\n", " train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n", " # load the configuration\n", " if dataset == \"cifar10\" or dataset == \"cifar100\":\n", " if use_less:\n", " config_path = \"configs/nas-benchmark/LESS.config\"\n", " else:\n", " config_path = \"configs/nas-benchmark/CIFAR.config\"\n", " split_info = load_config(\n", " \"configs/nas-benchmark/cifar-split.txt\", None, None\n", " )\n", " elif dataset.startswith(\"ImageNet16\"):\n", " if use_less:\n", " config_path = \"configs/nas-benchmark/LESS.config\"\n", " else:\n", " config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n", " split_info = load_config(\n", " \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n", " )\n", " else:\n", " raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", " config = load_config(\n", " config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n", " )\n", " # check whether use splited validation set\n", " if bool(split):\n", " assert dataset == \"cifar10\"\n", " ValLoaders = {\n", " \"ori-test\": torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " shuffle=False,\n", " num_workers=workers,\n", " pin_memory=True,\n", " )\n", " }\n", " assert len(train_data) == len(split_info.train) + len(\n", " split_info.valid\n", " ), \"invalid length : {:} vs {:} + {:}\".format(\n", " len(train_data), len(split_info.train), len(split_info.valid)\n", " )\n", " train_data_v2 = deepcopy(train_data)\n", " train_data_v2.transform = valid_data.transform\n", " valid_data = train_data_v2\n", " # data loader\n", " train_loader = torch.utils.data.DataLoader(\n", " train_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n", " num_workers=workers,\n", " pin_memory=True,\n", " )\n", " valid_loader = torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n", " num_workers=workers,\n", " pin_memory=True,\n", " )\n", " ValLoaders[\"x-valid\"] = valid_loader\n", " else:\n", " # data loader\n", " train_loader = torch.utils.data.DataLoader(\n", " train_data,\n", " batch_size=config.batch_size,\n", " shuffle=True,\n", " num_workers=workers,\n", " pin_memory=True,\n", " )\n", " valid_loader = torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " shuffle=False,\n", " num_workers=workers,\n", " pin_memory=True,\n", " )\n", " if dataset == \"cifar10\":\n", " ValLoaders = {\"ori-test\": valid_loader}\n", " elif dataset == \"cifar100\":\n", " cifar100_splits = load_config(\n", " \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n", " )\n", " ValLoaders = {\n", " \"ori-test\": valid_loader,\n", " \"x-valid\": torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", " cifar100_splits.xvalid\n", " ),\n", " num_workers=workers,\n", " pin_memory=True,\n", " ),\n", " \"x-test\": torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", " cifar100_splits.xtest\n", " ),\n", " num_workers=workers,\n", " pin_memory=True,\n", " ),\n", " }\n", " elif dataset == \"ImageNet16-120\":\n", " imagenet16_splits = load_config(\n", " \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n", " )\n", " ValLoaders = {\n", " \"ori-test\": valid_loader,\n", " \"x-valid\": torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", " imagenet16_splits.xvalid\n", " ),\n", " num_workers=workers,\n", " pin_memory=True,\n", " ),\n", " \"x-test\": torch.utils.data.DataLoader(\n", " valid_data,\n", " batch_size=config.batch_size,\n", " sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", " imagenet16_splits.xtest\n", " ),\n", " num_workers=workers,\n", " pin_memory=True,\n", " ),\n", " }\n", " else:\n", " raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", "\n", " dataset_key = \"{:}\".format(dataset)\n", " if bool(split):\n", " dataset_key = dataset_key + \"-valid\"\n", " logger.log(\n", " \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n", " dataset_key,\n", " len(train_data),\n", " len(valid_data),\n", " len(train_loader),\n", " len(valid_loader),\n", " config.batch_size,\n", " )\n", " )\n", " logger.log(\n", " \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n", " )\n", " for key, value in ValLoaders.items():\n", " logger.log(\n", " \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n", " )\n", " results = evaluate_for_seed(\n", " arch_config, config, arch, train_loader, ValLoaders, seed, logger\n", " )\n", " all_infos[dataset_key] = results\n", " all_dataset_keys.append(dataset_key)\n", " all_infos[\"all_dataset_keys\"] = all_dataset_keys\n", " return all_infos\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_single_model(\n", " save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n", "):\n", " assert torch.cuda.is_available(), \"CUDA is not available.\"\n", " torch.backends.cudnn.enabled = True\n", " torch.backends.cudnn.deterministic = True\n", " # torch.backends.cudnn.benchmark = True\n", " torch.set_num_threads(workers)\n", "\n", " save_dir = (\n", " Path(save_dir)\n", " / \"specifics\"\n", " / \"{:}-{:}-{:}-{:}\".format(\n", " \"LESS\" if use_less else \"FULL\",\n", " model_str,\n", " arch_config[\"channel\"],\n", " arch_config[\"num_cells\"],\n", " )\n", " )\n", " logger = Logger(str(save_dir), 0, False)\n", " if model_str in CellArchitectures:\n", " arch = CellArchitectures[model_str]\n", " logger.log(\n", " \"The model string is found in pre-defined architecture dict : {:}\".format(\n", " model_str\n", " )\n", " )\n", " else:\n", " try:\n", " arch = CellStructure.str2structure(model_str)\n", " except:\n", " raise ValueError(\n", " \"Invalid model string : {:}. It can not be found or parsed.\".format(\n", " model_str\n", " )\n", " )\n", " assert arch.check_valid_op(\n", " get_search_spaces(\"cell\", \"full\")\n", " ), \"{:} has the invalid op.\".format(arch)\n", " logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n", " logger.log(\"arch_config : {:}\".format(arch_config))\n", "\n", " start_time, seed_time = time.time(), AverageMeter()\n", " for _is, seed in enumerate(seeds):\n", " logger.log(\n", " \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n", " _is, len(seeds), seed\n", " )\n", " )\n", " to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n", " if to_save_name.exists():\n", " logger.log(\n", " \"Find the existing file {:}, directly load!\".format(to_save_name)\n", " )\n", " checkpoint = torch.load(to_save_name)\n", " else:\n", " logger.log(\n", " \"Does not find the existing file {:}, train and evaluate!\".format(\n", " to_save_name\n", " )\n", " )\n", " checkpoint = evaluate_all_datasets(\n", " arch,\n", " datasets,\n", " xpaths,\n", " splits,\n", " use_less,\n", " seed,\n", " arch_config,\n", " workers,\n", " logger,\n", " )\n", " torch.save(checkpoint, to_save_name)\n", " # log information\n", " logger.log(\"{:}\".format(checkpoint[\"info\"]))\n", " all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n", " for dataset_key in all_dataset_keys:\n", " logger.log(\n", " \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n", " )\n", " dataset_info = checkpoint[dataset_key]\n", " # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n", " logger.log(\n", " \"Flops = {:} MB, Params = {:} MB\".format(\n", " dataset_info[\"flop\"], dataset_info[\"param\"]\n", " )\n", " )\n", " logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n", " logger.log(\n", " \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n", " )\n", " last_epoch = dataset_info[\"total_epoch\"] - 1\n", " train_acc1es, train_acc5es = (\n", " dataset_info[\"train_acc1es\"],\n", " dataset_info[\"train_acc5es\"],\n", " )\n", " valid_acc1es, valid_acc5es = (\n", " dataset_info[\"valid_acc1es\"],\n", " dataset_info[\"valid_acc5es\"],\n", " )\n", " logger.log(\n", " \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n", " train_acc1es[last_epoch],\n", " train_acc5es[last_epoch],\n", " 100 - train_acc1es[last_epoch],\n", " valid_acc1es[last_epoch],\n", " valid_acc5es[last_epoch],\n", " 100 - valid_acc1es[last_epoch],\n", " )\n", " )\n", " # measure elapsed time\n", " seed_time.update(time.time() - start_time)\n", " start_time = time.time()\n", " need_time = \"Time Left: {:}\".format(\n", " convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n", " )\n", " logger.log(\n", " \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}\".format(\n", " _is, len(seeds), seed, need_time\n", " )\n", " )\n", " logger.close()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_single_model(\n", " save_dir=\"./outputs\",\n", " workers=8,\n", " datasets=\"cifar10\", \n", " xpaths=\"/root/cifardata/cifar-10-batches-py\",\n", " splits=[0, 0, 0],\n", " use_less=False,\n", " seeds=[777],\n", " model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n", " arch_config={\"channel\": 16, \"num_cells\": 8},)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "natsbench", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }