##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
"""The official Application Programming Interface (API) for NATS-Bench."""
from nats_bench.api_size import NATSsize
from nats_bench.api_topology import NATStopology
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import pickle_save
from nats_bench.api_utils import ResultsCount


NATS_BENCH_API_VERSIONs = ['v1.0',    # [2020.08.31]
                           'v1.1']    # [2020.12.20] adding unit tests
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')


def version():
  return NATS_BENCH_API_VERSIONs[-1]


def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
  """Create the instead for NATS API.

  Args:
    file_path_or_dict: None or a file path or a directory path.
    search_space: This is a string indicates the search space in NATS-Bench.
    fast_mode: If True, we will not load all the data at initialization,
      instead, the data for each candidate architecture will be loaded when
      quering it; If False, we will load all the data during initialization.
    verbose: This is a flag to indicate whether log additional information.

  Raises:
    ValueError: If not find the matched serach space description.

  Returns:
    The created NATS-Bench API.
  """
  if search_space in NATS_BENCH_TSS_NAMEs:
    return NATStopology(file_path_or_dict, fast_mode, verbose)
  elif search_space in NATS_BENCH_SSS_NAMEs:
    return NATSsize(file_path_or_dict, fast_mode, verbose)
  else:
    raise ValueError('invalid search space : {:}'.format(search_space))


def search_space_info(main_tag, aux_tag):
  """Obtain the search space information."""
  nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64],
                  num_layers=5)
  nats_tss = dict(op_names=['none', 'skip_connect',
                            'nor_conv_1x1', 'nor_conv_3x3',
                            'avg_pool_3x3'],
                  num_nodes=4)
  if main_tag == 'nats-bench':
    if aux_tag in NATS_BENCH_SSS_NAMEs:
      return nats_sss
    elif aux_tag in NATS_BENCH_TSS_NAMEs:
      return nats_tss
    else:
      raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag))
  elif main_tag == 'nas-bench-201':
    if aux_tag is not None:
      raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.')
    return nats_tss
  else:
    raise ValueError('Unknown main tag: {:}'.format(main_tag))