Fix bugs in test-ww
This commit is contained in:
		| @@ -6,7 +6,7 @@ | ||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||
| ############################################################################################### | ||||
| import os, gc, sys, argparse, psutil | ||||
| import os, gc, sys, math, argparse, psutil | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| @@ -57,8 +57,12 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|       with torch.no_grad(): | ||||
|         net.load_state_dict(param) | ||||
|         _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|         cur_norms.append(summary['lognorm']) | ||||
|     norms.append( float(np.mean(cur_norms)) ) | ||||
|         cur_norms.append(-summary['lognorm']) | ||||
|     cur_norm = float(np.mean(cur_norms)) | ||||
|     if math.isnan(cur_norm): | ||||
|       print ('  IGNORE {:} due to nan.'.format(idx)) | ||||
|       continue | ||||
|     norms.append(cur_norm) | ||||
|     api.clear_params(idx, None) | ||||
|     if idx % 200 == 199 or idx + 1 == len(api): | ||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user