Update NAS-Bench-201
This commit is contained in:
		| @@ -7,7 +7,7 @@ | |||||||
| ############################################################### | ############################################################### | ||||||
| import os, sys, time, torch, argparse | import os, sys, time, torch, argparse | ||||||
| from typing import List, Text, Dict, Any | from typing import List, Text, Dict, Any | ||||||
| from tqdm import tqdm | from shutil import copyfile | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from copy    import deepcopy | from copy    import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -32,17 +32,28 @@ def obtain_valid_ckp(save_dir: Text, total: int): | |||||||
|         seed2ckps[seed].append(i) |         seed2ckps[seed].append(i) | ||||||
|       else: |       else: | ||||||
|         miss2ckps[seed].append(i) |         miss2ckps[seed].append(i) | ||||||
|     """ |  | ||||||
|     ckps = [x for x in save_dir.glob('arch-{:06d}-seed-*.pth'.format(i))] |  | ||||||
|     for ckp in ckps: |  | ||||||
|       seed = ckp.name.split('-seed-')[-1].split('.pth')[0] |  | ||||||
|       seed2ckps[int(seed)].append(i) |  | ||||||
|     """ |  | ||||||
|   for seed, xlist in seed2ckps.items(): |   for seed, xlist in seed2ckps.items(): | ||||||
|     print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total)) |     print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total)) | ||||||
|   return dict(seed2ckps), dict(miss2ckps) |   return dict(seed2ckps), dict(miss2ckps) | ||||||
|      |      | ||||||
|  |  | ||||||
|  | def copy_data(source_dir, target_dir, meta_path): | ||||||
|  |   target_dir = Path(target_dir) | ||||||
|  |   target_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   miss2ckps = torch.load(meta_path)['miss2ckps'] | ||||||
|  |   s2t = {} | ||||||
|  |   for seed, xlist in miss2ckps.items(): | ||||||
|  |     for i in xlist: | ||||||
|  |       file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed) | ||||||
|  |       source_path = os.path.join(source_dir, file_name) | ||||||
|  |       target_path = os.path.join(target_dir, file_name) | ||||||
|  |       if os.path.exists(source_path): | ||||||
|  |         s2t[source_path] = target_path | ||||||
|  |   print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t))) | ||||||
|  |   for s, t in s2t.items(): | ||||||
|  |     copyfile(s, t) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--mode',        type=str, required=True, choices=['check', 'copy'], help='The script mode.') |   parser.add_argument('--mode',        type=str, required=True, choices=['check', 'copy'], help='The script mode.') | ||||||
| @@ -56,4 +67,14 @@ if __name__ == '__main__': | |||||||
|       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) |       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) | ||||||
|       seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) |       seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) | ||||||
|       torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config)) |       torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config)) | ||||||
|    |   elif args.mode == 'copy': | ||||||
|  |     for config in possible_configs: | ||||||
|  |       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) | ||||||
|  |       cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config) | ||||||
|  |       cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config) | ||||||
|  |       if os.path.exists(cur_meta_path): | ||||||
|  |         copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | ||||||
|  |       else: | ||||||
|  |         print('Do not find : {:}'.format(cur_meta_path)) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('invalid mode : {:}'.format(args.mode)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user