Add visualize codes for Q
This commit is contained in:
		| @@ -2,9 +2,11 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter | ||||
| from log_utils import time_string | ||||
| from utils import obtain_accuracy | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
|   | ||||
							
								
								
									
										14
									
								
								lib/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								lib/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
| @@ -3,12 +3,14 @@ | ||||
| ##################################################### | ||||
| import os, time, copy, torch, pathlib | ||||
|  | ||||
| # modules in AutoDL | ||||
| import datasets | ||||
| from config_utils import load_config | ||||
| from procedures import prepare_seed, get_optim_scheduler | ||||
| from utils import get_model_infos, obtain_accuracy | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import get_model_infos | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"] | ||||
|   | ||||
| @@ -3,9 +3,10 @@ | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils import obtain_accuracy | ||||
| from models import change_key | ||||
|  | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|   | ||||
| @@ -2,9 +2,11 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils import obtain_accuracy | ||||
| from models import change_key | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|   | ||||
| @@ -4,9 +4,9 @@ | ||||
| import os, sys, time, torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| # our modules | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils import obtain_accuracy | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user