33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
|
#####################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||
|
#####################################################
|
||
|
import random
|
||
|
|
||
|
|
||
|
class BatchSampler:
|
||
|
"""A batch sampler used for single machine training."""
|
||
|
|
||
|
def __init__(self, dataset, batch, steps):
|
||
|
self._num_per_epoch = len(dataset)
|
||
|
self._iter_per_epoch = self._num_per_epoch // batch
|
||
|
self._steps = steps
|
||
|
self._batch = batch
|
||
|
if self._num_per_epoch < self._batch:
|
||
|
raise ValueError(
|
||
|
"The dataset size must be larger than batch={:}".format(batch)
|
||
|
)
|
||
|
self._indexes = list(range(self._num_per_epoch))
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""
|
||
|
yield a batch of indexes using random sampling
|
||
|
"""
|
||
|
for i in range(self._steps):
|
||
|
if i % self._iter_per_epoch == 0:
|
||
|
random.shuffle(self._indexes)
|
||
|
j = i % self._iter_per_epoch
|
||
|
yield self._indexes[j * self._batch : (j + 1) * self._batch]
|
||
|
|
||
|
def __len__(self):
|
||
|
return self._steps
|