Test super core
This commit is contained in:
		
							
								
								
									
										28
									
								
								tests/test_super_rearrange.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								tests/test_super_rearrange.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest ./tests/test_super_rearrange.py -s         # | ||||||
|  | ##################################################### | ||||||
|  | import sys | ||||||
|  | import unittest | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / "..").resolve() | ||||||
|  | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from xautodl import xlayers | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestSuperReArrange(unittest.TestCase): | ||||||
|  |     """Test the super re-arrange layer.""" | ||||||
|  |  | ||||||
|  |     def test_super_re_arrange(self): | ||||||
|  |         layer = xlayers.SuperReArrange( | ||||||
|  |             "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=4, p2=4 | ||||||
|  |         ) | ||||||
|  |         tensor = torch.rand((8, 4, 32, 32)) | ||||||
|  |         print("The tensor shape: {:}".format(tensor.shape)) | ||||||
|  |         print(layer) | ||||||
| @@ -47,3 +47,5 @@ super_name2activation = { | |||||||
| from .super_trade_stem import SuperAlphaEBDv1 | from .super_trade_stem import SuperAlphaEBDv1 | ||||||
| from .super_positional_embedding import SuperDynamicPositionE | from .super_positional_embedding import SuperDynamicPositionE | ||||||
| from .super_positional_embedding import SuperPositionalEncoder | from .super_positional_embedding import SuperPositionalEncoder | ||||||
|  |  | ||||||
|  | from .super_rearrange import SuperReArrange | ||||||
|   | |||||||
| @@ -16,15 +16,14 @@ from .super_module import IntSpaceType | |||||||
| from .super_module import BoolSpaceType | from .super_module import BoolSpaceType | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperRearrange(SuperModule): | class SuperReArrange(SuperModule): | ||||||
|     """Applies the rearrange operation.""" |     """Applies the rearrange operation.""" | ||||||
|  |  | ||||||
|     def __init__(self, pattern, **axes_lengths): |     def __init__(self, pattern, **axes_lengths): | ||||||
|         super(SuperRearrange, self).__init__() |         super(SuperReArrange, self).__init__() | ||||||
|  |  | ||||||
|         self._pattern = pattern |         self._pattern = pattern | ||||||
|         self._axes_lengths = axes_lengths |         self._axes_lengths = axes_lengths | ||||||
|         self.reset_parameters() |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user