33 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			33 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
 | |
| #####################################################
 | |
| import random
 | |
| 
 | |
| 
 | |
| class BatchSampler:
 | |
|     """A batch sampler used for single machine training."""
 | |
| 
 | |
|     def __init__(self, dataset, batch, steps):
 | |
|         self._num_per_epoch = len(dataset)
 | |
|         self._iter_per_epoch = self._num_per_epoch // batch
 | |
|         self._steps = steps
 | |
|         self._batch = batch
 | |
|         if self._num_per_epoch < self._batch:
 | |
|             raise ValueError(
 | |
|                 "The dataset size must be larger than batch={:}".format(batch)
 | |
|             )
 | |
|         self._indexes = list(range(self._num_per_epoch))
 | |
| 
 | |
|     def __iter__(self):
 | |
|         """
 | |
|         yield a batch of indexes using random sampling
 | |
|         """
 | |
|         for i in range(self._steps):
 | |
|             if i % self._iter_per_epoch == 0:
 | |
|                 random.shuffle(self._indexes)
 | |
|             j = i % self._iter_per_epoch
 | |
|             yield self._indexes[j * self._batch : (j + 1) * self._batch]
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self._steps
 |