run the specific model
This commit is contained in:
parent
f46486e21b
commit
bb33ca9a68
502
test.ipynb
Normal file
502
test.ipynb
Normal file
@ -0,0 +1,502 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from nats_bench import create\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the API for size search space\n",
|
||||||
|
"api = create(None, 'sss', fast_mode=True, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the API for tologoy search space\n",
|
||||||
|
"api = create(None, 'tss', fast_mode=True, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n",
|
||||||
|
"# info is a dict, where you can easily figure out the meaning by key\n",
|
||||||
|
"info = api.get_more_info(1234, 'cifar10')\n",
|
||||||
|
"\n",
|
||||||
|
"# Query the flops, params, latency. info is a dict.\n",
|
||||||
|
"info = api.get_cost_info(12, 'cifar10')\n",
|
||||||
|
"\n",
|
||||||
|
"# Simulate the training of the 1224-th candidate:\n",
|
||||||
|
"validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n",
|
||||||
|
"\n",
|
||||||
|
"# Clear the parameters of the 12-th candidate.\n",
|
||||||
|
"api.clear_params(12)\n",
|
||||||
|
"\n",
|
||||||
|
"# Reload all information of the 12-th candidate.\n",
|
||||||
|
"api.reload(index=12)\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the instance of th 12-th candidate for CIFAR-10.\n",
|
||||||
|
"from models import get_cell_based_tiny_net\n",
|
||||||
|
"config = api.get_net_config(12, 'cifar10')\n",
|
||||||
|
"network = get_cell_based_tiny_net(config)\n",
|
||||||
|
"\n",
|
||||||
|
"# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n",
|
||||||
|
"params = api.get_net_param(12, 'cifar10', None)\n",
|
||||||
|
"network.load_state_dict(next(iter(params.values())))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from nas_201_api import NASBench201API as API\n",
|
||||||
|
"import os\n",
|
||||||
|
"# api = API('./NAS-Bench-201-v1_1_096897.pth')\n",
|
||||||
|
"# get the current path\n",
|
||||||
|
"print(os.path.abspath(os.path.curdir))\n",
|
||||||
|
"cur_path = os.path.abspath(os.path.curdir)\n",
|
||||||
|
"data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n",
|
||||||
|
"api = API(data_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# get the best performance on CIFAR-10\n",
|
||||||
|
"len = 15625\n",
|
||||||
|
"accs = []\n",
|
||||||
|
"for i in range(1, len):\n",
|
||||||
|
" results = api.query_by_index(i, 'cifar10')\n",
|
||||||
|
" dict_items = list(results.items())\n",
|
||||||
|
" train_info = dict_items[0][1].get_train()\n",
|
||||||
|
" acc = train_info['accuracy']\n",
|
||||||
|
" accs.append((i, acc))\n",
|
||||||
|
"print(max(accs, key=lambda x: x[1]))\n",
|
||||||
|
"best_index, best_acc = max(accs, key=lambda x: x[1])\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def find_best_index(dataset):\n",
|
||||||
|
" len = 15625\n",
|
||||||
|
" accs = []\n",
|
||||||
|
" for i in range(1, len):\n",
|
||||||
|
" results = api.query_by_index(i, dataset)\n",
|
||||||
|
" dict_items = list(results.items())\n",
|
||||||
|
" train_info = dict_items[0][1].get_train()\n",
|
||||||
|
" acc = train_info['accuracy']\n",
|
||||||
|
" accs.append((i, acc))\n",
|
||||||
|
" return max(accs, key=lambda x: x[1])\n",
|
||||||
|
"best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n",
|
||||||
|
"best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n",
|
||||||
|
"best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n",
|
||||||
|
"print(best_cifar_10_index, best_cifar_10_acc)\n",
|
||||||
|
"print(best_cifar_100_index, best_cifar_100_acc)\n",
|
||||||
|
"print(best_ImageNet16_index, best_ImageNet16_acc)\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"api.show(5374)\n",
|
||||||
|
"config = api.get_net_config(best_index, 'cifar10')\n",
|
||||||
|
"from models import get_cell_based_tiny_net\n",
|
||||||
|
"network = get_cell_based_tiny_net(config)\n",
|
||||||
|
"print(network)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"api.get_net_param(5374, 'cifar10', None)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os, sys, time, torch, random, argparse\n",
|
||||||
|
"from PIL import ImageFile\n",
|
||||||
|
"\n",
|
||||||
|
"ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
|
||||||
|
"from copy import deepcopy\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"from config_utils import load_config\n",
|
||||||
|
"from procedures.starts import get_machine_info\n",
|
||||||
|
"from datasets.get_dataset_with_transform import get_datasets\n",
|
||||||
|
"from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n",
|
||||||
|
"from models import CellStructure, CellArchitectures, get_search_spaces"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def evaluate_all_datasets(\n",
|
||||||
|
" arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n",
|
||||||
|
"):\n",
|
||||||
|
" machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n",
|
||||||
|
" all_infos = {\"info\": machine_info}\n",
|
||||||
|
" all_dataset_keys = []\n",
|
||||||
|
" # look all the datasets\n",
|
||||||
|
" for dataset, xpath, split in zip(datasets, xpaths, splits):\n",
|
||||||
|
" # train valid data\n",
|
||||||
|
" train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n",
|
||||||
|
" # load the configuration\n",
|
||||||
|
" if dataset == \"cifar10\" or dataset == \"cifar100\":\n",
|
||||||
|
" if use_less:\n",
|
||||||
|
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" config_path = \"configs/nas-benchmark/CIFAR.config\"\n",
|
||||||
|
" split_info = load_config(\n",
|
||||||
|
" \"configs/nas-benchmark/cifar-split.txt\", None, None\n",
|
||||||
|
" )\n",
|
||||||
|
" elif dataset.startswith(\"ImageNet16\"):\n",
|
||||||
|
" if use_less:\n",
|
||||||
|
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n",
|
||||||
|
" split_info = load_config(\n",
|
||||||
|
" \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n",
|
||||||
|
" )\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
|
||||||
|
" config = load_config(\n",
|
||||||
|
" config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n",
|
||||||
|
" )\n",
|
||||||
|
" # check whether use splited validation set\n",
|
||||||
|
" if bool(split):\n",
|
||||||
|
" assert dataset == \"cifar10\"\n",
|
||||||
|
" ValLoaders = {\n",
|
||||||
|
" \"ori-test\": torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" shuffle=False,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" }\n",
|
||||||
|
" assert len(train_data) == len(split_info.train) + len(\n",
|
||||||
|
" split_info.valid\n",
|
||||||
|
" ), \"invalid length : {:} vs {:} + {:}\".format(\n",
|
||||||
|
" len(train_data), len(split_info.train), len(split_info.valid)\n",
|
||||||
|
" )\n",
|
||||||
|
" train_data_v2 = deepcopy(train_data)\n",
|
||||||
|
" train_data_v2.transform = valid_data.transform\n",
|
||||||
|
" valid_data = train_data_v2\n",
|
||||||
|
" # data loader\n",
|
||||||
|
" train_loader = torch.utils.data.DataLoader(\n",
|
||||||
|
" train_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" valid_loader = torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" ValLoaders[\"x-valid\"] = valid_loader\n",
|
||||||
|
" else:\n",
|
||||||
|
" # data loader\n",
|
||||||
|
" train_loader = torch.utils.data.DataLoader(\n",
|
||||||
|
" train_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" shuffle=True,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" valid_loader = torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" shuffle=False,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" if dataset == \"cifar10\":\n",
|
||||||
|
" ValLoaders = {\"ori-test\": valid_loader}\n",
|
||||||
|
" elif dataset == \"cifar100\":\n",
|
||||||
|
" cifar100_splits = load_config(\n",
|
||||||
|
" \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n",
|
||||||
|
" )\n",
|
||||||
|
" ValLoaders = {\n",
|
||||||
|
" \"ori-test\": valid_loader,\n",
|
||||||
|
" \"x-valid\": torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||||
|
" cifar100_splits.xvalid\n",
|
||||||
|
" ),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" ),\n",
|
||||||
|
" \"x-test\": torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||||
|
" cifar100_splits.xtest\n",
|
||||||
|
" ),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" ),\n",
|
||||||
|
" }\n",
|
||||||
|
" elif dataset == \"ImageNet16-120\":\n",
|
||||||
|
" imagenet16_splits = load_config(\n",
|
||||||
|
" \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n",
|
||||||
|
" )\n",
|
||||||
|
" ValLoaders = {\n",
|
||||||
|
" \"ori-test\": valid_loader,\n",
|
||||||
|
" \"x-valid\": torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||||
|
" imagenet16_splits.xvalid\n",
|
||||||
|
" ),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" ),\n",
|
||||||
|
" \"x-test\": torch.utils.data.DataLoader(\n",
|
||||||
|
" valid_data,\n",
|
||||||
|
" batch_size=config.batch_size,\n",
|
||||||
|
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||||
|
" imagenet16_splits.xtest\n",
|
||||||
|
" ),\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" pin_memory=True,\n",
|
||||||
|
" ),\n",
|
||||||
|
" }\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
|
||||||
|
"\n",
|
||||||
|
" dataset_key = \"{:}\".format(dataset)\n",
|
||||||
|
" if bool(split):\n",
|
||||||
|
" dataset_key = dataset_key + \"-valid\"\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n",
|
||||||
|
" dataset_key,\n",
|
||||||
|
" len(train_data),\n",
|
||||||
|
" len(valid_data),\n",
|
||||||
|
" len(train_loader),\n",
|
||||||
|
" len(valid_loader),\n",
|
||||||
|
" config.batch_size,\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n",
|
||||||
|
" )\n",
|
||||||
|
" for key, value in ValLoaders.items():\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n",
|
||||||
|
" )\n",
|
||||||
|
" results = evaluate_for_seed(\n",
|
||||||
|
" arch_config, config, arch, train_loader, ValLoaders, seed, logger\n",
|
||||||
|
" )\n",
|
||||||
|
" all_infos[dataset_key] = results\n",
|
||||||
|
" all_dataset_keys.append(dataset_key)\n",
|
||||||
|
" all_infos[\"all_dataset_keys\"] = all_dataset_keys\n",
|
||||||
|
" return all_infos\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train_single_model(\n",
|
||||||
|
" save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n",
|
||||||
|
"):\n",
|
||||||
|
" assert torch.cuda.is_available(), \"CUDA is not available.\"\n",
|
||||||
|
" torch.backends.cudnn.enabled = True\n",
|
||||||
|
" torch.backends.cudnn.deterministic = True\n",
|
||||||
|
" # torch.backends.cudnn.benchmark = True\n",
|
||||||
|
" torch.set_num_threads(workers)\n",
|
||||||
|
"\n",
|
||||||
|
" save_dir = (\n",
|
||||||
|
" Path(save_dir)\n",
|
||||||
|
" / \"specifics\"\n",
|
||||||
|
" / \"{:}-{:}-{:}-{:}\".format(\n",
|
||||||
|
" \"LESS\" if use_less else \"FULL\",\n",
|
||||||
|
" model_str,\n",
|
||||||
|
" arch_config[\"channel\"],\n",
|
||||||
|
" arch_config[\"num_cells\"],\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" logger = Logger(str(save_dir), 0, False)\n",
|
||||||
|
" if model_str in CellArchitectures:\n",
|
||||||
|
" arch = CellArchitectures[model_str]\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"The model string is found in pre-defined architecture dict : {:}\".format(\n",
|
||||||
|
" model_str\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" else:\n",
|
||||||
|
" try:\n",
|
||||||
|
" arch = CellStructure.str2structure(model_str)\n",
|
||||||
|
" except:\n",
|
||||||
|
" raise ValueError(\n",
|
||||||
|
" \"Invalid model string : {:}. It can not be found or parsed.\".format(\n",
|
||||||
|
" model_str\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" assert arch.check_valid_op(\n",
|
||||||
|
" get_search_spaces(\"cell\", \"full\")\n",
|
||||||
|
" ), \"{:} has the invalid op.\".format(arch)\n",
|
||||||
|
" logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n",
|
||||||
|
" logger.log(\"arch_config : {:}\".format(arch_config))\n",
|
||||||
|
"\n",
|
||||||
|
" start_time, seed_time = time.time(), AverageMeter()\n",
|
||||||
|
" for _is, seed in enumerate(seeds):\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n",
|
||||||
|
" _is, len(seeds), seed\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n",
|
||||||
|
" if to_save_name.exists():\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Find the existing file {:}, directly load!\".format(to_save_name)\n",
|
||||||
|
" )\n",
|
||||||
|
" checkpoint = torch.load(to_save_name)\n",
|
||||||
|
" else:\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Does not find the existing file {:}, train and evaluate!\".format(\n",
|
||||||
|
" to_save_name\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" checkpoint = evaluate_all_datasets(\n",
|
||||||
|
" arch,\n",
|
||||||
|
" datasets,\n",
|
||||||
|
" xpaths,\n",
|
||||||
|
" splits,\n",
|
||||||
|
" use_less,\n",
|
||||||
|
" seed,\n",
|
||||||
|
" arch_config,\n",
|
||||||
|
" workers,\n",
|
||||||
|
" logger,\n",
|
||||||
|
" )\n",
|
||||||
|
" torch.save(checkpoint, to_save_name)\n",
|
||||||
|
" # log information\n",
|
||||||
|
" logger.log(\"{:}\".format(checkpoint[\"info\"]))\n",
|
||||||
|
" all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n",
|
||||||
|
" for dataset_key in all_dataset_keys:\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n",
|
||||||
|
" )\n",
|
||||||
|
" dataset_info = checkpoint[dataset_key]\n",
|
||||||
|
" # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Flops = {:} MB, Params = {:} MB\".format(\n",
|
||||||
|
" dataset_info[\"flop\"], dataset_info[\"param\"]\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n",
|
||||||
|
" )\n",
|
||||||
|
" last_epoch = dataset_info[\"total_epoch\"] - 1\n",
|
||||||
|
" train_acc1es, train_acc5es = (\n",
|
||||||
|
" dataset_info[\"train_acc1es\"],\n",
|
||||||
|
" dataset_info[\"train_acc5es\"],\n",
|
||||||
|
" )\n",
|
||||||
|
" valid_acc1es, valid_acc5es = (\n",
|
||||||
|
" dataset_info[\"valid_acc1es\"],\n",
|
||||||
|
" dataset_info[\"valid_acc5es\"],\n",
|
||||||
|
" )\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n",
|
||||||
|
" train_acc1es[last_epoch],\n",
|
||||||
|
" train_acc5es[last_epoch],\n",
|
||||||
|
" 100 - train_acc1es[last_epoch],\n",
|
||||||
|
" valid_acc1es[last_epoch],\n",
|
||||||
|
" valid_acc5es[last_epoch],\n",
|
||||||
|
" 100 - valid_acc1es[last_epoch],\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" # measure elapsed time\n",
|
||||||
|
" seed_time.update(time.time() - start_time)\n",
|
||||||
|
" start_time = time.time()\n",
|
||||||
|
" need_time = \"Time Left: {:}\".format(\n",
|
||||||
|
" convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n",
|
||||||
|
" )\n",
|
||||||
|
" logger.log(\n",
|
||||||
|
" \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}\".format(\n",
|
||||||
|
" _is, len(seeds), seed, need_time\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" logger.close()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_single_model(\n",
|
||||||
|
" save_dir=\"./outputs\",\n",
|
||||||
|
" workers=8,\n",
|
||||||
|
" datasets=\"cifar10\", \n",
|
||||||
|
" xpaths=\"/root/cifardata/cifar-10-batches-py\",\n",
|
||||||
|
" splits=[0, 0, 0],\n",
|
||||||
|
" use_less=False,\n",
|
||||||
|
" seeds=[777],\n",
|
||||||
|
" model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n",
|
||||||
|
" arch_config={\"channel\": 16, \"num_cells\": 8},)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "natsbench",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure(
|
|||||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||||
] # node-3
|
] # node-3
|
||||||
)
|
)
|
||||||
|
Number_5374 = Structure(
|
||||||
|
[
|
||||||
|
(("nor_conv_3x3", 0),), # node-1
|
||||||
|
(("nor_conv_1x1", 0), ("nor_conv_3x3", 1)), # node-2
|
||||||
|
(("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)), # node-3
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
AllFull_CODE = Structure(
|
AllFull_CODE = Structure(
|
||||||
[
|
[
|
||||||
@ -271,4 +278,5 @@ architectures = {
|
|||||||
"all_c1x1": AllConv1x1_CODE,
|
"all_c1x1": AllConv1x1_CODE,
|
||||||
"all_idnt": AllIdentity_CODE,
|
"all_idnt": AllIdentity_CODE,
|
||||||
"all_full": AllFull_CODE,
|
"all_full": AllFull_CODE,
|
||||||
|
"5374": Number_5374,
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)):
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
for k in topk:
|
for k in topk:
|
||||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||||
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
return res
|
return res
|
||||||
|
Loading…
Reference in New Issue
Block a user