Update NATS-Bench (sss version 1.0)
This commit is contained in:
		| @@ -10,10 +10,15 @@ | ||||
| # History: | ||||
| # [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py | ||||
| # | ||||
| import os, abc, copy, random, torch, numpy as np | ||||
| from pathlib import Path | ||||
| import abc, copy, random, numpy as np | ||||
| import importlib, warnings | ||||
| from typing import List, Text, Union, Dict, Optional | ||||
| from collections import OrderedDict, defaultdict | ||||
| USE_TORCH = importlib.find_loader('torch') is not None | ||||
| if USE_TORCH: | ||||
|   import torch | ||||
| else: | ||||
|   warnings.warn('Can not find PyTorch, and thus some features maybe invalid.') | ||||
|  | ||||
|  | ||||
| def remap_dataset_set_names(dataset, metric_on_set, verbose=False): | ||||
| @@ -545,6 +550,8 @@ class ArchResults(object): | ||||
|   def create_from_state_dict(state_dict_or_file): | ||||
|     x = ArchResults(-1, -1) | ||||
|     if isinstance(state_dict_or_file, str): # a file path | ||||
|       if not USE_TORCH: | ||||
|         raise ValueError('Since torch is not imported, this logic can not be used.') | ||||
|       state_dict = torch.load(state_dict_or_file, map_location='cpu') | ||||
|     elif isinstance(state_dict_or_file, dict): | ||||
|       state_dict = state_dict_or_file | ||||
|   | ||||
		Reference in New Issue
	
	Block a user