diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 0887b61..3053656 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -35,6 +35,7 @@ jobs: python -m black ./lib/xlayers -l 88 --check --diff --verbose python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose + python -m black ./lib/procedures -l 88 --check --diff --verbose - name: Test Search Space run: | diff --git a/lib/procedures/basic_main.py b/lib/procedures/basic_main.py index 60c3264..50f0a33 100644 --- a/lib/procedures/basic_main.py +++ b/lib/procedures/basic_main.py @@ -7,22 +7,63 @@ from log_utils import time_string from utils import obtain_accuracy -def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): +def basic_train( + xloader, + network, + criterion, + scheduler, + optimizer, + optim_config, + extra_info, + print_freq, + logger, +): loss, acc1, acc5 = procedure( - xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger + xloader, + network, + criterion, + scheduler, + optimizer, + "train", + optim_config, + extra_info, + print_freq, + logger, ) return loss, acc1, acc5 -def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): +def basic_valid( + xloader, network, criterion, optim_config, extra_info, print_freq, logger +): with torch.no_grad(): loss, acc1, acc5 = procedure( - xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger + xloader, + network, + criterion, + None, + None, + "valid", + None, + extra_info, + print_freq, + logger, ) return loss, acc1, acc5 -def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): +def procedure( + xloader, + network, + criterion, + scheduler, + optimizer, + mode, + config, + extra_info, + print_freq, + logger, +): data_time, batch_time, losses, top1, top5 = ( AverageMeter(), AverageMeter(), @@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e # logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) logger.log( - "[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1) + "[{:5s}] config :: auxiliary={:}".format( + mode, config.auxiliary if hasattr(config, "auxiliary") else -1 + ) ) end = time.time() for i, (inputs, targets) in enumerate(xloader): @@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e features, logits = network(inputs) if isinstance(logits, list): - assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) + assert len(logits) == 2, "logits must has {:} items instead of {:}".format( + 2, len(logits) + ) logits, logits_aux = logits else: logits, logits_aux = logits, None @@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e logger.log( " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( - mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg + mode=mode.upper(), + top1=top1, + top5=top5, + error1=100 - top1.avg, + error5=100 - top5.avg, + loss=losses.avg, ) ) return losses.avg, top1.avg, top5.avg diff --git a/lib/procedures/funcs_nasbench.py b/lib/procedures/funcs_nasbench.py index 681d11d..0cd7103 100644 --- a/lib/procedures/funcs_nasbench.py +++ b/lib/procedures/funcs_nasbench.py @@ -78,23 +78,42 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): return losses.avg, top1.avg, top5.avg, batch_time.sum -def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): +def evaluate_for_seed( + arch_config, opt_config, train_loader, valid_loaders, seed: int, logger +): prepare_seed(seed) # random seed net = get_cell_based_tiny_net(arch_config) # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) flop, param = get_model_infos(net, opt_config.xshape) logger.log("Network : {:}".format(net.get_message()), False) - logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) + logger.log( + "{:} Seed-------------------------- {:} --------------------------".format( + time_string(), seed + ) + ) logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) # train and valid optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) default_device = torch.cuda.current_device() - network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) + network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda( + device=default_device + ) criterion = criterion.cuda(device=default_device) # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup - train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} + start_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + opt_config.epochs + opt_config.warmup, + ) + ( + train_losses, + train_acc1es, + train_acc5es, + valid_losses, + valid_acc1es, + valid_acc5es, + ) = ({}, {}, {}, {}, {}, {}) train_times, valid_times, lrs = {}, {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) @@ -120,7 +139,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) + ) logger.log( "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( time_string(), @@ -171,14 +192,29 @@ def get_nas_bench_loaders(workers): break_line = "-" * 150 print("{:} Create data-loader for all datasets".format(time_string())) print(break_line) - TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1) + TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( + "cifar10", str(torch_dir / "cifar.python"), -1 + ) print( "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num ) ) - cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None) - assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [ + cifar10_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None + ) + assert cifar10_splits.train[:10] == [ + 0, + 5, + 7, + 11, + 13, + 15, + 16, + 17, + 20, + 24, + ] and cifar10_splits.valid[:10] == [ 1, 2, 3, @@ -194,7 +230,11 @@ def get_nas_bench_loaders(workers): temp_dataset.transform = VALID_CIFAR10.transform # data loader trainval_cifar10_loader = torch.utils.data.DataLoader( - TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + TRAIN_CIFAR10, + batch_size=cifar_config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) train_cifar10_loader = torch.utils.data.DataLoader( TRAIN_CIFAR10, @@ -211,7 +251,11 @@ def get_nas_bench_loaders(workers): pin_memory=True, ) test__cifar10_loader = torch.utils.data.DataLoader( - VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + VALID_CIFAR10, + batch_size=cifar_config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) print( "CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format( @@ -235,14 +279,29 @@ def get_nas_bench_loaders(workers): ) print(break_line) # CIFAR-100 - TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1) + TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( + "cifar100", str(torch_dir / "cifar.python"), -1 + ) print( "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num ) ) - cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None) - assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [ + cifar100_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None + ) + assert cifar100_splits.xvalid[:10] == [ + 1, + 3, + 4, + 5, + 8, + 10, + 13, + 14, + 15, + 16, + ] and cifar100_splits.xtest[:10] == [ 0, 2, 6, @@ -255,7 +314,11 @@ def get_nas_bench_loaders(workers): 24, ] train_cifar100_loader = torch.utils.data.DataLoader( - TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + TRAIN_CIFAR100, + batch_size=cifar_config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) valid_cifar100_loader = torch.utils.data.DataLoader( VALID_CIFAR100, @@ -271,9 +334,15 @@ def get_nas_bench_loaders(workers): num_workers=workers, pin_memory=True, ) - print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))) - print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))) - print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))) + print( + "CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)) + ) + print( + "CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) + ) + print( + "CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)) + ) print(break_line) imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" @@ -286,8 +355,23 @@ def get_nas_bench_loaders(workers): len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num ) ) - imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None) - assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [ + imagenet_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", + None, + None, + ) + assert imagenet_splits.xvalid[:10] == [ + 1, + 2, + 3, + 6, + 7, + 8, + 9, + 12, + 16, + 18, + ] and imagenet_splits.xtest[:10] == [ 0, 4, 5, diff --git a/lib/procedures/optimizers.py b/lib/procedures/optimizers.py index 9f3143c..3dd1f00 100644 --- a/lib/procedures/optimizers.py +++ b/lib/procedures/optimizers.py @@ -14,7 +14,9 @@ class _LRScheduler(object): self.optimizer = optimizer for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) + self.base_lrs = list( + map(lambda group: group["initial_lr"], optimizer.param_groups) + ) self.max_epochs = epochs self.warmup_epochs = warmup_epochs self.current_epoch = 0 @@ -31,7 +33,9 @@ class _LRScheduler(object): ) def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key != "optimizer"} + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } def load_state_dict(self, state_dict): self.__dict__.update(state_dict) @@ -50,10 +54,14 @@ class _LRScheduler(object): def update(self, cur_epoch, cur_iter): if cur_epoch is not None: - assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch) + assert ( + isinstance(cur_epoch, int) and cur_epoch >= 0 + ), "invalid cur-epoch : {:}".format(cur_epoch) self.current_epoch = cur_epoch if cur_iter is not None: - assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter) + assert ( + isinstance(cur_iter, float) and cur_iter >= 0 + ), "invalid cur-iter : {:}".format(cur_iter) self.current_iter = cur_iter for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr @@ -66,29 +74,44 @@ class CosineAnnealingLR(_LRScheduler): super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): - return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min) + return "type={:}, T-max={:}, eta-min={:}".format( + "cosine", self.T_max, self.eta_min + ) def get_lr(self): lrs = [] for base_lr in self.base_lrs: - if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs: + if ( + self.current_epoch >= self.warmup_epochs + and self.current_epoch < self.max_epochs + ): last_epoch = self.current_epoch - self.warmup_epochs # if last_epoch < self.T_max: # if last_epoch < self.max_epochs: - lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 + lr = ( + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * last_epoch / self.T_max)) + / 2 + ) # else: # lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 elif self.current_epoch >= self.max_epochs: lr = self.eta_min else: - lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr + lr = ( + self.current_epoch / self.warmup_epochs + + self.current_iter / self.warmup_epochs + ) * base_lr lrs.append(lr) return lrs class MultiStepLR(_LRScheduler): def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): - assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas)) + assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( + len(milestones), len(gammas) + ) self.milestones = milestones self.gammas = gammas super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) @@ -108,7 +131,10 @@ class MultiStepLR(_LRScheduler): for x in self.gammas[:idx]: lr *= x else: - lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr + lr = ( + self.current_epoch / self.warmup_epochs + + self.current_iter / self.warmup_epochs + ) * base_lr lrs.append(lr) return lrs @@ -119,7 +145,9 @@ class ExponentialLR(_LRScheduler): super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): - return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs) + return "type={:}, gamma={:}, base-lrs={:}".format( + "exponential", self.gamma, self.base_lrs + ) def get_lr(self): lrs = [] @@ -129,7 +157,10 @@ class ExponentialLR(_LRScheduler): assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) lr = base_lr * (self.gamma ** last_epoch) else: - lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr + lr = ( + self.current_epoch / self.warmup_epochs + + self.current_iter / self.warmup_epochs + ) * base_lr lrs.append(lr) return lrs @@ -151,10 +182,18 @@ class LinearLR(_LRScheduler): if self.current_epoch >= self.warmup_epochs: last_epoch = self.current_epoch - self.warmup_epochs assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) - ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR + ratio = ( + (self.max_LR - self.min_LR) + * last_epoch + / self.max_epochs + / self.max_LR + ) lr = base_lr * (1 - ratio) else: - lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr + lr = ( + self.current_epoch / self.warmup_epochs + + self.current_iter / self.warmup_epochs + ) * base_lr lrs.append(lr) return lrs @@ -176,26 +215,42 @@ class CrossEntropyLabelSmooth(nn.Module): def get_optim_scheduler(parameters, config): assert ( - hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion") - ), "config must have optim / scheduler / criterion keys instead of {:}".format(config) + hasattr(config, "optim") + and hasattr(config, "scheduler") + and hasattr(config, "criterion") + ), "config must have optim / scheduler / criterion keys instead of {:}".format( + config + ) if config.optim == "SGD": optim = torch.optim.SGD( - parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov + parameters, + config.LR, + momentum=config.momentum, + weight_decay=config.decay, + nesterov=config.nesterov, ) elif config.optim == "RMSprop": - optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) + optim = torch.optim.RMSprop( + parameters, config.LR, momentum=config.momentum, weight_decay=config.decay + ) else: raise ValueError("invalid optim : {:}".format(config.optim)) if config.scheduler == "cos": T_max = getattr(config, "T_max", config.epochs) - scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) + scheduler = CosineAnnealingLR( + optim, config.warmup, config.epochs, T_max, config.eta_min + ) elif config.scheduler == "multistep": - scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) + scheduler = MultiStepLR( + optim, config.warmup, config.epochs, config.milestones, config.gammas + ) elif config.scheduler == "exponential": scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) elif config.scheduler == "linear": - scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) + scheduler = LinearLR( + optim, config.warmup, config.epochs, config.LR, config.LR_min + ) else: raise ValueError("invalid scheduler : {:}".format(config.scheduler)) diff --git a/lib/procedures/q_exps.py b/lib/procedures/q_exps.py index c31efde..5887cfb 100644 --- a/lib/procedures/q_exps.py +++ b/lib/procedures/q_exps.py @@ -41,7 +41,10 @@ def update_gpu(config, gpu): if "task" in config and "model" in config["task"]: if "GPU" in config["task"]["model"]: config["task"]["model"]["GPU"] = gpu - elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: + elif ( + "kwargs" in config["task"]["model"] + and "GPU" in config["task"]["model"]["kwargs"] + ): config["task"]["model"]["kwargs"]["GPU"] = gpu elif "model" in config: if "GPU" in config["model"]: @@ -68,7 +71,12 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): model_fit_kwargs = dict(dataset=dataset) # Let's start the experiment. - with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True): + with R.start( + experiment_name=experiment_name, + recorder_name=recorder_name, + uri=uri, + resume=True, + ): # Setup log recorder_root_dir = R.get_recorder().get_local_dir() log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) diff --git a/lib/procedures/search_main.py b/lib/procedures/search_main.py index ecdea00..9ad67cf 100644 --- a/lib/procedures/search_main.py +++ b/lib/procedures/search_main.py @@ -36,7 +36,12 @@ def search_train( logger, ): data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() + base_losses, arch_losses, top1, top5 = ( + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + ) arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() epoch_str, flop_need, flop_weight, flop_tolerant = ( extra_info["epoch-str"], @@ -46,10 +51,16 @@ def search_train( ) network.train() - logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) + logger.log( + "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( + epoch_str, flop_need, flop_weight + ) + ) end = time.time() network.apply(change_key("search_mode", "search")) - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + search_loader + ): scheduler.update(None, 1.0 * step / len(search_loader)) # calculate prediction and loss base_targets = base_targets.cuda(non_blocking=True) @@ -75,7 +86,9 @@ def search_train( arch_optimizer.zero_grad() logits, expected_flop = network(arch_inputs) flop_cur = network.module.get_flop("genotype", None, None) - flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) + flop_loss, flop_loss_scale = get_flop_loss( + expected_flop, flop_cur, flop_need, flop_tolerant + ) acls_loss = criterion(logits, arch_targets) arch_loss = acls_loss + flop_loss * flop_weight arch_loss.backward() @@ -90,7 +103,11 @@ def search_train( batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or (step + 1) == len(search_loader): - Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) + Sstr = ( + "**TRAIN** " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger): end = time.time() if i % print_freq == 0 or (i + 1) == len(xloader): - Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) + Sstr = ( + "**VALID** " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger): logger.log( " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( - top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg + top1=top1, + top5=top5, + error1=100 - top1.avg, + error5=100 - top5.avg, + loss=losses.avg, ) ) diff --git a/lib/procedures/search_main_v2.py b/lib/procedures/search_main_v2.py index 0b1fbca..ce3cbed 100644 --- a/lib/procedures/search_main_v2.py +++ b/lib/procedures/search_main_v2.py @@ -36,7 +36,12 @@ def search_train_v2( logger, ): data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() + base_losses, arch_losses, top1, top5 = ( + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + ) arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() epoch_str, flop_need, flop_weight, flop_tolerant = ( extra_info["epoch-str"], @@ -46,10 +51,16 @@ def search_train_v2( ) network.train() - logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) + logger.log( + "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( + epoch_str, flop_need, flop_weight + ) + ) end = time.time() network.apply(change_key("search_mode", "search")) - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + search_loader + ): scheduler.update(None, 1.0 * step / len(search_loader)) # calculate prediction and loss base_targets = base_targets.cuda(non_blocking=True) @@ -73,7 +84,9 @@ def search_train_v2( arch_optimizer.zero_grad() logits, expected_flop = network(arch_inputs) flop_cur = network.module.get_flop("genotype", None, None) - flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) + flop_loss, flop_loss_scale = get_flop_loss( + expected_flop, flop_cur, flop_need, flop_tolerant + ) acls_loss = criterion(logits, arch_targets) arch_loss = acls_loss + flop_loss * flop_weight arch_loss.backward() @@ -88,7 +101,11 @@ def search_train_v2( batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or (step + 1) == len(search_loader): - Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) + Sstr = ( + "**TRAIN** " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) diff --git a/lib/procedures/simple_KD_main.py b/lib/procedures/simple_KD_main.py index cc7b3da..025f6e6 100644 --- a/lib/procedures/simple_KD_main.py +++ b/lib/procedures/simple_KD_main.py @@ -10,7 +10,16 @@ from utils import obtain_accuracy def simple_KD_train( - xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger + xloader, + teacher, + network, + criterion, + scheduler, + optimizer, + optim_config, + extra_info, + print_freq, + logger, ): loss, acc1, acc5 = procedure( xloader, @@ -28,25 +37,58 @@ def simple_KD_train( return loss, acc1, acc5 -def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): +def simple_KD_valid( + xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger +): with torch.no_grad(): loss, acc1, acc5 = procedure( - xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger + xloader, + teacher, + network, + criterion, + None, + None, + "valid", + optim_config, + extra_info, + print_freq, + logger, ) return loss, acc1, acc5 def loss_KD_fn( - criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature + criterion, + student_logits, + teacher_logits, + studentFeatures, + teacherFeatures, + targets, + alpha, + temperature, ): basic_loss = criterion(student_logits, targets) * (1.0 - alpha) log_student = F.log_softmax(student_logits / temperature, dim=1) sof_teacher = F.softmax(teacher_logits / temperature, dim=1) - KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature) + KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * ( + alpha * temperature * temperature + ) return basic_loss + KD_loss -def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): +def procedure( + xloader, + teacher, + network, + criterion, + scheduler, + optimizer, + mode, + config, + extra_info, + print_freq, + logger, +): data_time, batch_time, losses, top1, top5 = ( AverageMeter(), AverageMeter(), @@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, logger.log( "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format( - mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature + mode, + config.auxiliary if hasattr(config, "auxiliary") else -1, + config.KD_alpha, + config.KD_temperature, ) ) end = time.time() @@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, student_f, logits = network(inputs) if isinstance(logits, list): - assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) + assert len(logits) == 2, "logits must has {:} items instead of {:}".format( + 2, len(logits) + ) logits, logits_aux = logits else: logits, logits_aux = logits, None @@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, teacher_f, teacher_logits = teacher(inputs) loss = loss_KD_fn( - criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature + criterion, + logits, + teacher_logits, + student_f, + teacher_f, + targets, + config.KD_alpha, + config.KD_temperature, ) if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: loss_aux = criterion(logits_aux, targets) @@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, ) logger.log( " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( - mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg + mode=mode.upper(), + top1=top1, + top5=top5, + error1=100 - top1.avg, + error5=100 - top5.avg, + loss=losses.avg, ) ) return losses.avg, top1.avg, top5.avg diff --git a/lib/procedures/starts.py b/lib/procedures/starts.py index c93b968..1ae19c5 100644 --- a/lib/procedures/starts.py +++ b/lib/procedures/starts.py @@ -31,7 +31,9 @@ def prepare_logger(xargs): logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) logger.log( "CUDA_VISIBLE_DEVICES : {:}".format( - os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None" + os.environ["CUDA_VISIBLE_DEVICES"] + if "CUDA_VISIBLE_DEVICES" in os.environ + else "None" ) ) return logger @@ -54,10 +56,14 @@ def get_machine_info(): def save_checkpoint(state, filename, logger): if osp.isfile(filename): if hasattr(logger, "log"): - logger.log("Find {:} exist, delete is at first before saving".format(filename)) + logger.log( + "Find {:} exist, delete is at first before saving".format(filename) + ) os.remove(filename) torch.save(state, filename) - assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename) + assert osp.isfile( + filename + ), "save filename : {:} failed, which is not found.".format(filename) if hasattr(logger, "log"): logger.log("save checkpoint into {:}".format(filename)) return filename diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index fe7ee7d..d3a1571 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -50,6 +50,10 @@ class Space(metaclass=abc.ABCMeta): def clean_last_abstract(self): raise NotImplementedError + def clean_last(self): + self.clean_last_sample() + self.clean_last_abstract() + @abc.abstractproperty def determined(self) -> bool: raise NotImplementedError diff --git a/lib/xlayers/super_linear.py b/lib/xlayers/super_linear.py index 30420d3..d287ee9 100644 --- a/lib/xlayers/super_linear.py +++ b/lib/xlayers/super_linear.py @@ -147,11 +147,18 @@ class SuperMLP(SuperModule): root_node.append("fc2", space_fc2) return root_node + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperMLP, self).apply_candidate(abstract_child) + if "fc1" in abstract_child: + self.fc1.apply_candidate(abstract_child["fc1"]) + if "fc2" in abstract_child: + self.fc2.apply_candidate(abstract_child["fc2"]) + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: - return self._unified_forward(x) + return self._unified_forward(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: - return self._unified_forward(x) + return self._unified_forward(input) def _unified_forward(self, x): x = self.fc1(x) diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 1e0702c..119a65f 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -32,7 +32,7 @@ class SuperModule(abc.ABC, nn.Module): self.apply(_reset_super_run) - def apply_candiate(self, abstract_child): + def apply_candidate(self, abstract_child): if not isinstance(abstract_child, spaces.VirtualNode): raise ValueError( "Invalid abstract child program: {:}".format(abstract_child) diff --git a/tests/test_super_model.py b/tests/test_super_model.py index dfb75b7..946264c 100644 --- a/tests/test_super_model.py +++ b/tests/test_super_model.py @@ -30,32 +30,37 @@ class TestSuperLinear(unittest.TestCase): print(model.super_run_type) self.assertTrue(model.bias) - inputs = torch.rand(32, 10) + inputs = torch.rand(20, 10) print("Input shape: {:}".format(inputs.shape)) print("Weight shape: {:}".format(model._super_weight.shape)) print("Bias shape: {:}".format(model._super_bias.shape)) outputs = model(inputs) - self.assertEqual(tuple(outputs.shape), (32, 36)) + self.assertEqual(tuple(outputs.shape), (20, 36)) abstract_space = model.abstract_search_space + abstract_space.clean_last() abstract_child = abstract_space.random() print("The abstract searc space:\n{:}".format(abstract_space)) print("The abstract child program:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) - model.apply_candiate(abstract_child) + model.apply_candidate(abstract_child) - output_shape = (32, abstract_child["_out_features"].value) + output_shape = (20, abstract_child["_out_features"].value) outputs = model(inputs) self.assertEqual(tuple(outputs.shape), output_shape) def test_super_mlp(self): hidden_features = spaces.Categorical(12, 24, 36) - out_features = spaces.Categorical(12, 24, 36) + out_features = spaces.Categorical(24, 36, 48) mlp = super_core.SuperMLP(10, hidden_features, out_features) print(mlp) self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) + inputs = torch.rand(4, 10) + outputs = mlp(inputs) + self.assertEqual(tuple(outputs.shape), (4, 48)) + abstract_space = mlp.abstract_search_space print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) self.assertEqual( @@ -67,10 +72,16 @@ class TestSuperLinear(unittest.TestCase): is abstract_space["fc2"]["_in_features"] ) - abstract_space.clean_last_sample() + abstract_space.clean_last() abstract_child = abstract_space.random(reuse_last=True) print("The abstract child program is:\n{:}".format(abstract_child)) self.assertEqual( abstract_child["fc1"]["_out_features"].value, abstract_child["fc2"]["_in_features"].value, ) + + mlp.set_super_run_type(super_core.SuperRunMode.Candidate) + mlp.apply_candidate(abstract_child) + outputs = mlp(inputs) + output_shape = (4, abstract_child["fc2"]["_out_features"].value) + self.assertEqual(tuple(outputs.shape), output_shape)