Update SuperMLP
This commit is contained in:
		| @@ -7,22 +7,63 @@ from log_utils import time_string | ||||
| from utils import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): | ||||
| def basic_train( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
| @@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e | ||||
|  | ||||
|     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|     logger.log( | ||||
|         "[{:5s}] config ::  auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1) | ||||
|         "[{:5s}] config ::  auxiliary={:}".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1 | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
| @@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e | ||||
|  | ||||
|         features, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
| @@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -78,23 +78,42 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): | ||||
| def evaluate_for_seed( | ||||
|     arch_config, opt_config, train_loader, valid_loaders, seed: int, logger | ||||
| ): | ||||
|  | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net(arch_config) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, opt_config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|     default_device = torch.cuda.current_device() | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda( | ||||
|         device=default_device | ||||
|     ) | ||||
|     criterion = criterion.cuda(device=default_device) | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|     train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         opt_config.epochs + opt_config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times, lrs = {}, {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
| @@ -120,7 +139,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( | ||||
|                 time_string(), | ||||
| @@ -171,14 +192,29 @@ def get_nas_bench_loaders(workers): | ||||
|     break_line = "-" * 150 | ||||
|     print("{:} Create data-loader for all datasets".format(time_string())) | ||||
|     print(break_line) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( | ||||
|         "cifar10", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None) | ||||
|     assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [ | ||||
|     cifar10_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar10_splits.train[:10] == [ | ||||
|         0, | ||||
|         5, | ||||
|         7, | ||||
|         11, | ||||
|         13, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] and cifar10_splits.valid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
| @@ -194,7 +230,11 @@ def get_nas_bench_loaders(workers): | ||||
|     temp_dataset.transform = VALID_CIFAR10.transform | ||||
|     # data loader | ||||
|     trainval_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     train_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
| @@ -211,7 +251,11 @@ def get_nas_bench_loaders(workers): | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar10_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True | ||||
|         VALID_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : trval-loader has {:3d} batch with {:} per batch".format( | ||||
| @@ -235,14 +279,29 @@ def get_nas_bench_loaders(workers): | ||||
|     ) | ||||
|     print(break_line) | ||||
|     # CIFAR-100 | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1) | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( | ||||
|         "cifar100", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None) | ||||
|     assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [ | ||||
|     cifar100_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar100_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         8, | ||||
|         10, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|     ] and cifar100_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         2, | ||||
|         6, | ||||
| @@ -255,7 +314,11 @@ def get_nas_bench_loaders(workers): | ||||
|         24, | ||||
|     ] | ||||
|     train_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True | ||||
|         TRAIN_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
| @@ -271,9 +334,15 @@ def get_nas_bench_loaders(workers): | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print("CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader))) | ||||
|     print("CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))) | ||||
|     print("CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader))) | ||||
|     print( | ||||
|         "CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader)) | ||||
|     ) | ||||
|     print(break_line) | ||||
|  | ||||
|     imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
| @@ -286,8 +355,23 @@ def get_nas_bench_loaders(workers): | ||||
|             len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None) | ||||
|     assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [ | ||||
|     imagenet_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", | ||||
|         None, | ||||
|         None, | ||||
|     ) | ||||
|     assert imagenet_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         12, | ||||
|         16, | ||||
|         18, | ||||
|     ] and imagenet_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         4, | ||||
|         5, | ||||
|   | ||||
| @@ -14,7 +14,9 @@ class _LRScheduler(object): | ||||
|         self.optimizer = optimizer | ||||
|         for group in optimizer.param_groups: | ||||
|             group.setdefault("initial_lr", group["lr"]) | ||||
|         self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) | ||||
|         self.base_lrs = list( | ||||
|             map(lambda group: group["initial_lr"], optimizer.param_groups) | ||||
|         ) | ||||
|         self.max_epochs = epochs | ||||
|         self.warmup_epochs = warmup_epochs | ||||
|         self.current_epoch = 0 | ||||
| @@ -31,7 +33,9 @@ class _LRScheduler(object): | ||||
|         ) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         return {key: value for key, value in self.__dict__.items() if key != "optimizer"} | ||||
|         return { | ||||
|             key: value for key, value in self.__dict__.items() if key != "optimizer" | ||||
|         } | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.__dict__.update(state_dict) | ||||
| @@ -50,10 +54,14 @@ class _LRScheduler(object): | ||||
|  | ||||
|     def update(self, cur_epoch, cur_iter): | ||||
|         if cur_epoch is not None: | ||||
|             assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             assert ( | ||||
|                 isinstance(cur_epoch, int) and cur_epoch >= 0 | ||||
|             ), "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             self.current_epoch = cur_epoch | ||||
|         if cur_iter is not None: | ||||
|             assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter) | ||||
|             assert ( | ||||
|                 isinstance(cur_iter, float) and cur_iter >= 0 | ||||
|             ), "invalid cur-iter : {:}".format(cur_iter) | ||||
|             self.current_iter = cur_iter | ||||
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|             param_group["lr"] = lr | ||||
| @@ -66,29 +74,44 @@ class CosineAnnealingLR(_LRScheduler): | ||||
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min) | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format( | ||||
|             "cosine", self.T_max, self.eta_min | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs: | ||||
|             if ( | ||||
|                 self.current_epoch >= self.warmup_epochs | ||||
|                 and self.current_epoch < self.max_epochs | ||||
|             ): | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 # if last_epoch < self.T_max: | ||||
|                 # if last_epoch < self.max_epochs: | ||||
|                 lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 | ||||
|                 lr = ( | ||||
|                     self.eta_min | ||||
|                     + (base_lr - self.eta_min) | ||||
|                     * (1 + math.cos(math.pi * last_epoch / self.T_max)) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 # else: | ||||
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|             elif self.current_epoch >= self.max_epochs: | ||||
|                 lr = self.eta_min | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas)) | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( | ||||
|             len(milestones), len(gammas) | ||||
|         ) | ||||
|         self.milestones = milestones | ||||
|         self.gammas = gammas | ||||
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
| @@ -108,7 +131,10 @@ class MultiStepLR(_LRScheduler): | ||||
|                 for x in self.gammas[:idx]: | ||||
|                     lr *= x | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
| @@ -119,7 +145,9 @@ class ExponentialLR(_LRScheduler): | ||||
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs) | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format( | ||||
|             "exponential", self.gamma, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
| @@ -129,7 +157,10 @@ class ExponentialLR(_LRScheduler): | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 lr = base_lr * (self.gamma ** last_epoch) | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
| @@ -151,10 +182,18 @@ class LinearLR(_LRScheduler): | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR | ||||
|                 ratio = ( | ||||
|                     (self.max_LR - self.min_LR) | ||||
|                     * last_epoch | ||||
|                     / self.max_epochs | ||||
|                     / self.max_LR | ||||
|                 ) | ||||
|                 lr = base_lr * (1 - ratio) | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
| @@ -176,26 +215,42 @@ class CrossEntropyLabelSmooth(nn.Module): | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|     assert ( | ||||
|         hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format(config) | ||||
|         hasattr(config, "optim") | ||||
|         and hasattr(config, "scheduler") | ||||
|         and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format( | ||||
|         config | ||||
|     ) | ||||
|     if config.optim == "SGD": | ||||
|         optim = torch.optim.SGD( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov | ||||
|             parameters, | ||||
|             config.LR, | ||||
|             momentum=config.momentum, | ||||
|             weight_decay=config.decay, | ||||
|             nesterov=config.nesterov, | ||||
|         ) | ||||
|     elif config.optim == "RMSprop": | ||||
|         optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|         optim = torch.optim.RMSprop( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid optim : {:}".format(config.optim)) | ||||
|  | ||||
|     if config.scheduler == "cos": | ||||
|         T_max = getattr(config, "T_max", config.epochs) | ||||
|         scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) | ||||
|         scheduler = CosineAnnealingLR( | ||||
|             optim, config.warmup, config.epochs, T_max, config.eta_min | ||||
|         ) | ||||
|     elif config.scheduler == "multistep": | ||||
|         scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) | ||||
|         scheduler = MultiStepLR( | ||||
|             optim, config.warmup, config.epochs, config.milestones, config.gammas | ||||
|         ) | ||||
|     elif config.scheduler == "exponential": | ||||
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|     elif config.scheduler == "linear": | ||||
|         scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) | ||||
|         scheduler = LinearLR( | ||||
|             optim, config.warmup, config.epochs, config.LR, config.LR_min | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | ||||
|  | ||||
|   | ||||
| @@ -41,7 +41,10 @@ def update_gpu(config, gpu): | ||||
|     if "task" in config and "model" in config["task"]: | ||||
|         if "GPU" in config["task"]["model"]: | ||||
|             config["task"]["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: | ||||
|         elif ( | ||||
|             "kwargs" in config["task"]["model"] | ||||
|             and "GPU" in config["task"]["model"]["kwargs"] | ||||
|         ): | ||||
|             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "model" in config: | ||||
|         if "GPU" in config["model"]: | ||||
| @@ -68,7 +71,12 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|     model_fit_kwargs = dict(dataset=dataset) | ||||
|  | ||||
|     # Let's start the experiment. | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True): | ||||
|     with R.start( | ||||
|         experiment_name=experiment_name, | ||||
|         recorder_name=recorder_name, | ||||
|         uri=uri, | ||||
|         resume=True, | ||||
|     ): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
|   | ||||
| @@ -36,7 +36,12 @@ def search_train( | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
| @@ -46,10 +51,16 @@ def search_train( | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
| @@ -75,7 +86,9 @@ def search_train( | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
| @@ -90,7 +103,11 @@ def search_train( | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
| @@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|             end = time.time() | ||||
|  | ||||
|             if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|                 Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|                 Sstr = ( | ||||
|                     "**VALID** " | ||||
|                     + time_string() | ||||
|                     + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|                 ) | ||||
|                 Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                     batch_time=batch_time, data_time=data_time | ||||
|                 ) | ||||
| @@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|  | ||||
|     logger.log( | ||||
|         " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|   | ||||
| @@ -36,7 +36,12 @@ def search_train_v2( | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
| @@ -46,10 +51,16 @@ def search_train_v2( | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
| @@ -73,7 +84,9 @@ def search_train_v2( | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
| @@ -88,7 +101,11 @@ def search_train_v2( | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|   | ||||
| @@ -10,7 +10,16 @@ from utils import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train( | ||||
|     xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
| @@ -28,25 +37,58 @@ def simple_KD_train( | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
| def simple_KD_valid( | ||||
|     xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger | ||||
|             xloader, | ||||
|             teacher, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             optim_config, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def loss_KD_fn( | ||||
|     criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature | ||||
|     criterion, | ||||
|     student_logits, | ||||
|     teacher_logits, | ||||
|     studentFeatures, | ||||
|     teacherFeatures, | ||||
|     targets, | ||||
|     alpha, | ||||
|     temperature, | ||||
| ): | ||||
|     basic_loss = criterion(student_logits, targets) * (1.0 - alpha) | ||||
|     log_student = F.log_softmax(student_logits / temperature, dim=1) | ||||
|     sof_teacher = F.softmax(teacher_logits / temperature, dim=1) | ||||
|     KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature) | ||||
|     KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * ( | ||||
|         alpha * temperature * temperature | ||||
|     ) | ||||
|     return basic_loss + KD_loss | ||||
|  | ||||
|  | ||||
| def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
| def procedure( | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
| @@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, | ||||
|  | ||||
|     logger.log( | ||||
|         "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature | ||||
|             mode, | ||||
|             config.auxiliary if hasattr(config, "auxiliary") else -1, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
| @@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, | ||||
|  | ||||
|         student_f, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
| @@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, | ||||
|             teacher_f, teacher_logits = teacher(inputs) | ||||
|  | ||||
|         loss = loss_KD_fn( | ||||
|             criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature | ||||
|             criterion, | ||||
|             logits, | ||||
|             teacher_logits, | ||||
|             student_f, | ||||
|             teacher_f, | ||||
|             targets, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
| @@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, | ||||
|     ) | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -31,7 +31,9 @@ def prepare_logger(xargs): | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None" | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] | ||||
|             if "CUDA_VISIBLE_DEVICES" in os.environ | ||||
|             else "None" | ||||
|         ) | ||||
|     ) | ||||
|     return logger | ||||
| @@ -54,10 +56,14 @@ def get_machine_info(): | ||||
| def save_checkpoint(state, filename, logger): | ||||
|     if osp.isfile(filename): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log("Find {:} exist, delete is at first before saving".format(filename)) | ||||
|             logger.log( | ||||
|                 "Find {:} exist, delete is at first before saving".format(filename) | ||||
|             ) | ||||
|         os.remove(filename) | ||||
|     torch.save(state, filename) | ||||
|     assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename) | ||||
|     assert osp.isfile( | ||||
|         filename | ||||
|     ), "save filename : {:} failed, which is not found.".format(filename) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("save checkpoint into {:}".format(filename)) | ||||
|     return filename | ||||
|   | ||||
		Reference in New Issue
	
	Block a user