49 lines
2.4 KiB
Python
49 lines
2.4 KiB
Python
##############################################################################
|
|
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
|
|
##############################################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
|
##############################################################################
|
|
# python ./exps/NATS-Bench/Analyze-time.py #
|
|
##############################################################################
|
|
import os, sys, time, tqdm, torch, random, argparse
|
|
from typing import List, Text, Dict, Any
|
|
from PIL import ImageFile
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|
from config_utils import dict2config, load_config
|
|
from datasets import get_datasets
|
|
from nats_bench import create
|
|
|
|
|
|
def show_time(api, epoch=12):
|
|
print('Show the time for {:} with {:}-epoch-training'.format(api, epoch))
|
|
all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
|
|
for index in tqdm.tqdm(range(len(api))):
|
|
info = api.get_more_info(index, 'ImageNet16-120', hp=epoch)
|
|
imagenet_time = info['train-all-time']
|
|
info = api.get_more_info(index, 'cifar10-valid', hp=epoch)
|
|
cifar10_time = info['train-all-time']
|
|
info = api.get_more_info(index, 'cifar100', hp=epoch)
|
|
cifar100_time = info['train-all-time']
|
|
# accumulate the time
|
|
all_cifar10_time += cifar10_time
|
|
all_cifar100_time += cifar100_time
|
|
all_imagenet_time += imagenet_time
|
|
print('The total training time for CIFAR-10 (held-out train set) is {:} seconds'.format(all_cifar10_time))
|
|
print('The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_cifar100_time, all_cifar100_time / all_cifar10_time))
|
|
print('The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_imagenet_time, all_imagenet_time / all_cifar10_time))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False)
|
|
show_time(api_nats_tss, 12)
|
|
|
|
api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False)
|
|
show_time(api_nats_sss, 12)
|
|
|