xautodl/exps/NATS-Bench/test-nats-api.py

107 lines
4.2 KiB
Python
Raw Normal View History

2020-08-30 10:04:52 +02:00
##############################################################################
2021-01-25 14:48:14 +01:00
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
2020-08-30 10:04:52 +02:00
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/test-nats-api.py #
##############################################################################
2020-09-16 11:25:19 +02:00
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
2021-03-17 10:25:58 +01:00
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
2021-03-17 10:25:58 +01:00
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
2021-03-17 10:25:58 +01:00
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
2020-07-30 15:07:11 +02:00
from nats_bench import create
from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
def test_api(api, sss_or_tss=True):
2021-03-17 10:25:58 +01:00
print("{:} start testing the api : {:}".format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print("{:}\n".format(info))
info = api.query_by_index(113, "cifar100")
print("{:}\n".format(info))
info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200")
print("{:}\n".format(info))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
for xset in ["train", "test", "valid"]:
best_index, highest_accuracy = api.find_best(dataset, xset)
print("")
params = api.get_net_param(12, "cifar10", None)
# Obtain the config and create the network
config = api.get_net_config(12, "cifar10")
print("{:}\n".format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
# Obtain the cost information
info = api.get_cost_info(12, "cifar10")
print("{:}\n".format(info))
info = api.get_latency(12, "cifar10")
print("{:}\n".format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, "cifar10")
# Count the number of architectures
info = api.statistics("cifar100", "12")
print("{:} statistics results : {:}\n".format(time_string(), info))
# Show the information of the 123-th architecture
api.show(123)
# Obtain both cost and performance information
info = api.get_more_info(1234, "cifar10")
print("{:}\n".format(info))
print("{:} finish testing the api : {:}".format(time_string(), api))
if not sss_or_tss:
arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|"
matrix = api.str2matrix(arch_str)
print("Compute the adjacency matrix of {:}".format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, "cifar10")
print("simulate_train_eval : {:}\n\n".format(info))
if __name__ == "__main__":
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
print("{:} --->>> {:}".format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()