import random, tarfile
import numpy, six
from six.moves import cPickle as pickle
from PIL import Image, ImageOps


def train_cifar_augmentation(image):
  # flip
  if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
  else:                     image1 = image
  # random crop
  image2 = ImageOps.expand(image1, border=4, fill=0)
  i = random.randint(0, 40 - 32)
  j = random.randint(0, 40 - 32)
  image3 = image2.crop((j,i,j+32,i+32))
  # to numpy
  image3 = numpy.array(image3) / 255.0
  mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
  std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
  return (image3 - mean) / std


def valid_cifar_augmentation(image):
  image3 = numpy.array(image) / 255.0
  mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
  std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
  return (image3 - mean) / std


def reader_creator(filename, sub_name, is_train, cycle=False):
  def read_batch(batch):
    data = batch[six.b('data')]
    labels = batch.get(
      six.b('labels'), batch.get(six.b('fine_labels'), None))
    assert labels is not None
    for sample, label in six.moves.zip(data, labels):
      sample = sample.reshape(3, 32, 32)
      sample = sample.transpose((1, 2, 0))
      image  = Image.fromarray(sample)
      if is_train:
        ximage = train_cifar_augmentation(image)
      else:
        ximage = valid_cifar_augmentation(image)
      ximage = ximage.transpose((2, 0, 1))
      yield ximage.astype(numpy.float32), int(label)

  def reader():
    with tarfile.open(filename, mode='r') as f:
      names = (each_item.name for each_item in f
           if sub_name in each_item.name)

      while True:
        for name in names:
          if six.PY2:
            batch = pickle.load(f.extractfile(name))
          else:
            batch = pickle.load(
              f.extractfile(name), encoding='bytes')
          for item in read_batch(batch):
            yield item
        if not cycle:
          break

  return reader