Update NATS-Bench API to v1.1
This commit is contained in:
		| @@ -243,7 +243,10 @@ class NATSsize(NASBenchMetaAPI): | |||||||
|       except Exception as unused_e:  # pylint: disable=broad-except |       except Exception as unused_e:  # pylint: disable=broad-except | ||||||
|         test_info = None |         test_info = None | ||||||
|       valtest_info = None |       valtest_info = None | ||||||
|  |       xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) | ||||||
|     else: |     else: | ||||||
|  |       if dataset == 'cifar10': | ||||||
|  |         xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) | ||||||
|       try:  # collect results on the proposed test set |       try:  # collect results on the proposed test set | ||||||
|         if dataset == 'cifar10': |         if dataset == 'cifar10': | ||||||
|           test_info = archresult.get_metrics( |           test_info = archresult.get_metrics( | ||||||
|   | |||||||
| @@ -32,27 +32,47 @@ class TestNATSBench(object): | |||||||
|       benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') |       benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') | ||||||
|     return _test_nats_bench(benchmark_dir, False, fake_random) |     return _test_nats_bench(benchmark_dir, False, fake_random) | ||||||
|  |  | ||||||
|  |   def prepare_fake_tss(self): | ||||||
|  |     print('') | ||||||
|  |     tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') | ||||||
|  |     api = NATStopology(tss_benchmark_dir, True, False) | ||||||
|  |     return api | ||||||
|  |  | ||||||
|   def test_01_th_issue(self): |   def test_01_th_issue(self): | ||||||
|     # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 |     # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 | ||||||
|     print('') |     api = self.prepare_fake_tss() | ||||||
|     tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') |  | ||||||
|     api = NATStopology(tss_benchmark_dir, True, False) |  | ||||||
|     # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) |     # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) | ||||||
|     info = api.get_more_info(0, 'cifar10', hp=12) |     info = api.get_more_info(0, 'cifar10', hp=12) | ||||||
|     print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss'])) |     # First of all, the data split in NATS-Bench is different from that in the official CIFAR paper. | ||||||
|     print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time'])) |     # In NATS-Bench, we split the original CIFAR-10 training set into two parts, i.e., a training set and a validation set. | ||||||
|  |     # In the following, we will use the splits of NATS-Bench to explain. | ||||||
|  |     print(info['comment']) | ||||||
|  |     print('The loss on the training + validation sets of CIFAR-10: {:}'.format(info['train-loss'])) | ||||||
|  |     print('The total training time for 12 epochs on the training + validation sets of CIFAR-10: {:}'.format(info['train-all-time'])) | ||||||
|     print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time'])) |     print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time'])) | ||||||
|     print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time'])) |     print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time'])) | ||||||
|     print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time'])) |     print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time'])) | ||||||
|     # Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper. |  | ||||||
|     cost_info = api.get_cost_info(0, 'cifar10') |     cost_info = api.get_cost_info(0, 'cifar10') | ||||||
|     xkeys = ['T-train@epoch',     # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper. |     xkeys = ['T-train@epoch',     # The per epoch training time on the training + validation sets of CIFAR-10. | ||||||
|              'T-train@total', |              'T-train@total', | ||||||
|              'T-ori-test@epoch',  # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench. |              'T-ori-test@epoch',  # The time cost for the evaluation on CIFAR-10 test set. | ||||||
|              'T-ori-test@total']  # T-ori-test@epoch * 12 times. |              'T-ori-test@total']  # T-ori-test@epoch * 12 times. | ||||||
|     for xkey in xkeys: |     for xkey in xkeys: | ||||||
|       print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey])) |       print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey])) | ||||||
|      |      | ||||||
|  |   def test_02_th_issue(self): | ||||||
|  |     # https://github.com/D-X-Y/NATS-Bench/issues/2 | ||||||
|  |     api = self.prepare_fake_tss() | ||||||
|  |     data = api.query_by_index(284, dataname='cifar10', hp=200) | ||||||
|  |     for xkey, xvalue in data.items(): | ||||||
|  |       print('{:} : {:}'.format(xkey, xvalue)) | ||||||
|  |     xinfo = data[777].get_train() | ||||||
|  |     print(xinfo) | ||||||
|  |     print(data[777].train_acc1es) | ||||||
|  |  | ||||||
|  |     info_012_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||||
|  |     print(info_012_epochs['train-accuracy']) | ||||||
|  |   | ||||||
|  |  | ||||||
| def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | ||||||
|   """The main test entry for NATS-Bench.""" |   """The main test entry for NATS-Bench.""" | ||||||
| @@ -62,7 +82,7 @@ def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | |||||||
|     api = NATSsize(benchmark_dir, True, verbose) |     api = NATSsize(benchmark_dir, True, verbose) | ||||||
|  |  | ||||||
|   if fake_random: |   if fake_random: | ||||||
|     test_indexes = [0, 11, 241] |     test_indexes = [0, 11, 284] | ||||||
|   else: |   else: | ||||||
|     test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] |     test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -222,7 +222,10 @@ class NATStopology(NASBenchMetaAPI): | |||||||
|       except Exception as unused_e:  # pylint: disable=broad-except |       except Exception as unused_e:  # pylint: disable=broad-except | ||||||
|         test_info = None |         test_info = None | ||||||
|       valtest_info = None |       valtest_info = None | ||||||
|  |       xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) | ||||||
|     else: |     else: | ||||||
|  |       if dataset == 'cifar10': | ||||||
|  |         xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) | ||||||
|       try:  # collect results on the proposed test set |       try:  # collect results on the proposed test set | ||||||
|         if dataset == 'cifar10': |         if dataset == 'cifar10': | ||||||
|           test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) |           test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|   | |||||||
| @@ -426,13 +426,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|           arch_index, hp)) |           arch_index, hp)) | ||||||
|     self._prepare_info(arch_index) |     self._prepare_info(arch_index) | ||||||
|     if arch_index in self.arch2infos_dict: |     if arch_index in self.arch2infos_dict: | ||||||
|       if hp not in self.arch2infos_dict[arch_index]: |       if str(hp) not in self.arch2infos_dict[arch_index]: | ||||||
|         raise ValueError('The {:}-th architecture only has hyper-parameters of ' |         raise ValueError('The {:}-th architecture only has hyper-parameters of ' | ||||||
|                          '{:} instead of {:}.'.format( |                          '{:} instead of {:}.'.format( | ||||||
|                              arch_index, |                              arch_index, | ||||||
|                              list(self.arch2infos_dict[arch_index].keys()), |                              list(self.arch2infos_dict[arch_index].keys()), | ||||||
|                              hp)) |                              hp)) | ||||||
|       info = self.arch2infos_dict[arch_index][hp] |       info = self.arch2infos_dict[arch_index][str(hp)] | ||||||
|     else: |     else: | ||||||
|       raise ValueError('arch_index [{:}] does not in arch2infos'.format( |       raise ValueError('arch_index [{:}] does not in arch2infos'.format( | ||||||
|           arch_index)) |           arch_index)) | ||||||
| @@ -472,7 +472,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|     if self.verbose: |     if self.verbose: | ||||||
|       print('{:} Call query_by_index with arch_index={:}, dataname={:}, ' |       print('{:} Call query_by_index with arch_index={:}, dataname={:}, ' | ||||||
|             'hp={:}'.format(time_string(), arch_index, dataname, hp)) |             'hp={:}'.format(time_string(), arch_index, dataname, hp)) | ||||||
|     info = self.query_meta_info_by_index(arch_index, hp) |     info = self.query_meta_info_by_index(arch_index, str(hp)) | ||||||
|     if dataname is None: |     if dataname is None: | ||||||
|       return info |       return info | ||||||
|     else: |     else: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user