update NAS-Bench-201 API to support str2structure
This commit is contained in:
		| @@ -251,6 +251,26 @@ class NASBench201API(object): | |||||||
|       else: |       else: | ||||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) |         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||||
|  |  | ||||||
|  |   # This func shows how to read the string0based architecture encoding | ||||||
|  |   #   the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py` | ||||||
|  |   # Usage: | ||||||
|  |   #   arch = api.str2structure( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) | ||||||
|  |   #   print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list | ||||||
|  |   #   for i, node in enumerate(arch): | ||||||
|  |   #     print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) | ||||||
|  |   @staticmethod | ||||||
|  |   def str2structure(xstr): | ||||||
|  |     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||||
|  |     nodestrs = xstr.split('+') | ||||||
|  |     genotypes = [] | ||||||
|  |     for i, node_str in enumerate(nodestrs): | ||||||
|  |       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||||
|  |       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||||
|  |       inputs = ( xi.split('~') for xi in inputs ) | ||||||
|  |       input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) | ||||||
|  |       genotypes.append( input_infos ) | ||||||
|  |     return genotypes | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ArchResults(object): | class ArchResults(object): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user