Update xmisc.scheduler/sampler
This commit is contained in:
		
							
								
								
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | ||||
| #################################################### | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. # | ||||
| #################################################### | ||||
| # Inspired from https://github.com/facebookresearch/detectron2/blob/master/tests/test_scheduler.py | ||||
| #################################################### | ||||
| import math | ||||
| import numpy as np | ||||
| from unittest import TestCase | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from xautodl.xmisc.scheduler_utils import CosineParamScheduler, MultiStepParamScheduler | ||||
| from xautodl.xmisc.scheduler_utils import LRMultiplier, WarmupParamScheduler | ||||
|  | ||||
|  | ||||
| class TestScheduler(TestCase): | ||||
|     """Test the scheduler.""" | ||||
|  | ||||
|     def test_warmup_multistep(self): | ||||
|         p = torch.nn.Parameter(torch.zeros(0)) | ||||
|         opt = torch.optim.SGD([p], lr=5) | ||||
|  | ||||
|         multiplier = WarmupParamScheduler( | ||||
|             MultiStepParamScheduler( | ||||
|                 [1, 0.1, 0.01, 0.001], | ||||
|                 milestones=[10, 15, 20], | ||||
|                 num_updates=30, | ||||
|             ), | ||||
|             0.001, | ||||
|             5 / 30, | ||||
|         ) | ||||
|         sched = LRMultiplier(opt, multiplier, 30) | ||||
|         # This is an equivalent of: | ||||
|         # sched = WarmupMultiStepLR( | ||||
|         # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5) | ||||
|  | ||||
|         p.sum().backward() | ||||
|         opt.step() | ||||
|  | ||||
|         lrs = [0.005] | ||||
|         for _ in range(30): | ||||
|             sched.step() | ||||
|             lrs.append(opt.param_groups[0]["lr"]) | ||||
|         self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) | ||||
|         self.assertTrue(np.allclose(lrs[5:10], 5.0)) | ||||
|         self.assertTrue(np.allclose(lrs[10:15], 0.5)) | ||||
|         self.assertTrue(np.allclose(lrs[15:20], 0.05)) | ||||
|         self.assertTrue(np.allclose(lrs[20:], 0.005)) | ||||
|  | ||||
|     def test_warmup_cosine(self): | ||||
|         p = torch.nn.Parameter(torch.zeros(0)) | ||||
|         opt = torch.optim.SGD([p], lr=5) | ||||
|         multiplier = WarmupParamScheduler( | ||||
|             CosineParamScheduler(1, 0), | ||||
|             0.001, | ||||
|             5 / 30, | ||||
|         ) | ||||
|         sched = LRMultiplier(opt, multiplier, 30) | ||||
|  | ||||
|         p.sum().backward() | ||||
|         opt.step() | ||||
|         self.assertEqual(opt.param_groups[0]["lr"], 0.005) | ||||
|         lrs = [0.005] | ||||
|  | ||||
|         for _ in range(30): | ||||
|             sched.step() | ||||
|             lrs.append(opt.param_groups[0]["lr"]) | ||||
|         for idx, lr in enumerate(lrs): | ||||
|             expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30)) | ||||
|             if idx >= 5: | ||||
|                 self.assertAlmostEqual(lr, expected_cosine) | ||||
|             else: | ||||
|                 self.assertNotAlmostEqual(lr, expected_cosine) | ||||
		Reference in New Issue
	
	Block a user