update NAS-Bench-201 API to support str2matrix and str2lists
This commit is contained in:
		| @@ -251,15 +251,15 @@ class NASBench201API(object): | |||||||
|       else: |       else: | ||||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) |         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||||
|  |  | ||||||
|   # This func shows how to read the string0based architecture encoding |   # This func shows how to read the string-based architecture encoding | ||||||
|   #   the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py` |   #   the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py` | ||||||
|   # Usage: |   # Usage: | ||||||
|   #   arch = api.str2structure( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) |   #   arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) | ||||||
|   #   print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list |   #   print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list | ||||||
|   #   for i, node in enumerate(arch): |   #   for i, node in enumerate(arch): | ||||||
|   #     print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) |   #     print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) | ||||||
|   @staticmethod |   @staticmethod | ||||||
|   def str2structure(xstr): |   def str2lists(xstr): | ||||||
|     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) |     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||||
|     nodestrs = xstr.split('+') |     nodestrs = xstr.split('+') | ||||||
|     genotypes = [] |     genotypes = [] | ||||||
| @@ -271,6 +271,37 @@ class NASBench201API(object): | |||||||
|       genotypes.append( input_infos ) |       genotypes.append( input_infos ) | ||||||
|     return genotypes |     return genotypes | ||||||
|  |  | ||||||
|  |   # This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101 | ||||||
|  |   # Usage: | ||||||
|  |   #   # this will return a numpy matrix (2-D np.array) | ||||||
|  |   #   matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) | ||||||
|  |   #   # This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). | ||||||
|  |   #      [ [0, 0, 0, 0],  # the first line represents the input (0-th) node | ||||||
|  |   #        [2, 0, 0, 0],  # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node ) | ||||||
|  |   #        [0, 0, 0, 0],  # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) | ||||||
|  |   #        [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) | ||||||
|  |   #   In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect' | ||||||
|  |   #      2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. | ||||||
|  |   @staticmethod | ||||||
|  |   def str2matrix(xstr): | ||||||
|  |     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||||
|  |     # this only support NAS-Bench-201 search space | ||||||
|  |     # this defination will be consistant with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24 | ||||||
|  |     # If a node has two input-edges from the same node, this function does not work. One edge will be overleaped. | ||||||
|  |     NAS_BENCH_201         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||||
|  |     nodestrs = xstr.split('+') | ||||||
|  |     num_nodes = len(nodestrs) + 1 | ||||||
|  |     matrix = np.zeros((num_nodes,num_nodes)) | ||||||
|  |     for i, node_str in enumerate(nodestrs): | ||||||
|  |       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||||
|  |       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||||
|  |       for xi in inputs: | ||||||
|  |         op, idx = xi.split('~') | ||||||
|  |         if op not in NAS_BENCH_201: raise ValueError('this op ({:}) is not in {:}'.format(op, NAS_BENCH_201)) | ||||||
|  |         op_idx, node_idx = NAS_BENCH_201.index(op), int(idx) | ||||||
|  |         matrix[i+1, node_idx] = op_idx | ||||||
|  |     return matrix | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ArchResults(object): | class ArchResults(object): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user