xautodl/exps/NATS-Bench/test-nats-api.py

112 lines
4.2 KiB
Python

##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/test-nats-api.py #
##############################################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net, CellStructure
from nats_bench import create
def test_api(api, sss_or_tss=True):
print("{:} start testing the api : {:}".format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print("{:}\n".format(info))
info = api.query_by_index(113, "cifar100")
print("{:}\n".format(info))
info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200")
print("{:}\n".format(info))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
for xset in ["train", "test", "valid"]:
best_index, highest_accuracy = api.find_best(dataset, xset)
print("")
params = api.get_net_param(12, "cifar10", None)
# Obtain the config and create the network
config = api.get_net_config(12, "cifar10")
print("{:}\n".format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
# Obtain the cost information
info = api.get_cost_info(12, "cifar10")
print("{:}\n".format(info))
info = api.get_latency(12, "cifar10")
print("{:}\n".format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, "cifar10")
# Count the number of architectures
info = api.statistics("cifar100", "12")
print("{:} statistics results : {:}\n".format(time_string(), info))
# Show the information of the 123-th architecture
api.show(123)
# Obtain both cost and performance information
info = api.get_more_info(1234, "cifar10")
print("{:}\n".format(info))
print("{:} finish testing the api : {:}".format(time_string(), api))
if not sss_or_tss:
arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|"
matrix = api.str2matrix(arch_str)
print("Compute the adjacency matrix of {:}".format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, "cifar10")
print("simulate_train_eval : {:}\n\n".format(info))
if __name__ == "__main__":
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
print("{:} --->>> {:}".format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()