Update NATS-Bench (sss version 1.0)
This commit is contained in:
		| @@ -10,10 +10,15 @@ | ||||
| # History: | ||||
| # [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py | ||||
| # | ||||
| import os, abc, copy, random, torch, numpy as np | ||||
| from pathlib import Path | ||||
| import abc, copy, random, numpy as np | ||||
| import importlib, warnings | ||||
| from typing import List, Text, Union, Dict, Optional | ||||
| from collections import OrderedDict, defaultdict | ||||
| USE_TORCH = importlib.find_loader('torch') is not None | ||||
| if USE_TORCH: | ||||
|   import torch | ||||
| else: | ||||
|   warnings.warn('Can not find PyTorch, and thus some features maybe invalid.') | ||||
|  | ||||
|  | ||||
| def remap_dataset_set_names(dataset, metric_on_set, verbose=False): | ||||
| @@ -545,6 +550,8 @@ class ArchResults(object): | ||||
|   def create_from_state_dict(state_dict_or_file): | ||||
|     x = ArchResults(-1, -1) | ||||
|     if isinstance(state_dict_or_file, str): # a file path | ||||
|       if not USE_TORCH: | ||||
|         raise ValueError('Since torch is not imported, this logic can not be used.') | ||||
|       state_dict = torch.load(state_dict_or_file, map_location='cpu') | ||||
|     elif isinstance(state_dict_or_file, dict): | ||||
|       state_dict = state_dict_or_file | ||||
|   | ||||
| @@ -3,3 +3,4 @@ from .gpu_manager      import GPUManager | ||||
| from .flop_benchmark   import get_model_infos, count_parameters_in_MB | ||||
| from .affine_utils     import normalize_points, denormalize_points | ||||
| from .affine_utils     import identity2affine, solve2theta, affine2image | ||||
| from .hash_utils       import get_md5_file | ||||
|   | ||||
							
								
								
									
										16
									
								
								lib/utils/hash_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								lib/utils/hash_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| import os, hashlib | ||||
|  | ||||
|  | ||||
| def get_md5_file(file_path, post_truncated=5): | ||||
|   md5_hash = hashlib.md5() | ||||
|   if os.path.exists(file_path): | ||||
|     xfile = open(file_path, "rb") | ||||
|     content = xfile.read() | ||||
|     md5_hash.update(content) | ||||
|     digest = md5_hash.hexdigest() | ||||
|   else: | ||||
|     raise ValueError('[get_md5_file] {:} does not exist'.format(file_path)) | ||||
|   if post_truncated is None: | ||||
|     return digest | ||||
|   else: | ||||
|     return digest[-post_truncated:] | ||||
		Reference in New Issue
	
	Block a user