##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          #
##############################################################################
# Usage: python exps/NATS-Bench/sss-file-manager.py --mode check             #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path

lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
    sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time


def obtain_valid_ckp(save_dir: Text, total: int):
    possible_seeds = [777, 888, 999]
    seed2ckps = defaultdict(list)
    miss2ckps = defaultdict(list)
    for i in range(total):
        for seed in possible_seeds:
            path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
            if os.path.exists(path):
                seed2ckps[seed].append(i)
            else:
                miss2ckps[seed].append(i)
    for seed, xlist in seed2ckps.items():
        print(
            "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
                save_dir, seed, len(xlist), total, total - len(xlist), total
            )
        )
    return dict(seed2ckps), dict(miss2ckps)


def copy_data(source_dir, target_dir, meta_path):
    target_dir = Path(target_dir)
    target_dir.mkdir(parents=True, exist_ok=True)
    miss2ckps = torch.load(meta_path)["miss2ckps"]
    s2t = {}
    for seed, xlist in miss2ckps.items():
        for i in xlist:
            file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
            source_path = os.path.join(source_dir, file_name)
            target_path = os.path.join(target_dir, file_name)
            if os.path.exists(source_path):
                s2t[source_path] = target_path
    print(
        "Map from {:} to {:}, find {:} missed ckps.".format(
            source_dir, target_dir, len(s2t)
        )
    )
    for s, t in s2t.items():
        copyfile(s, t)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="NATS-Bench (size search space) file manager.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        choices=["check", "copy"],
        help="The script mode.",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="output/NATS-Bench-size",
        help="Folder to save checkpoints and log.",
    )
    parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
    # use for train the model
    args = parser.parse_args()
    possible_configs = ["01", "12", "90"]
    if args.mode == "check":
        for config in possible_configs:
            cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
            seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
            torch.save(
                dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
                "{:}/meta-{:}.pth".format(args.save_dir, config),
            )
    elif args.mode == "copy":
        for config in possible_configs:
            cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
            cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
            cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
            if os.path.exists(cur_meta_path):
                copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
            else:
                print("Do not find : {:}".format(cur_meta_path))
    else:
        raise ValueError("invalid mode : {:}".format(args.mode))