#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API, NASBench301API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher


"""
def get_cor(A, B):
  return float(np.corrcoef(A, B)[0,1])


def tostr(accdict, norms):
  xstr = []
  for key, accs in accdict.items():
    cor = get_cor(accs, norms)
    xstr.append('{:}: {:.3f}'.format(key, cor))
  return ' '.join(xstr)
"""

def evaluate(api, weight_dir, data: str):
  print('\nEvaluate dataset={:}'.format(data))
  process = psutil.Process(os.getpid())
  norms, accuracies = [], []
  ok, total = 0, 5000
  for idx in range(total):
    arch_index = api.random()
    api.reload(weight_dir, arch_index)
    # compute the weight watcher results
    config = api.get_net_config(arch_index, data)
    net = get_cell_based_tiny_net(config)
    meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
    params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777)
    with torch.no_grad():
      net.load_state_dict(params)
      _, summary = weight_watcher.analyze(net, alphas=False)
      if 'lognorm' not in summary:
        api.clear_params(arch_index, None)
        del net ; continue
        continue
      cur_norm = -summary['lognorm']
    api.clear_params(arch_index, None)
    if math.isnan(cur_norm):
      del net, meta_info
      continue
    else:
      ok += 1
      norms.append(cur_norm)
    # query the accuracy
    info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=888 if isinstance(api, NASBench201API) else 777)
    accuracies.append(info['accuracy'])
    del net, meta_info
    # print the information
    if idx % 20 == 0:
      gc.collect()
      print('{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)'.format(time_string(), ok, idx, total, process.memory_info().rss / 1e6))
  return norms, accuracies


def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
  API = NASBench201API if search_space == 'tss' else NASBench301API
  save_dir.mkdir(parents=True, exist_ok=True)
  api = API(meta_file, verbose=False)
  datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
  print(time_string() + ' ' + '='*50)
  for data in datasets:
    hps = api.avaliable_hps
    for hp in hps:
      nums = api.statistics(data, hp=hp)
      total = sum([k*v for k, v in nums.items()])
      print('Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).'.format(hp, data, total, nums))
  print(time_string() + ' ' + '='*50)

  norms, accuracies = evaluate(api, weight_dir, xdata)

  indexes = list(range(len(norms)))
  norm_indexes = sorted(indexes, key=lambda i: norms[i])
  accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
  labels = []
  for index in norm_indexes:
    labels.append(accy_indexes.index(index))

  dpi, width, height = 200, 1400,  800
  figsize = width / float(dpi), height / float(dpi)
  LabelSize, LegendFontsize = 18, 12
  resnet_scale, resnet_alpha = 120, 0.5

  fig = plt.figure(figsize=figsize)
  ax  = fig.add_subplot(111)
  plt.xlim(min(indexes), max(indexes))
  plt.ylim(min(indexes), max(indexes))
  # plt.ylabel('y').set_rotation(30)
  plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
  plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
  ax.scatter(indexes, labels , marker='*', s=0.5, c='tab:red'  , alpha=0.8)
  ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8)
  ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Test accuracy')
  ax.scatter([-1], [-1], marker='*', s=100, c='tab:red'  , label='Weight watcher')
  plt.grid(zorder=0)
  ax.set_axisbelow(True)
  plt.legend(loc=0, fontsize=LegendFontsize)
  ax.set_xlabel('architecture ranking sorted by the test accuracy ', fontsize=LabelSize)
  ax.set_ylabel('architecture ranking computed by weight watcher', fontsize=LabelSize)
  save_path = (save_dir / '{:}-{:}-test-ww.pdf'.format(search_space, xdata)).resolve()
  fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
  save_path = (save_dir / '{:}-{:}-test-ww.png'.format(search_space, xdata)).resolve()
  fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
  print ('{:} save into {:}'.format(time_string(), save_path))
  
  print('{:} finish this test.'.format(time_string()))


if __name__ == '__main__':
  parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
  parser.add_argument('--save_dir',     type=str, default='./output/vis-nas-bench/', help='The base-name of folder to save checkpoints and log.')
  parser.add_argument('--search_space', type=str, default=None, choices=['tss', 'sss'], help='The search space.')
  parser.add_argument('--base_path',    type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
  parser.add_argument('--dataset'  ,    type=str, default=None, help='.')
  args = parser.parse_args()

  save_dir = Path(args.save_dir)
  save_dir.mkdir(parents=True, exist_ok=True)
  meta_file = Path(args.base_path + '.pth')
  weight_dir = Path(args.base_path + '-archive')
  assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
  assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)

  main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)