Add get_torch_home func for NATS-Bench
This commit is contained in:
		| @@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any | ||||
|  | ||||
| from nats_bench.api_utils import ArchResults | ||||
| from nats_bench.api_utils import NASBenchMetaAPI | ||||
| from nats_bench.api_utils import get_torch_home | ||||
| from nats_bench.api_utils import nats_is_dir | ||||
| from nats_bench.api_utils import nats_is_file | ||||
| from nats_bench.api_utils import PICKLE_EXT | ||||
| @@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI): | ||||
|     if file_path_or_dict is None: | ||||
|       if self._fast_mode: | ||||
|         self._archive_dir = os.path.join( | ||||
|             os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|             get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|       else: | ||||
|         file_path_or_dict = os.path.join( | ||||
|             os.environ['TORCH_HOME'], '{:}.{:}'.format( | ||||
|             get_torch_home(), '{:}.{:}'.format( | ||||
|                 ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||
|       print('{:} Try to use the default NATS-Bench (size) path from ' | ||||
|             'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, | ||||
|   | ||||
| @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Text, Union | ||||
|  | ||||
| from nats_bench.api_utils import ArchResults | ||||
| from nats_bench.api_utils import NASBenchMetaAPI | ||||
| from nats_bench.api_utils import get_torch_home | ||||
| from nats_bench.api_utils import nats_is_dir | ||||
| from nats_bench.api_utils import nats_is_file | ||||
| from nats_bench.api_utils import PICKLE_EXT | ||||
| @@ -88,10 +89,10 @@ class NATStopology(NASBenchMetaAPI): | ||||
|     if file_path_or_dict is None: | ||||
|       if self._fast_mode: | ||||
|         self._archive_dir = os.path.join( | ||||
|             os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|             get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|       else: | ||||
|         file_path_or_dict = os.path.join( | ||||
|             os.environ['TORCH_HOME'], '{:}.{:}'.format( | ||||
|             get_torch_home(), '{:}.{:}'.format( | ||||
|                 ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||
|       print('{:} Try to use the default NATS-Bench (topology) path from ' | ||||
|             'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) | ||||
|   | ||||
| @@ -45,6 +45,17 @@ def get_file_system(): | ||||
|   return _FILE_SYSTEM | ||||
|  | ||||
|  | ||||
| def get_torch_home(): | ||||
|   if 'TORCH_HOME' in os.environ: | ||||
|     return os.environ['TORCH_HOME'] | ||||
|   elif 'HOME' in os.environ: | ||||
|     return os.path.join(os.environ['HOME'], '.torch') | ||||
|   else: | ||||
|     raise ValueError('Did not find HOME in os.environ. ' | ||||
|       'Please at least setup the path of HOME or TORCH_HOME ' | ||||
|       'in the environment.') | ||||
|  | ||||
|  | ||||
| def nats_is_dir(file_path): | ||||
|   if _FILE_SYSTEM == 'default': | ||||
|     return os.path.isdir(file_path) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user