In [None]:
from nats_bench import create

# Create the API for size search space
api = create(None, 'sss', fast_mode=True, verbose=True)

# Create the API for tologoy search space
api = create(None, 'tss', fast_mode=True, verbose=True)

# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(1234, 'cifar10')

# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(12, 'cifar10')

# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')

# Clear the parameters of the 12-th candidate.
api.clear_params(12)

# Reload all information of the 12-th candidate.
api.reload(index=12)

# Create the instance of th 12-th candidate for CIFAR-10.
from models import get_cell_based_tiny_net
config = api.get_net_config(12, 'cifar10')
network = get_cell_based_tiny_net(config)

# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.
params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values())))


In [None]:
from nas_201_api import NASBench201API as API
import os
# api = API('./NAS-Bench-201-v1_1_096897.pth')
# get the current path
print(os.path.abspath(os.path.curdir))
cur_path = os.path.abspath(os.path.curdir)
data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')
api = API(data_path)

In [None]:
# get the best performance on CIFAR-10
len = 15625
accs = []
for i in range(1, len):
    results = api.query_by_index(i, 'cifar10')
    dict_items = list(results.items())
    train_info = dict_items[0][1].get_train()
    acc = train_info['accuracy']
    accs.append((i, acc))
print(max(accs, key=lambda x: x[1]))
best_index, best_acc = max(accs, key=lambda x: x[1])


In [None]:
def find_best_index(dataset):
    len = 15625
    accs = []
    for i in range(1, len):
        results = api.query_by_index(i, dataset)
        dict_items = list(results.items())
        train_info = dict_items[0][1].get_train()
        acc = train_info['accuracy']
        accs.append((i, acc))
    return max(accs, key=lambda x: x[1])
best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')
best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')
best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')
print(best_cifar_10_index, best_cifar_10_acc)
print(best_cifar_100_index, best_cifar_100_acc)
print(best_ImageNet16_index, best_ImageNet16_acc)




In [None]:
api.show(5374)
config = api.get_net_config(best_index, 'cifar10')
from models import get_cell_based_tiny_net
network = get_cell_based_tiny_net(config)
print(network)

In [None]:
api.get_net_param(5374, 'cifar10', None)

In [None]:
import os, sys, time, torch, random, argparse
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path

from config_utils import load_config
from procedures.starts import get_machine_info
from datasets.get_dataset_with_transform import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from models import CellStructure, CellArchitectures, get_search_spaces

In [None]:
def evaluate_all_datasets(
    arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
):
    machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
    all_infos = {"info": machine_info}
    all_dataset_keys = []
    # look all the datasets
    for dataset, xpath, split in zip(datasets, xpaths, splits):
        # train valid data
        train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
        # load the configuration
        if dataset == "cifar10" or dataset == "cifar100":
            if use_less:
                config_path = "configs/nas-benchmark/LESS.config"
            else:
                config_path = "configs/nas-benchmark/CIFAR.config"
            split_info = load_config(
                "configs/nas-benchmark/cifar-split.txt", None, None
            )
        elif dataset.startswith("ImageNet16"):
            if use_less:
                config_path = "configs/nas-benchmark/LESS.config"
            else:
                config_path = "configs/nas-benchmark/ImageNet-16.config"
            split_info = load_config(
                "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
            )
        else:
            raise ValueError("invalid dataset : {:}".format(dataset))
        config = load_config(
            config_path, {"class_num": class_num, "xshape": xshape}, logger
        )
        # check whether use splited validation set
        if bool(split):
            assert dataset == "cifar10"
            ValLoaders = {
                "ori-test": torch.utils.data.DataLoader(
                    valid_data,
                    batch_size=config.batch_size,
                    shuffle=False,
                    num_workers=workers,
                    pin_memory=True,
                )
            }
            assert len(train_data) == len(split_info.train) + len(
                split_info.valid
            ), "invalid length : {:} vs {:} + {:}".format(
                len(train_data), len(split_info.train), len(split_info.valid)
            )
            train_data_v2 = deepcopy(train_data)
            train_data_v2.transform = valid_data.transform
            valid_data = train_data_v2
            # data loader
            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=config.batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
                num_workers=workers,
                pin_memory=True,
            )
            valid_loader = torch.utils.data.DataLoader(
                valid_data,
                batch_size=config.batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
                num_workers=workers,
                pin_memory=True,
            )
            ValLoaders["x-valid"] = valid_loader
        else:
            # data loader
            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=config.batch_size,
                shuffle=True,
                num_workers=workers,
                pin_memory=True,
            )
            valid_loader = torch.utils.data.DataLoader(
                valid_data,
                batch_size=config.batch_size,
                shuffle=False,
                num_workers=workers,
                pin_memory=True,
            )
            if dataset == "cifar10":
                ValLoaders = {"ori-test": valid_loader}
            elif dataset == "cifar100":
                cifar100_splits = load_config(
                    "configs/nas-benchmark/cifar100-test-split.txt", None, None
                )
                ValLoaders = {
                    "ori-test": valid_loader,
                    "x-valid": torch.utils.data.DataLoader(
                        valid_data,
                        batch_size=config.batch_size,
                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                            cifar100_splits.xvalid
                        ),
                        num_workers=workers,
                        pin_memory=True,
                    ),
                    "x-test": torch.utils.data.DataLoader(
                        valid_data,
                        batch_size=config.batch_size,
                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                            cifar100_splits.xtest
                        ),
                        num_workers=workers,
                        pin_memory=True,
                    ),
                }
            elif dataset == "ImageNet16-120":
                imagenet16_splits = load_config(
                    "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
                )
                ValLoaders = {
                    "ori-test": valid_loader,
                    "x-valid": torch.utils.data.DataLoader(
                        valid_data,
                        batch_size=config.batch_size,
                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                            imagenet16_splits.xvalid
                        ),
                        num_workers=workers,
                        pin_memory=True,
                    ),
                    "x-test": torch.utils.data.DataLoader(
                        valid_data,
                        batch_size=config.batch_size,
                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                            imagenet16_splits.xtest
                        ),
                        num_workers=workers,
                        pin_memory=True,
                    ),
                }
            else:
                raise ValueError("invalid dataset : {:}".format(dataset))

        dataset_key = "{:}".format(dataset)
        if bool(split):
            dataset_key = dataset_key + "-valid"
        logger.log(
            "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
                dataset_key,
                len(train_data),
                len(valid_data),
                len(train_loader),
                len(valid_loader),
                config.batch_size,
            )
        )
        logger.log(
            "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
        )
        for key, value in ValLoaders.items():
            logger.log(
                "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
            )
        results = evaluate_for_seed(
            arch_config, config, arch, train_loader, ValLoaders, seed, logger
        )
        all_infos[dataset_key] = results
        all_dataset_keys.append(dataset_key)
    all_infos["all_dataset_keys"] = all_dataset_keys
    return all_infos


In [None]:
def train_single_model(
    save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    torch.set_num_threads(workers)

    save_dir = (
        Path(save_dir)
        / "specifics"
        / "{:}-{:}-{:}-{:}".format(
            "LESS" if use_less else "FULL",
            model_str,
            arch_config["channel"],
            arch_config["num_cells"],
        )
    )
    logger = Logger(str(save_dir), 0, False)
    if model_str in CellArchitectures:
        arch = CellArchitectures[model_str]
        logger.log(
            "The model string is found in pre-defined architecture dict : {:}".format(
                model_str
            )
        )
    else:
        try:
            arch = CellStructure.str2structure(model_str)
        except:
            raise ValueError(
                "Invalid model string : {:}. It can not be found or parsed.".format(
                    model_str
                )
            )
    assert arch.check_valid_op(
        get_search_spaces("cell", "full")
    ), "{:} has the invalid op.".format(arch)
    logger.log("Start train-evaluate {:}".format(arch.tostr()))
    logger.log("arch_config : {:}".format(arch_config))

    start_time, seed_time = time.time(), AverageMeter()
    for _is, seed in enumerate(seeds):
        logger.log(
            "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format(
                _is, len(seeds), seed
            )
        )
        to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
        if to_save_name.exists():
            logger.log(
                "Find the existing file {:}, directly load!".format(to_save_name)
            )
            checkpoint = torch.load(to_save_name)
        else:
            logger.log(
                "Does not find the existing file {:}, train and evaluate!".format(
                    to_save_name
                )
            )
            checkpoint = evaluate_all_datasets(
                arch,
                datasets,
                xpaths,
                splits,
                use_less,
                seed,
                arch_config,
                workers,
                logger,
            )
            torch.save(checkpoint, to_save_name)
        # log information
        logger.log("{:}".format(checkpoint["info"]))
        all_dataset_keys = checkpoint["all_dataset_keys"]
        for dataset_key in all_dataset_keys:
            logger.log(
                "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
            )
            dataset_info = checkpoint[dataset_key]
            # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
            logger.log(
                "Flops = {:} MB, Params = {:} MB".format(
                    dataset_info["flop"], dataset_info["param"]
                )
            )
            logger.log("config : {:}".format(dataset_info["config"]))
            logger.log(
                "Training State (finish) = {:}".format(dataset_info["finish-train"])
            )
            last_epoch = dataset_info["total_epoch"] - 1
            train_acc1es, train_acc5es = (
                dataset_info["train_acc1es"],
                dataset_info["train_acc5es"],
            )
            valid_acc1es, valid_acc5es = (
                dataset_info["valid_acc1es"],
                dataset_info["valid_acc5es"],
            )
            logger.log(
                "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
                    train_acc1es[last_epoch],
                    train_acc5es[last_epoch],
                    100 - train_acc1es[last_epoch],
                    valid_acc1es[last_epoch],
                    valid_acc5es[last_epoch],
                    100 - valid_acc1es[last_epoch],
                )
            )
        # measure elapsed time
        seed_time.update(time.time() - start_time)
        start_time = time.time()
        need_time = "Time Left: {:}".format(
            convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
        )
        logger.log(
            "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
                _is, len(seeds), seed, need_time
            )
        )
    logger.close()


In [None]:
train_single_model(
    save_dir="./outputs",
    workers=8,
    datasets="cifar10", 
    xpaths="/root/cifardata/cifar-10-batches-py",
    splits=[0, 0, 0],
    use_less=False,
    seeds=[777],
    model_str="|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|",
    arch_config={"channel": 16, "num_cells": 8},)