import torch
import os, sys
import os.path as osp
import numpy as np

def tensor2np(x):
  if isinstance(x, np.ndarray): return x
  if x.is_cuda: x = x.cpu()
  return x.numpy()

class Save_Meta():

  def __init__(self):
    self.reset()

  def __repr__(self):
    return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {})'.format(len(self)))

  def reset(self):
    self.predictions = []
    self.groundtruth = []
    
  def __len__(self):
    return len(self.predictions)

  def append(self, _pred, _ground):
    _pred, _ground = tensor2np(_pred), tensor2np(_ground)
    assert _ground.shape[0] == _pred.shape[0] and len(_pred.shape) == 2 and len(_ground.shape) == 1, 'The shapes are wrong : {} & {}'.format(_pred.shape, _ground.shape)
    self.predictions.append(_pred)
    self.groundtruth.append(_ground)

  def save(self, save_dir, filename, test=True):
    meta = {'predictions': self.predictions, 
            'groundtruth': self.groundtruth}
    filename = osp.join(save_dir, filename)
    torch.save(meta, filename)
    if test:
      predictions = np.concatenate(self.predictions)
      groundtruth = np.concatenate(self.groundtruth)
      predictions = np.argmax(predictions, axis=1)
      accuracy = np.sum(groundtruth==predictions) * 100.0 / predictions.size
    else:
      accuracy = None
    print ('save save_meta into {} with accuracy = {}'.format(filename, accuracy))

  def load(self, filename):
    assert os.path.isfile(filename), '{} is not a file'.format(filename)
    checkpoint       = torch.load(filename)
    self.predictions = checkpoint['predictions']
    self.groundtruth = checkpoint['groundtruth']