65 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import random, tarfile
 | 
						|
import numpy, six
 | 
						|
from six.moves import cPickle as pickle
 | 
						|
from PIL import Image, ImageOps
 | 
						|
 | 
						|
 | 
						|
def train_cifar_augmentation(image):
 | 
						|
  # flip
 | 
						|
  if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
 | 
						|
  else:                     image1 = image
 | 
						|
  # random crop
 | 
						|
  image2 = ImageOps.expand(image1, border=4, fill=0)
 | 
						|
  i = random.randint(0, 40 - 32)
 | 
						|
  j = random.randint(0, 40 - 32)
 | 
						|
  image3 = image2.crop((j,i,j+32,i+32))
 | 
						|
  # to numpy
 | 
						|
  image3 = numpy.array(image3) / 255.0
 | 
						|
  mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
 | 
						|
  std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
 | 
						|
  return (image3 - mean) / std
 | 
						|
 | 
						|
 | 
						|
def valid_cifar_augmentation(image):
 | 
						|
  image3 = numpy.array(image) / 255.0
 | 
						|
  mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
 | 
						|
  std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
 | 
						|
  return (image3 - mean) / std
 | 
						|
 | 
						|
 | 
						|
def reader_creator(filename, sub_name, is_train, cycle=False):
 | 
						|
  def read_batch(batch):
 | 
						|
    data = batch[six.b('data')]
 | 
						|
    labels = batch.get(
 | 
						|
      six.b('labels'), batch.get(six.b('fine_labels'), None))
 | 
						|
    assert labels is not None
 | 
						|
    for sample, label in six.moves.zip(data, labels):
 | 
						|
      sample = sample.reshape(3, 32, 32)
 | 
						|
      sample = sample.transpose((1, 2, 0))
 | 
						|
      image  = Image.fromarray(sample)
 | 
						|
      if is_train:
 | 
						|
        ximage = train_cifar_augmentation(image)
 | 
						|
      else:
 | 
						|
        ximage = valid_cifar_augmentation(image)
 | 
						|
      ximage = ximage.transpose((2, 0, 1))
 | 
						|
      yield ximage.astype(numpy.float32), int(label)
 | 
						|
 | 
						|
  def reader():
 | 
						|
    with tarfile.open(filename, mode='r') as f:
 | 
						|
      names = (each_item.name for each_item in f
 | 
						|
           if sub_name in each_item.name)
 | 
						|
 | 
						|
      while True:
 | 
						|
        for name in names:
 | 
						|
          if six.PY2:
 | 
						|
            batch = pickle.load(f.extractfile(name))
 | 
						|
          else:
 | 
						|
            batch = pickle.load(
 | 
						|
              f.extractfile(name), encoding='bytes')
 | 
						|
          for item in read_batch(batch):
 | 
						|
            yield item
 | 
						|
        if not cycle:
 | 
						|
          break
 | 
						|
 | 
						|
  return reader
 |