from nas_201_api import NASBench201API as API import os import os, sys, time, torch, random, argparse from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from copy import deepcopy from pathlib import Path from xautodl.config_utils import load_config from xautodl.procedures import save_checkpoint, copy_checkpoint from xautodl.procedures import get_machine_info from xautodl.datasets import get_datasets from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time from xautodl.models import CellStructure, CellArchitectures, get_search_spaces import time, torch from xautodl.procedures import prepare_seed, get_optim_scheduler from xautodl.utils import get_model_infos, obtain_accuracy from xautodl.config_utils import dict2config from xautodl.log_utils import AverageMeter, time_string, convert_secs2time from xautodl.models import get_cell_based_tiny_net cur_path = os.path.abspath(os.path.curdir) data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth') print(f'loading data from {data_path}') print(f'loading') api = API(data_path) print(f'loaded') def find_best_index(dataset): len = 15625 accs = [] for i in range(1, len): results = api.query_by_index(i, dataset) dict_items = list(results.items()) train_info = dict_items[0][1].get_train() acc = train_info['accuracy'] accs.append((i, acc)) return max(accs, key=lambda x: x[1]) best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10') best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100') best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120') print(f'find best cifar10 index: {best_cifar_10_index}, acc: {best_cifar_10_acc}') print(f'find best cifar100 index: {best_cifar_100_index}, acc: {best_cifar_100_acc}') print(f'find best ImageNet16 index: {best_ImageNet16_index}, acc: {best_ImageNet16_acc}') from xautodl.models import get_cell_based_tiny_net def get_network_str_by_id(id, dataset): config = api.get_net_config(id, dataset) return config['arch_str'] best_cifar_10_str = get_network_str_by_id(best_cifar_10_index, 'cifar10') best_cifar_100_str = get_network_str_by_id(best_cifar_100_index, 'cifar100') best_ImageNet16_str = get_network_str_by_id(best_ImageNet16_index, 'ImageNet16-120') def evaluate_all_datasets( arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger ): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) all_infos = {"info": machine_info} all_dataset_keys = [] # look all the datasets for dataset, xpath, split in zip(datasets, xpaths, splits): # train valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == "cifar10" or dataset == "cifar100": if use_less: config_path = "configs/nas-benchmark/LESS.config" else: config_path = "configs/nas-benchmark/CIFAR.config" split_info = load_config( "configs/nas-benchmark/cifar-split.txt", None, None ) elif dataset.startswith("ImageNet16"): if use_less: config_path = "configs/nas-benchmark/LESS.config" else: config_path = "configs/nas-benchmark/ImageNet-16.config" split_info = load_config( "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None ) else: raise ValueError("invalid dataset : {:}".format(dataset)) config = load_config( config_path, {"class_num": class_num, "xshape": xshape}, logger ) # check whether use splited validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid ), "invalid length : {:} vs {:} + {:}".format( len(train_data), len(split_info.train), len(split_info.valid) ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True, ) ValLoaders["x-valid"] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": cifar100_splits = load_config( "configs/nas-benchmark/cifar100-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xtest ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": imagenet16_splits = load_config( "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xtest ), num_workers=workers, pin_memory=True, ), } else: raise ValueError("invalid dataset : {:}".format(dataset)) dataset_key = "{:}".format(dataset) if bool(split): dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size, ) ) logger.log( "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) ) for key, value in ValLoaders.items(): logger.log( "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) ) results = evaluate_for_seed( arch_config, config, arch, train_loader, ValLoaders, seed, logger ) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys return all_infos def evaluate_for_seed( arch_config, config, arch, train_loader, valid_loaders, seed, logger ): prepare_seed(seed) # random seed net = get_cell_based_tiny_net( dict2config( { "name": "infer.tiny", "C": arch_config["channel"], "N": arch_config["num_cells"], "genotype": arch, "num_classes": config.class_num, }, None, ) ) # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) flop, param = get_model_infos(net, config.xshape) logger.log("Network : {:}".format(net.get_message()), False) logger.log( "{:} Seed-------------------------- {:} --------------------------".format( time_string(), seed ) ) logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) # train and valid optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() # start training start_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), config.epochs + config.warmup, ) ( train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es, ) = ({}, {}, {}, {}, {}, {}) train_times, valid_times = {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) train_loss, train_acc1, train_acc5, train_tm = procedure( train_loader, network, criterion, scheduler, optimizer, "train" ) train_losses[epoch] = train_loss train_acc1es[epoch] = train_acc1 train_acc5es[epoch] = train_acc5 train_times[epoch] = train_tm with torch.no_grad(): for key, xloder in valid_loaders.items(): valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( xloder, network, criterion, None, None, "valid" ) valid_losses["{:}@{:}".format(key, epoch)] = valid_loss valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 valid_times["{:}@{:}".format(key, epoch)] = valid_tm # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) ) logger.log( "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, ) ) info_seed = { "flop": flop, "param": param, "channel": arch_config["channel"], "num_cells": arch_config["num_cells"], "config": config._asdict(), "total_epoch": total_epoch, "train_losses": train_losses, "train_acc1es": train_acc1es, "train_acc5es": train_acc5es, "train_times": train_times, "valid_losses": valid_losses, "valid_acc1es": valid_acc1es, "valid_acc5es": valid_acc5es, "valid_times": valid_times, "net_state_dict": net.state_dict(), "net_string": "{:}".format(net), "finish-train": True, } return info_seed def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): data_time, batch_time, batch = AverageMeter(), AverageMeter(), None losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() latencies = [] network.eval() with torch.no_grad(): end = time.time() for i, (inputs, targets) in enumerate(xloader): targets = targets.cuda(non_blocking=True) inputs = inputs.cuda(non_blocking=True) data_time.update(time.time() - end) # forward features, logits = network(inputs) loss = criterion(logits, targets) batch_time.update(time.time() - end) if batch is None or batch == inputs.size(0): batch = inputs.size(0) latencies.append(batch_time.val - data_time.val) # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) end = time.time() if len(latencies) > 2: latencies = latencies[1:] return losses.avg, top1.avg, top5.avg, latencies def procedure(xloader, network, criterion, scheduler, optimizer, mode): losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() if mode == "train": network.train() elif mode == "valid": network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() for i, (inputs, targets) in enumerate(xloader): if mode == "train": scheduler.update(None, 1.0 * i / len(xloader)) targets = targets.cuda(non_blocking=True) if mode == "train": optimizer.zero_grad() # forward features, logits = network(inputs) loss = criterion(logits, targets) # backward if mode == "train": loss.backward() optimizer.step() # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # count time batch_time.update(time.time() - end) end = time.time() return losses.avg, top1.avg, top5.avg, batch_time.sum def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): data_time, batch_time, batch = AverageMeter(), AverageMeter(), None losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() latencies = [] network.eval() with torch.no_grad(): end = time.time() for i, (inputs, targets) in enumerate(xloader): targets = targets.cuda(non_blocking=True) inputs = inputs.cuda(non_blocking=True) data_time.update(time.time() - end) # forward features, logits = network(inputs) loss = criterion(logits, targets) batch_time.update(time.time() - end) if batch is None or batch == inputs.size(0): batch = inputs.size(0) latencies.append(batch_time.val - data_time.val) # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) end = time.time() if len(latencies) > 2: latencies = latencies[1:] return losses.avg, top1.avg, top5.avg, latencies def procedure(xloader, network, criterion, scheduler, optimizer, mode): losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() if mode == "train": network.train() elif mode == "valid": network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() for i, (inputs, targets) in enumerate(xloader): if mode == "train": scheduler.update(None, 1.0 * i / len(xloader)) targets = targets.cuda(non_blocking=True) if mode == "train": optimizer.zero_grad() # forward features, logits = network(inputs) loss = criterion(logits, targets) # backward if mode == "train": loss.backward() optimizer.step() # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # count time batch_time.update(time.time() - end) end = time.time() return losses.avg, top1.avg, top5.avg, batch_time.sum def train_single_model( save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config ): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = True torch.set_num_threads(workers) save_dir = ( Path(save_dir) / "specifics" / "{:}-{:}-{:}-{:}".format( "LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"], ) ) logger = Logger(str(save_dir), 0, False) print(CellArchitectures) if model_str in CellArchitectures: arch = CellArchitectures[model_str] logger.log( "The model string is found in pre-defined architecture dict : {:}".format( model_str ) ) else: try: arch = CellStructure.str2structure(model_str) except: raise ValueError( "Invalid model string : {:}. It can not be found or parsed.".format( model_str ) ) assert arch.check_valid_op( get_search_spaces("cell", "nas-bench-201") ), "{:} has the invalid op.".format(arch) logger.log("Start train-evaluate {:}".format(arch.tostr())) logger.log("arch_config : {:}".format(arch_config)) start_time, seed_time = time.time(), AverageMeter() for _is, seed in enumerate(seeds): logger.log( "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( _is, len(seeds), seed ) ) to_save_name = save_dir / "seed-{:04d}.pth".format(seed) if to_save_name.exists(): logger.log( "Find the existing file {:}, directly load!".format(to_save_name) ) checkpoint = torch.load(to_save_name) else: logger.log( "Does not find the existing file {:}, train and evaluate!".format( to_save_name ) ) checkpoint = evaluate_all_datasets( arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger, ) torch.save(checkpoint, to_save_name) # log information logger.log("{:}".format(checkpoint["info"])) all_dataset_keys = checkpoint["all_dataset_keys"] for dataset_key in all_dataset_keys: logger.log( "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) ) dataset_info = checkpoint[dataset_key] # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) logger.log( "Flops = {:} MB, Params = {:} MB".format( dataset_info["flop"], dataset_info["param"] ) ) logger.log("config : {:}".format(dataset_info["config"])) logger.log( "Training State (finish) = {:}".format(dataset_info["finish-train"]) ) last_epoch = dataset_info["total_epoch"] - 1 train_acc1es, train_acc5es = ( dataset_info["train_acc1es"], dataset_info["train_acc5es"], ) valid_acc1es, valid_acc5es = ( dataset_info["valid_acc1es"], dataset_info["valid_acc5es"], ) print(dataset_info["train_acc1es"]) print(dataset_info["train_acc5es"]) print(dataset_info["valid_acc1es"]) print(dataset_info["valid_acc5es"]) logger.log( "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( train_acc1es[last_epoch], train_acc5es[last_epoch], 100 - train_acc1es[last_epoch], valid_acc1es['ori-test@'+str(last_epoch)], valid_acc5es['ori-test@'+str(last_epoch)], 100 - valid_acc1es['ori-test@'+str(last_epoch)], ) ) # measure elapsed time seed_time.update(time.time() - start_time) start_time = time.time() need_time = "Time Left: {:}".format( convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) ) logger.log( "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}".format( _is, len(seeds), seed, need_time ) ) logger.close() # |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| train_strs = [best_cifar_10_str, best_cifar_100_str, best_ImageNet16_str] train_single_model( save_dir="./outputs", workers=8, datasets=["ImageNet16-120"], xpaths="./datasets/imagenet16-120", splits=[0, 0, 0], use_less=False, seeds=[777], model_str=best_ImageNet16_str, arch_config={"channel": 16, "num_cells": 8},) train_single_model( save_dir="./outputs", workers=8, datasets=["cifar10"], xpaths="./datasets/cifar10", splits=[0, 0, 0], use_less=False, seeds=[777], model_str=best_cifar_10_str, arch_config={"channel": 16, "num_cells": 8},) train_single_model( save_dir="./outputs", workers=8, datasets=["cifar100"], xpaths="./datasets/cifar100", splits=[0, 0, 0], use_less=False, seeds=[777], model_str=best_cifar_100_str, arch_config={"channel": 16, "num_cells": 8},)