Update Weight Watcher in utils
This commit is contained in:
		| @@ -5,18 +5,15 @@ | ||||
| # required to install hpbandster ################################## | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         ################## | ||||
| ################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| import os, sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
| from config_utils import load_config | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from procedures   import prepare_seed, prepare_logger | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from models       import CellStructure, get_search_spaces | ||||
|   | ||||
| @@ -3,11 +3,9 @@ | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|   | ||||
							
								
								
									
										21
									
								
								exps/experimental/test-ww.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								exps/experimental/test-ww.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torchvision.models as models | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from utils import weight_watcher | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|   model = models.vgg19_bn(pretrained=True) | ||||
|   _, summary = weight_watcher.analyze(model, alphas=False) | ||||
|   # print(summary) | ||||
|   for key, value in summary.items(): | ||||
|     print('{:10s} : {:}'.format(key, value)) | ||||
|   # import pdb; pdb.set_trace() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main() | ||||
		Reference in New Issue
	
	Block a user