Refine lib -> xautodl
This commit is contained in:
		
							
								
								
									
										5
									
								
								.github/workflows/super_model_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/super_model_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -35,3 +35,8 @@ jobs: | |||||||
|           python -m pip install torch torchvision |           python -m pip install torch torchvision | ||||||
|           python -m pytest ./tests/test_super_*.py |           python -m pytest ./tests/test_super_*.py | ||||||
|         shell: bash |         shell: bash | ||||||
|  |  | ||||||
|  |       - name: Test TAS (NeurIPS 2019) | ||||||
|  |         run: | | ||||||
|  |           python -m pytest ./tests/test_tas.py | ||||||
|  |         shell: bash | ||||||
|   | |||||||
| @@ -25,7 +25,7 @@ def main(args): | |||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.benchmark = True |     torch.backends.cudnn.benchmark = True | ||||||
|     # torch.backends.cudnn.deterministic = True |     # torch.backends.cudnn.deterministic = True | ||||||
|     torch.set_num_threads(args.workers) |     # torch.set_num_threads(args.workers) | ||||||
|  |  | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|   | |||||||
| @@ -470,7 +470,7 @@ if __name__ == "__main__": | |||||||
|     assert torch.cuda.is_available(), "CUDA is not available." |     assert torch.cuda.is_available(), "CUDA is not available." | ||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.deterministic = True |     torch.backends.cudnn.deterministic = True | ||||||
|     torch.set_num_threads(args.workers) |     # torch.set_num_threads(args.workers) | ||||||
|  |  | ||||||
|     main( |     main( | ||||||
|         save_dir, |         save_dir, | ||||||
|   | |||||||
| @@ -340,7 +340,7 @@ def train_single_model( | |||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.deterministic = True |     torch.backends.cudnn.deterministic = True | ||||||
|     # torch.backends.cudnn.benchmark = True |     # torch.backends.cudnn.benchmark = True | ||||||
|     torch.set_num_threads(workers) |     # torch.set_num_threads(workers) | ||||||
|  |  | ||||||
|     save_dir = ( |     save_dir = ( | ||||||
|         Path(save_dir) |         Path(save_dir) | ||||||
| @@ -675,7 +675,7 @@ if __name__ == "__main__": | |||||||
|         assert torch.cuda.is_available(), "CUDA is not available." |         assert torch.cuda.is_available(), "CUDA is not available." | ||||||
|         torch.backends.cudnn.enabled = True |         torch.backends.cudnn.enabled = True | ||||||
|         torch.backends.cudnn.deterministic = True |         torch.backends.cudnn.deterministic = True | ||||||
|         torch.set_num_threads(args.workers if args.workers > 0 else 1) |         # torch.set_num_threads(args.workers if args.workers > 0 else 1) | ||||||
|  |  | ||||||
|         main( |         main( | ||||||
|             save_dir, |             save_dir, | ||||||
|   | |||||||
| @@ -132,7 +132,7 @@ def select_action(policy): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, api): | def main(xargs, api): | ||||||
|     torch.set_num_threads(4) |     # torch.set_num_threads(4) | ||||||
|     prepare_seed(xargs.rand_seed) |     prepare_seed(xargs.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -204,7 +204,7 @@ def main(xargs): | |||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.benchmark = False |     torch.backends.cudnn.benchmark = False | ||||||
|     torch.backends.cudnn.deterministic = True |     torch.backends.cudnn.deterministic = True | ||||||
|     torch.set_num_threads(xargs.workers) |     # torch.set_num_threads(xargs.workers) | ||||||
|     prepare_seed(xargs.rand_seed) |     prepare_seed(xargs.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | from xautodl.datasets import get_datasets | ||||||
| if str(lib_dir) not in sys.path: | from xautodl.config_utils import load_config, obtain_basic_args as obtain_args | ||||||
|     sys.path.insert(0, str(lib_dir)) | from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
| from config_utils import load_config, obtain_basic_args as obtain_args | from xautodl.procedures import get_optim_scheduler, get_procedures | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | from xautodl.models import obtain_model | ||||||
| from procedures import get_optim_scheduler, get_procedures | from xautodl.nas_infer_model import obtain_nas_infer_model | ||||||
| from datasets import get_datasets | from xautodl.utils import get_model_infos | ||||||
| from models import obtain_model | from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||||
| from nas_infer_model import obtain_nas_infer_model |  | ||||||
| from utils import get_model_infos |  | ||||||
| from log_utils import AverageMeter, time_string, convert_secs2time |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
| @@ -26,7 +23,7 @@ def main(args): | |||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.benchmark = True |     torch.backends.cudnn.benchmark = True | ||||||
|     # torch.backends.cudnn.deterministic = True |     # torch.backends.cudnn.deterministic = True | ||||||
|     torch.set_num_threads(args.workers) |     # torch.set_num_threads(args.workers) | ||||||
|  |  | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|   | |||||||
| @@ -10,21 +10,17 @@ import numpy as np | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | from xautodl.config_utils import ( | ||||||
| print("lib_dir : {:}".format(lib_dir)) |  | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
| from config_utils import ( |  | ||||||
|     load_config, |     load_config, | ||||||
|     configure2str, |     configure2str, | ||||||
|     obtain_search_single_args as obtain_args, |     obtain_search_single_args as obtain_args, | ||||||
| ) | ) | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
| from procedures import get_optim_scheduler, get_procedures | from xautodl.procedures import get_optim_scheduler, get_procedures | ||||||
| from datasets import get_datasets, SearchDataset | from xautodl.datasets import get_datasets, SearchDataset | ||||||
| from models import obtain_search_model, obtain_model, change_key | from xautodl.models import obtain_search_model, obtain_model, change_key | ||||||
| from utils import get_model_infos | from xautodl.utils import get_model_infos | ||||||
| from log_utils import AverageMeter, time_string, convert_secs2time | from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
| @@ -32,7 +28,7 @@ def main(args): | |||||||
|     torch.backends.cudnn.enabled = True |     torch.backends.cudnn.enabled = True | ||||||
|     torch.backends.cudnn.benchmark = True |     torch.backends.cudnn.benchmark = True | ||||||
|     # torch.backends.cudnn.deterministic = True |     # torch.backends.cudnn.deterministic = True | ||||||
|     torch.set_num_threads(args.workers) |     # torch.set_num_threads(args.workers) | ||||||
|  |  | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest tests/test_loader.py -s                    # | ||||||
|  | ##################################################### | ||||||
|  | import unittest | ||||||
|  | import tempfile | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from xautodl.datasets import get_datasets | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_simple(): | ||||||
|  |     xdir = tempfile.mkdtemp() | ||||||
|  |     train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1) | ||||||
|  |     print(train_data) | ||||||
|  |     print(valid_data) | ||||||
|  |  | ||||||
|  |     xloader = torch.utils.data.DataLoader( | ||||||
|  |         train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True | ||||||
|  |     ) | ||||||
|  |     print(xloader) | ||||||
|  |     print(next(iter(xloader))) | ||||||
|  |  | ||||||
|  |     for i, data in enumerate(xloader): | ||||||
|  |         print(i) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | test_simple() | ||||||
							
								
								
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
|  | from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestTASFunc(unittest.TestCase): | ||||||
|  |     """Test the TAS function.""" | ||||||
|  |  | ||||||
|  |     def test_channel_interplation(self): | ||||||
|  |         tensors = torch.rand((16, 128, 7, 7)) | ||||||
|  |  | ||||||
|  |         for oc in range(200, 210): | ||||||
|  |             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||||
|  |             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||||
|  |             assert (out_v1 == out_v2).any().item() == 1 | ||||||
|  |         for oc in range(48, 160): | ||||||
|  |             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||||
|  |             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||||
|  |             assert (out_v1 == out_v2).any().item() == 1 | ||||||
| @@ -9,9 +9,9 @@ from typing import List, Text, Any | |||||||
| import random, torch | import random, torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|  |  | ||||||
| from models.cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from models.cell_infers.cells import InferCell | from ..cell_infers.cells import InferCell | ||||||
| from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter | from .shape_searchs.SoftSelect import select2withP, ChannelWiseInter | ||||||
|  |  | ||||||
|  |  | ||||||
| class GenericNAS301Model(nn.Module): | class GenericNAS301Model(nn.Module): | ||||||
|   | |||||||
| @@ -1,20 +0,0 @@ | |||||||
| ################################################## |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # |  | ||||||
| ################################################## |  | ||||||
| import torch |  | ||||||
| import torch.nn as nn |  | ||||||
| from SoftSelect import ChannelWiseInter |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|  |  | ||||||
|     tensors = torch.rand((16, 128, 7, 7)) |  | ||||||
|  |  | ||||||
|     for oc in range(200, 210): |  | ||||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") |  | ||||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") |  | ||||||
|         assert (out_v1 == out_v2).any().item() == 1 |  | ||||||
|     for oc in range(48, 160): |  | ||||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") |  | ||||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") |  | ||||||
|         assert (out_v1 == out_v2).any().item() == 1 |  | ||||||
		Reference in New Issue
	
	Block a user