run the specific model
This commit is contained in:
		
							
								
								
									
										502
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,502 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from nats_bench import create\n", | ||||
|     "\n", | ||||
|     "# Create the API for size search space\n", | ||||
|     "api = create(None, 'sss', fast_mode=True, verbose=True)\n", | ||||
|     "\n", | ||||
|     "# Create the API for tologoy search space\n", | ||||
|     "api = create(None, 'tss', fast_mode=True, verbose=True)\n", | ||||
|     "\n", | ||||
|     "# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n", | ||||
|     "# info is a dict, where you can easily figure out the meaning by key\n", | ||||
|     "info = api.get_more_info(1234, 'cifar10')\n", | ||||
|     "\n", | ||||
|     "# Query the flops, params, latency. info is a dict.\n", | ||||
|     "info = api.get_cost_info(12, 'cifar10')\n", | ||||
|     "\n", | ||||
|     "# Simulate the training of the 1224-th candidate:\n", | ||||
|     "validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n", | ||||
|     "\n", | ||||
|     "# Clear the parameters of the 12-th candidate.\n", | ||||
|     "api.clear_params(12)\n", | ||||
|     "\n", | ||||
|     "# Reload all information of the 12-th candidate.\n", | ||||
|     "api.reload(index=12)\n", | ||||
|     "\n", | ||||
|     "# Create the instance of th 12-th candidate for CIFAR-10.\n", | ||||
|     "from models import get_cell_based_tiny_net\n", | ||||
|     "config = api.get_net_config(12, 'cifar10')\n", | ||||
|     "network = get_cell_based_tiny_net(config)\n", | ||||
|     "\n", | ||||
|     "# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n", | ||||
|     "params = api.get_net_param(12, 'cifar10', None)\n", | ||||
|     "network.load_state_dict(next(iter(params.values())))\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from nas_201_api import NASBench201API as API\n", | ||||
|     "import os\n", | ||||
|     "# api = API('./NAS-Bench-201-v1_1_096897.pth')\n", | ||||
|     "# get the current path\n", | ||||
|     "print(os.path.abspath(os.path.curdir))\n", | ||||
|     "cur_path = os.path.abspath(os.path.curdir)\n", | ||||
|     "data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n", | ||||
|     "api = API(data_path)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# get the best performance on CIFAR-10\n", | ||||
|     "len = 15625\n", | ||||
|     "accs = []\n", | ||||
|     "for i in range(1, len):\n", | ||||
|     "    results = api.query_by_index(i, 'cifar10')\n", | ||||
|     "    dict_items = list(results.items())\n", | ||||
|     "    train_info = dict_items[0][1].get_train()\n", | ||||
|     "    acc = train_info['accuracy']\n", | ||||
|     "    accs.append((i, acc))\n", | ||||
|     "print(max(accs, key=lambda x: x[1]))\n", | ||||
|     "best_index, best_acc = max(accs, key=lambda x: x[1])\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def find_best_index(dataset):\n", | ||||
|     "    len = 15625\n", | ||||
|     "    accs = []\n", | ||||
|     "    for i in range(1, len):\n", | ||||
|     "        results = api.query_by_index(i, dataset)\n", | ||||
|     "        dict_items = list(results.items())\n", | ||||
|     "        train_info = dict_items[0][1].get_train()\n", | ||||
|     "        acc = train_info['accuracy']\n", | ||||
|     "        accs.append((i, acc))\n", | ||||
|     "    return max(accs, key=lambda x: x[1])\n", | ||||
|     "best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n", | ||||
|     "best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n", | ||||
|     "best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n", | ||||
|     "print(best_cifar_10_index, best_cifar_10_acc)\n", | ||||
|     "print(best_cifar_100_index, best_cifar_100_acc)\n", | ||||
|     "print(best_ImageNet16_index, best_ImageNet16_acc)\n", | ||||
|     "\n", | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "api.show(5374)\n", | ||||
|     "config = api.get_net_config(best_index, 'cifar10')\n", | ||||
|     "from models import get_cell_based_tiny_net\n", | ||||
|     "network = get_cell_based_tiny_net(config)\n", | ||||
|     "print(network)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "api.get_net_param(5374, 'cifar10', None)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import os, sys, time, torch, random, argparse\n", | ||||
|     "from PIL import ImageFile\n", | ||||
|     "\n", | ||||
|     "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", | ||||
|     "from copy import deepcopy\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "from config_utils import load_config\n", | ||||
|     "from procedures.starts import get_machine_info\n", | ||||
|     "from datasets.get_dataset_with_transform import get_datasets\n", | ||||
|     "from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n", | ||||
|     "from models import CellStructure, CellArchitectures, get_search_spaces" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def evaluate_all_datasets(\n", | ||||
|     "    arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n", | ||||
|     "):\n", | ||||
|     "    machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n", | ||||
|     "    all_infos = {\"info\": machine_info}\n", | ||||
|     "    all_dataset_keys = []\n", | ||||
|     "    # look all the datasets\n", | ||||
|     "    for dataset, xpath, split in zip(datasets, xpaths, splits):\n", | ||||
|     "        # train valid data\n", | ||||
|     "        train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n", | ||||
|     "        # load the configuration\n", | ||||
|     "        if dataset == \"cifar10\" or dataset == \"cifar100\":\n", | ||||
|     "            if use_less:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/LESS.config\"\n", | ||||
|     "            else:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/CIFAR.config\"\n", | ||||
|     "            split_info = load_config(\n", | ||||
|     "                \"configs/nas-benchmark/cifar-split.txt\", None, None\n", | ||||
|     "            )\n", | ||||
|     "        elif dataset.startswith(\"ImageNet16\"):\n", | ||||
|     "            if use_less:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/LESS.config\"\n", | ||||
|     "            else:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n", | ||||
|     "            split_info = load_config(\n", | ||||
|     "                \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n", | ||||
|     "            )\n", | ||||
|     "        else:\n", | ||||
|     "            raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", | ||||
|     "        config = load_config(\n", | ||||
|     "            config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n", | ||||
|     "        )\n", | ||||
|     "        # check whether use splited validation set\n", | ||||
|     "        if bool(split):\n", | ||||
|     "            assert dataset == \"cifar10\"\n", | ||||
|     "            ValLoaders = {\n", | ||||
|     "                \"ori-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                    valid_data,\n", | ||||
|     "                    batch_size=config.batch_size,\n", | ||||
|     "                    shuffle=False,\n", | ||||
|     "                    num_workers=workers,\n", | ||||
|     "                    pin_memory=True,\n", | ||||
|     "                )\n", | ||||
|     "            }\n", | ||||
|     "            assert len(train_data) == len(split_info.train) + len(\n", | ||||
|     "                split_info.valid\n", | ||||
|     "            ), \"invalid length : {:} vs {:} + {:}\".format(\n", | ||||
|     "                len(train_data), len(split_info.train), len(split_info.valid)\n", | ||||
|     "            )\n", | ||||
|     "            train_data_v2 = deepcopy(train_data)\n", | ||||
|     "            train_data_v2.transform = valid_data.transform\n", | ||||
|     "            valid_data = train_data_v2\n", | ||||
|     "            # data loader\n", | ||||
|     "            train_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                train_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            valid_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                valid_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            ValLoaders[\"x-valid\"] = valid_loader\n", | ||||
|     "        else:\n", | ||||
|     "            # data loader\n", | ||||
|     "            train_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                train_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                shuffle=True,\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            valid_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                valid_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                shuffle=False,\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            if dataset == \"cifar10\":\n", | ||||
|     "                ValLoaders = {\"ori-test\": valid_loader}\n", | ||||
|     "            elif dataset == \"cifar100\":\n", | ||||
|     "                cifar100_splits = load_config(\n", | ||||
|     "                    \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n", | ||||
|     "                )\n", | ||||
|     "                ValLoaders = {\n", | ||||
|     "                    \"ori-test\": valid_loader,\n", | ||||
|     "                    \"x-valid\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            cifar100_splits.xvalid\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                    \"x-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            cifar100_splits.xtest\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                }\n", | ||||
|     "            elif dataset == \"ImageNet16-120\":\n", | ||||
|     "                imagenet16_splits = load_config(\n", | ||||
|     "                    \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n", | ||||
|     "                )\n", | ||||
|     "                ValLoaders = {\n", | ||||
|     "                    \"ori-test\": valid_loader,\n", | ||||
|     "                    \"x-valid\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            imagenet16_splits.xvalid\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                    \"x-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            imagenet16_splits.xtest\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                }\n", | ||||
|     "            else:\n", | ||||
|     "                raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", | ||||
|     "\n", | ||||
|     "        dataset_key = \"{:}\".format(dataset)\n", | ||||
|     "        if bool(split):\n", | ||||
|     "            dataset_key = dataset_key + \"-valid\"\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n", | ||||
|     "                dataset_key,\n", | ||||
|     "                len(train_data),\n", | ||||
|     "                len(valid_data),\n", | ||||
|     "                len(train_loader),\n", | ||||
|     "                len(valid_loader),\n", | ||||
|     "                config.batch_size,\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n", | ||||
|     "        )\n", | ||||
|     "        for key, value in ValLoaders.items():\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n", | ||||
|     "            )\n", | ||||
|     "        results = evaluate_for_seed(\n", | ||||
|     "            arch_config, config, arch, train_loader, ValLoaders, seed, logger\n", | ||||
|     "        )\n", | ||||
|     "        all_infos[dataset_key] = results\n", | ||||
|     "        all_dataset_keys.append(dataset_key)\n", | ||||
|     "    all_infos[\"all_dataset_keys\"] = all_dataset_keys\n", | ||||
|     "    return all_infos\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def train_single_model(\n", | ||||
|     "    save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n", | ||||
|     "):\n", | ||||
|     "    assert torch.cuda.is_available(), \"CUDA is not available.\"\n", | ||||
|     "    torch.backends.cudnn.enabled = True\n", | ||||
|     "    torch.backends.cudnn.deterministic = True\n", | ||||
|     "    # torch.backends.cudnn.benchmark = True\n", | ||||
|     "    torch.set_num_threads(workers)\n", | ||||
|     "\n", | ||||
|     "    save_dir = (\n", | ||||
|     "        Path(save_dir)\n", | ||||
|     "        / \"specifics\"\n", | ||||
|     "        / \"{:}-{:}-{:}-{:}\".format(\n", | ||||
|     "            \"LESS\" if use_less else \"FULL\",\n", | ||||
|     "            model_str,\n", | ||||
|     "            arch_config[\"channel\"],\n", | ||||
|     "            arch_config[\"num_cells\"],\n", | ||||
|     "        )\n", | ||||
|     "    )\n", | ||||
|     "    logger = Logger(str(save_dir), 0, False)\n", | ||||
|     "    if model_str in CellArchitectures:\n", | ||||
|     "        arch = CellArchitectures[model_str]\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"The model string is found in pre-defined architecture dict : {:}\".format(\n", | ||||
|     "                model_str\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "    else:\n", | ||||
|     "        try:\n", | ||||
|     "            arch = CellStructure.str2structure(model_str)\n", | ||||
|     "        except:\n", | ||||
|     "            raise ValueError(\n", | ||||
|     "                \"Invalid model string : {:}. It can not be found or parsed.\".format(\n", | ||||
|     "                    model_str\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "    assert arch.check_valid_op(\n", | ||||
|     "        get_search_spaces(\"cell\", \"full\")\n", | ||||
|     "    ), \"{:} has the invalid op.\".format(arch)\n", | ||||
|     "    logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n", | ||||
|     "    logger.log(\"arch_config : {:}\".format(arch_config))\n", | ||||
|     "\n", | ||||
|     "    start_time, seed_time = time.time(), AverageMeter()\n", | ||||
|     "    for _is, seed in enumerate(seeds):\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n", | ||||
|     "                _is, len(seeds), seed\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "        to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n", | ||||
|     "        if to_save_name.exists():\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Find the existing file {:}, directly load!\".format(to_save_name)\n", | ||||
|     "            )\n", | ||||
|     "            checkpoint = torch.load(to_save_name)\n", | ||||
|     "        else:\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Does not find the existing file {:}, train and evaluate!\".format(\n", | ||||
|     "                    to_save_name\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "            checkpoint = evaluate_all_datasets(\n", | ||||
|     "                arch,\n", | ||||
|     "                datasets,\n", | ||||
|     "                xpaths,\n", | ||||
|     "                splits,\n", | ||||
|     "                use_less,\n", | ||||
|     "                seed,\n", | ||||
|     "                arch_config,\n", | ||||
|     "                workers,\n", | ||||
|     "                logger,\n", | ||||
|     "            )\n", | ||||
|     "            torch.save(checkpoint, to_save_name)\n", | ||||
|     "        # log information\n", | ||||
|     "        logger.log(\"{:}\".format(checkpoint[\"info\"]))\n", | ||||
|     "        all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n", | ||||
|     "        for dataset_key in all_dataset_keys:\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n", | ||||
|     "            )\n", | ||||
|     "            dataset_info = checkpoint[dataset_key]\n", | ||||
|     "            # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Flops = {:} MB, Params = {:} MB\".format(\n", | ||||
|     "                    dataset_info[\"flop\"], dataset_info[\"param\"]\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "            logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n", | ||||
|     "            )\n", | ||||
|     "            last_epoch = dataset_info[\"total_epoch\"] - 1\n", | ||||
|     "            train_acc1es, train_acc5es = (\n", | ||||
|     "                dataset_info[\"train_acc1es\"],\n", | ||||
|     "                dataset_info[\"train_acc5es\"],\n", | ||||
|     "            )\n", | ||||
|     "            valid_acc1es, valid_acc5es = (\n", | ||||
|     "                dataset_info[\"valid_acc1es\"],\n", | ||||
|     "                dataset_info[\"valid_acc5es\"],\n", | ||||
|     "            )\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n", | ||||
|     "                    train_acc1es[last_epoch],\n", | ||||
|     "                    train_acc5es[last_epoch],\n", | ||||
|     "                    100 - train_acc1es[last_epoch],\n", | ||||
|     "                    valid_acc1es[last_epoch],\n", | ||||
|     "                    valid_acc5es[last_epoch],\n", | ||||
|     "                    100 - valid_acc1es[last_epoch],\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "        # measure elapsed time\n", | ||||
|     "        seed_time.update(time.time() - start_time)\n", | ||||
|     "        start_time = time.time()\n", | ||||
|     "        need_time = \"Time Left: {:}\".format(\n", | ||||
|     "            convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n", | ||||
|     "        )\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}\".format(\n", | ||||
|     "                _is, len(seeds), seed, need_time\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "    logger.close()\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "train_single_model(\n", | ||||
|     "    save_dir=\"./outputs\",\n", | ||||
|     "    workers=8,\n", | ||||
|     "    datasets=\"cifar10\", \n", | ||||
|     "    xpaths=\"/root/cifardata/cifar-10-batches-py\",\n", | ||||
|     "    splits=[0, 0, 0],\n", | ||||
|     "    use_less=False,\n", | ||||
|     "    seeds=[777],\n", | ||||
|     "    model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n", | ||||
|     "    arch_config={\"channel\": 16, \"num_cells\": 8},)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "natsbench", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.9.19" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
| @@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure( | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
| Number_5374 = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||
|         (("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)),  # node-3 | ||||
|     ] | ||||
| ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|     [ | ||||
| @@ -271,4 +278,5 @@ architectures = { | ||||
|     "all_c1x1": AllConv1x1_CODE, | ||||
|     "all_idnt": AllIdentity_CODE, | ||||
|     "all_full": AllFull_CODE, | ||||
|     "5374": Number_5374, | ||||
| } | ||||
|   | ||||
| @@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|   | ||||
		Reference in New Issue
	
	Block a user