17 lines
		
	
	
		
			453 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			17 lines
		
	
	
		
			453 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| 
 | |
| def obtain_accuracy(output, target, topk=(1,)):
 | |
|   """Computes the precision@k for the specified values of k"""
 | |
|   maxk = max(topk)
 | |
|   batch_size = target.size(0)
 | |
| 
 | |
|   _, pred = output.topk(maxk, 1, True, True)
 | |
|   pred = pred.t()
 | |
|   correct = pred.eq(target.view(1, -1).expand_as(pred))
 | |
| 
 | |
|   res = []
 | |
|   for k in topk:
 | |
|     correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
 | |
|     res.append(correct_k.mul_(100.0 / batch_size))
 | |
|   return res
 |