xautodl/exps/NAS-Bench-201/check.py

138 lines
5.3 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
# python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS
#####################################################
import sys, time, argparse, collections
import torch
from pathlib import Path
from collections import defaultdict
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def check_files(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location="cpu")
meta_archs = meta_infos["archs"]
meta_num_archs = meta_infos["total"]
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
for index, sub_dir in enumerate(sub_model_dirs):
xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth"))
# xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
arch_indexes = set()
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split("-")
assert (
len(temp_names) == 4
and temp_names[0] == "arch"
and temp_names[2] == "seed"
), "invalid checkpoint name : {:}".format(checkpoint.name)
arch_indexes.add(temp_names[1])
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[
len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))
] += 1
print(
"There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format(
num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items())
)
)
for key in sorted(list(num_seeds.keys())):
print(
"There are {:5d} architectures that are evaluated {:} times.".format(
num_seeds[key], key
)
)
dir2ckps, dir2ckp_exists = dict(), dict()
start_time, epoch_time = time.time(), AverageMeter()
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
if basestr == "C16-N5":
seeds = [777, 888, 999]
elif basestr == "C16-N5-LESS":
seeds = [111, 777]
else:
raise ValueError("Invalid base str : {:}".format(basestr))
numrs = defaultdict(lambda: 0)
all_checkpoints, all_ckp_exists = [], []
for arch_index in arch_indexes:
checkpoints = [
"arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds
]
ckp_exists = [(sub_dir / x).exists() for x in checkpoints]
arch_index = int(arch_index)
assert (
0 <= arch_index < len(meta_archs)
), "invalid arch-index {:} (not found in meta_archs)".format(arch_index)
all_checkpoints += checkpoints
all_ckp_exists += ckp_exists
numrs[sum(ckp_exists)] += 1
dir2ckps[str(sub_dir)] = all_checkpoints
dir2ckp_exists[str(sub_dir)] = all_ckp_exists
# measure time
epoch_time.update(time.time() - start_time)
start_time = time.time()
numrstr = ", ".join(
["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]
)
print(
"{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format(
time_string(),
IDX + 1,
len(subdir2archs),
len(arch_indexes),
len(all_checkpoints),
sum(all_ckp_exists),
sub_dir,
convert_secs2time(epoch_time.avg * (len(subdir2archs) - IDX - 1), True),
numrstr,
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS Benchmark 201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
type=str,
default="./output/NAS-BENCH-201-4",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--meta_path",
type=str,
default="./output/NAS-BENCH-201-4/meta-node-4.pth",
help="The meta file path.",
)
parser.add_argument(
"--base_str", type=str, default="C16-N5", help="The basic string."
)
args = parser.parse_args()
save_dir = Path(args.base_save_dir)
meta_path = Path(args.meta_path)
assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir)
assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path)
print("check NAS-Bench-201 in {:}".format(save_dir))
check_files(save_dir, meta_path, args.base_str)