54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
##################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
##################################################
|
|
import torch
|
|
import os, sys
|
|
import os.path as osp
|
|
import numpy as np
|
|
|
|
def tensor2np(x):
|
|
if isinstance(x, np.ndarray): return x
|
|
if x.is_cuda: x = x.cpu()
|
|
return x.numpy()
|
|
|
|
class Save_Meta():
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def __repr__(self):
|
|
return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {})'.format(len(self)))
|
|
|
|
def reset(self):
|
|
self.predictions = []
|
|
self.groundtruth = []
|
|
|
|
def __len__(self):
|
|
return len(self.predictions)
|
|
|
|
def append(self, _pred, _ground):
|
|
_pred, _ground = tensor2np(_pred), tensor2np(_ground)
|
|
assert _ground.shape[0] == _pred.shape[0] and len(_pred.shape) == 2 and len(_ground.shape) == 1, 'The shapes are wrong : {} & {}'.format(_pred.shape, _ground.shape)
|
|
self.predictions.append(_pred)
|
|
self.groundtruth.append(_ground)
|
|
|
|
def save(self, save_dir, filename, test=True):
|
|
meta = {'predictions': self.predictions,
|
|
'groundtruth': self.groundtruth}
|
|
filename = osp.join(save_dir, filename)
|
|
torch.save(meta, filename)
|
|
if test:
|
|
predictions = np.concatenate(self.predictions)
|
|
groundtruth = np.concatenate(self.groundtruth)
|
|
predictions = np.argmax(predictions, axis=1)
|
|
accuracy = np.sum(groundtruth==predictions) * 100.0 / predictions.size
|
|
else:
|
|
accuracy = None
|
|
print ('save save_meta into {} with accuracy = {}'.format(filename, accuracy))
|
|
|
|
def load(self, filename):
|
|
assert os.path.isfile(filename), '{} is not a file'.format(filename)
|
|
checkpoint = torch.load(filename)
|
|
self.predictions = checkpoint['predictions']
|
|
self.groundtruth = checkpoint['groundtruth']
|