Refine lib -> xautodl
This commit is contained in:
		
							
								
								
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_loader.py -s                    # | ||||
| ##################################################### | ||||
| import unittest | ||||
| import tempfile | ||||
| import torch | ||||
|  | ||||
| from xautodl.datasets import get_datasets | ||||
|  | ||||
|  | ||||
| def test_simple(): | ||||
|     xdir = tempfile.mkdtemp() | ||||
|     train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1) | ||||
|     print(train_data) | ||||
|     print(valid_data) | ||||
|  | ||||
|     xloader = torch.utils.data.DataLoader( | ||||
|         train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True | ||||
|     ) | ||||
|     print(xloader) | ||||
|     print(next(iter(xloader))) | ||||
|  | ||||
|     for i, data in enumerate(xloader): | ||||
|         print(i) | ||||
|  | ||||
|  | ||||
| test_simple() | ||||
							
								
								
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter | ||||
|  | ||||
|  | ||||
| class TestTASFunc(unittest.TestCase): | ||||
|     """Test the TAS function.""" | ||||
|  | ||||
|     def test_channel_interplation(self): | ||||
|         tensors = torch.rand((16, 128, 7, 7)) | ||||
|  | ||||
|         for oc in range(200, 210): | ||||
|             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|             assert (out_v1 == out_v2).any().item() == 1 | ||||
|         for oc in range(48, 160): | ||||
|             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|             assert (out_v1 == out_v2).any().item() == 1 | ||||
		Reference in New Issue
	
	Block a user