##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
"""This file is used to quickly test the API."""
import random

from nats_bench.api_size import NATSsize
from nats_bench.api_topology import NATStopology


def test_nats_bench_tss(benchmark_dir):
  return test_nats_bench(benchmark_dir, True)


def test_nats_bench_sss(benchmark_dir):
  return test_nats_bench(benchmark_dir, False)


def test_nats_bench(benchmark_dir, is_tss, verbose=False):
  if is_tss:
    api = NATStopology(benchmark_dir, True, verbose)
  else:
    api = NATSsize(benchmark_dir, True, verbose)

  test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
  key2dataset = {'cifar10': 'CIFAR-10',
                 'cifar100': 'CIFAR-100',
                 'ImageNet16-120': 'ImageNet16-120'}

  for index in test_indexes:
    print('\n\nEvaluate the {:5d}-th architecture.'.format(index))

    for key, dataset in key2dataset.items():
      # Query the loss / accuracy / time for the `index`-th candidate
      #   architecture on CIFAR-10
      # info is a dict, where you can easily figure out the meaning by key
      info = api.get_more_info(index, key)
      print('  -->> The performance on {:}: {:}'.format(dataset, info))

      # Query the flops, params, latency. info is a dict.
      info = api.get_cost_info(index, key)
      print('  -->> The cost info on {:}: {:}'.format(dataset, info))

      # Simulate the training of the `index`-th candidate:
      validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
          index, dataset=key, hp='12')
      print('  -->> The validation accuracy={:}, latency={:}, '
            'the current time cost={:} s, accumulated time cost={:} s'
            .format(validation_accuracy, latency, time_cost,
                    current_total_time_cost))

      # Print the configuration of the `index`-th architecture on CIFAR-10
      config = api.get_net_config(index, key)
      print('  -->> The configuration on {:} is {:}'.format(dataset, config))

    # Show the information of the `index`-th architecture
    api.show(index)