##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # ##################################################### import random class BatchSampler: """A batch sampler used for single machine training.""" def __init__(self, dataset, batch, steps): self._num_per_epoch = len(dataset) self._iter_per_epoch = self._num_per_epoch // batch self._steps = steps self._batch = batch if self._num_per_epoch < self._batch: raise ValueError( "The dataset size must be larger than batch={:}".format(batch) ) self._indexes = list(range(self._num_per_epoch)) def __iter__(self): """ yield a batch of indexes using random sampling """ for i in range(self._steps): if i % self._iter_per_epoch == 0: random.shuffle(self._indexes) j = i % self._iter_per_epoch yield self._indexes[j * self._batch : (j + 1) * self._batch] def __len__(self): return self._steps