Add MobileNetV2
This commit is contained in:
		
							
								
								
									
										24
									
								
								exps/experimental/test-flops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								exps/experimental/test-flops.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torchvision.models as models | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from utils import get_model_infos | ||||
| #from models.ImageNet_MobileNetV2 import MobileNetV2 | ||||
| from torchvision.models.mobilenet import MobileNetV2 | ||||
|  | ||||
| def main(width_mult): | ||||
|   # model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2) | ||||
|   model = MobileNetV2(width_mult=width_mult) | ||||
|   print(model) | ||||
|   flops, params = get_model_infos(model, (2, 3, 224, 224)) | ||||
|   print('FLOPs : {:}'.format(flops)) | ||||
|   print('Params : {:}'.format(params)) | ||||
|   print('-'*50) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main(1.0) | ||||
|   main(1.4) | ||||
		Reference in New Issue
	
	Block a user