xautodl/exps/NATS-Bench/tss-file-manager.py

106 lines
4.2 KiB
Python

##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/tss-file-manager.py --mode check #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.procedures import bench_evaluate_for_seed
from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print(
"[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
save_dir, seed, len(xlist), total, total - len(xlist), total
)
)
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)["miss2ckps"]
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print(
"Map from {:} to {:}, find {:} missed ckps.".format(
source_dir, target_dir, len(s2t)
)
)
for s, t in s2t.items():
copyfile(s, t)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["check", "copy"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-topology",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
# use for train the model
args = parser.parse_args()
possible_configs = ["12", "200"]
possible_seedss = [[111, 777], [777, 888, 999]]
if args.mode == "check":
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(
cur_save_dir, args.check_N, possible_seeds
)
torch.save(
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
"{:}/meta-{:}.pth".format(args.save_dir, config),
)
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print("Do not find : {:}".format(cur_meta_path))
else:
raise ValueError("invalid mode : {:}".format(args.mode))