Fix bugs in test-ww
This commit is contained in:
		| @@ -6,7 +6,7 @@ | |||||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||||
| ############################################################################################### | ############################################################################################### | ||||||
| import os, gc, sys, argparse, psutil | import os, gc, sys, math, argparse, psutil | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | |||||||
|       with torch.no_grad(): |       with torch.no_grad(): | ||||||
|         net.load_state_dict(param) |         net.load_state_dict(param) | ||||||
|         _, summary = weight_watcher.analyze(net, alphas=False) |         _, summary = weight_watcher.analyze(net, alphas=False) | ||||||
|         cur_norms.append(summary['lognorm']) |         cur_norms.append(-summary['lognorm']) | ||||||
|     norms.append( float(np.mean(cur_norms)) ) |     cur_norm = float(np.mean(cur_norms)) | ||||||
|  |     if math.isnan(cur_norm): | ||||||
|  |       print ('  IGNORE {:} due to nan.'.format(idx)) | ||||||
|  |       continue | ||||||
|  |     norms.append(cur_norm) | ||||||
|     api.clear_params(idx, None) |     api.clear_params(idx, None) | ||||||
|     if idx % 200 == 199 or idx + 1 == len(api): |     if idx % 200 == 199 or idx + 1 == len(api): | ||||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) |       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user