Add simple baseline for LFNA
This commit is contained in:
		| @@ -11,270 +11,109 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from log_utils import time_string | ||||
|  | ||||
| from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     torch.set_num_threads(args.workers) | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env() | ||||
|     historical_x, historical_y = None, None | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(dynamic_env): | ||||
|         import pdb | ||||
|         pdb.set_trace() | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         args.dataset, args.data_path, args.cutout_length | ||||
|     ) | ||||
|     train_loader = torch.utils.data.DataLoader( | ||||
|         train_data, | ||||
|         batch_size=args.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_loader = torch.utils.data.DataLoader( | ||||
|         valid_data, | ||||
|         batch_size=args.batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     # get configures | ||||
|     model_config = load_config(args.model_config, {"class_num": class_num}, logger) | ||||
|     optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) | ||||
|         if historical_x is not None: | ||||
|             mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         else: | ||||
|             mean, std = 0, 1 | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|  | ||||
|     if args.model_source == "normal": | ||||
|         base_model = obtain_model(model_config) | ||||
|     elif args.model_source == "nas": | ||||
|         base_model = obtain_nas_infer_model(model_config, args.extra_model_path) | ||||
|     elif args.model_source == "autodl-searched": | ||||
|         base_model = obtain_model(model_config, args.extra_model_path) | ||||
|     else: | ||||
|         raise ValueError("invalid model-source : {:}".format(args.model_source)) | ||||
|     flop, param = get_model_infos(base_model, xshape) | ||||
|     logger.log("model ====>>>>:\n{:}".format(base_model)) | ||||
|     logger.log("model information : {:}".format(base_model.get_message())) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log( | ||||
|         "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( | ||||
|             param, flop, flop / 1e3 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log("train_data : {:}".format(train_data)) | ||||
|     logger.log("valid_data : {:}".format(valid_data)) | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler( | ||||
|         base_model.parameters(), optim_config | ||||
|     ) | ||||
|     logger.log("optimizer  : {:}".format(optimizer)) | ||||
|     logger.log("scheduler  : {:}".format(scheduler)) | ||||
|     logger.log("criterion  : {:}".format(criterion)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_infox = torch.load(last_info) | ||||
|         start_epoch = last_infox["epoch"] + 1 | ||||
|         last_checkpoint_path = last_infox["last_checkpoint"] | ||||
|         if not last_checkpoint_path.exists(): | ||||
|             logger.log( | ||||
|                 "Does not find {:}, try another path".format(last_checkpoint_path) | ||||
|         # create the current data loader | ||||
|         if historical_x is not None: | ||||
|             train_dataset = torch.utils.data.TensorDataset(historical_x, historical_y) | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_dataset, | ||||
|                 batch_size=args.batch_size, | ||||
|                 shuffle=True, | ||||
|                 num_workers=args.workers, | ||||
|             ) | ||||
|             last_checkpoint_path = ( | ||||
|                 last_info.parent | ||||
|                 / last_checkpoint_path.parent.name | ||||
|                 / last_checkpoint_path.name | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 model.parameters(), lr=args.init_lr, amsgrad=True | ||||
|             ) | ||||
|         checkpoint = torch.load(last_checkpoint_path) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             criterion = torch.nn.MSELoss() | ||||
|             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|                 optimizer, | ||||
|                 milestones=[ | ||||
|                     int(args.epochs * 0.25), | ||||
|                     int(args.epochs * 0.5), | ||||
|                     int(args.epochs * 0.75), | ||||
|                 ], | ||||
|                 gamma=0.3, | ||||
|             ) | ||||
|         ) | ||||
|     elif args.resume is not None: | ||||
|         assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( | ||||
|             args.resume | ||||
|         ) | ||||
|         checkpoint = torch.load(args.resume) | ||||
|         start_epoch = checkpoint["epoch"] + 1 | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( | ||||
|                 args.resume, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     elif args.init_model is not None: | ||||
|         assert Path( | ||||
|             args.init_model | ||||
|         ).exists(), "Can not find the initialization file : {:}".format(args.init_model) | ||||
|         checkpoint = torch.load(args.init_model) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|         logger.log("=> initialize the model from {:}".format(args.init_model)) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|  | ||||
|     train_func, valid_func = get_procedures(args.procedure) | ||||
|  | ||||
|     total_epoch = optim_config.epochs + optim_config.warmup | ||||
|     # Main Training and Evaluation Loop | ||||
|     start_time = time.time() | ||||
|     epoch_time = AverageMeter() | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) | ||||
|         LRs = scheduler.get_lr() | ||||
|         find_best = False | ||||
|         # set-up drop-out ratio | ||||
|         if hasattr(base_model, "update_drop_path"): | ||||
|             base_model.update_drop_path( | ||||
|                 model_config.drop_path_prob * epoch / total_epoch | ||||
|             ) | ||||
|         logger.log( | ||||
|             "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( | ||||
|                 time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # train for one epoch | ||||
|         train_loss, train_acc1, train_acc5 = train_func( | ||||
|             train_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             scheduler, | ||||
|             optimizer, | ||||
|             optim_config, | ||||
|             epoch_str, | ||||
|             args.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         # log the results | ||||
|         logger.log( | ||||
|             "***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( | ||||
|                 time_string(), epoch_str, train_loss, train_acc1, train_acc5 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # evaluate the performance | ||||
|         if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): | ||||
|             logger.log("-" * 150) | ||||
|             valid_loss, valid_acc1, valid_acc5 = valid_func( | ||||
|                 valid_loader, | ||||
|                 network, | ||||
|                 criterion, | ||||
|                 optim_config, | ||||
|                 epoch_str, | ||||
|                 args.print_freq_eval, | ||||
|                 logger, | ||||
|             ) | ||||
|             valid_accuracies[epoch] = valid_acc1 | ||||
|             logger.log( | ||||
|                 "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( | ||||
|                     time_string(), | ||||
|                     epoch_str, | ||||
|                     valid_loss, | ||||
|                     valid_acc1, | ||||
|                     valid_acc5, | ||||
|                     valid_accuracies["best"], | ||||
|                     100 - valid_accuracies["best"], | ||||
|             for _iepoch in range(args.epochs): | ||||
|                 results = basic_train_fn( | ||||
|                     train_loader, model, criterion, optimizer, MSEMetric(), logger | ||||
|                 ) | ||||
|             ) | ||||
|             if valid_acc1 > valid_accuracies["best"]: | ||||
|                 valid_accuracies["best"] = valid_acc1 | ||||
|                 find_best = True | ||||
|                 logger.log( | ||||
|                     "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( | ||||
|                         epoch, | ||||
|                         valid_acc1, | ||||
|                         valid_acc5, | ||||
|                         100 - valid_acc1, | ||||
|                         100 - valid_acc5, | ||||
|                         model_best_path, | ||||
|                 lr_scheduler.step() | ||||
|                 if _iepoch % args.log_per_epoch == 0: | ||||
|                     log_str = ( | ||||
|                         "[{:}]".format(time_string()) | ||||
|                         + " [{:04d}/{:04d}][{:04d}/{:04d}]".format( | ||||
|                             idx, len(dynamic_env), _iepoch, args.epochs | ||||
|                         ) | ||||
|                         + " mse: {:.5f}, lr: {:.4f}".format( | ||||
|                             results["mse"], min(lr_scheduler.get_last_lr()) | ||||
|                         ) | ||||
|                     ) | ||||
|                 ) | ||||
|             num_bytes = ( | ||||
|                 torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 | ||||
|             ) | ||||
|                     logger.log(log_str) | ||||
|             results = basic_eval_fn(train_loader, model, MSEMetric(), logger) | ||||
|             logger.log( | ||||
|                 "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( | ||||
|                     next(network.parameters()).device, | ||||
|                     int(num_bytes), | ||||
|                     num_bytes / 1e3, | ||||
|                     num_bytes / 1e6, | ||||
|                     num_bytes / 1e9, | ||||
|                 "[{:}] [{:04d}/{:04d}] train-mse: {:.5f}".format( | ||||
|                     time_string(), idx, len(dynamic_env), results["mse"] | ||||
|                 ) | ||||
|             ) | ||||
|             max_bytes[epoch] = num_bytes | ||||
|         if epoch % 10 == 0: | ||||
|             torch.cuda.empty_cache() | ||||
|  | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch, | ||||
|                 "args": deepcopy(args), | ||||
|                 "max_bytes": deepcopy(max_bytes), | ||||
|                 "FLOP": flop, | ||||
|                 "PARAM": param, | ||||
|                 "valid_accuracies": deepcopy(valid_accuracies), | ||||
|                 "model-config": model_config._asdict(), | ||||
|                 "optim-config": optim_config._asdict(), | ||||
|                 "base-model": base_model.state_dict(), | ||||
|                 "scheduler": scheduler.state_dict(), | ||||
|                 "optimizer": optimizer.state_dict(), | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset(allx, ally) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, | ||||
|             batch_size=args.batch_size, | ||||
|             shuffle=False, | ||||
|             num_workers=args.workers, | ||||
|         ) | ||||
|         if find_best: | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(dynamic_env)) | ||||
|             + " eval-mse: {:.5f}".format(results["mse"]) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, len(dynamic_env) | ||||
|         ) | ||||
|         save_checkpoint( | ||||
|             {"model": model.state_dict(), "index": idx, "timestamp": timestamp}, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         # Update historical data | ||||
|         if historical_x is None: | ||||
|             historical_x, historical_y = allx, ally | ||||
|         else: | ||||
|             historical_x, historical_y = torch.cat((historical_x, allx)), torch.cat( | ||||
|                 (historical_y, ally) | ||||
|             ) | ||||
|         logger.log("") | ||||
|  | ||||
|     logger.log("\n" + "-" * 200) | ||||
|     logger.log( | ||||
|         "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( | ||||
|             convert_secs2time(epoch_time.sum, True), | ||||
|             max(v for k, v in max_bytes.items()) / 1e6, | ||||
|             logger.path("info"), | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
| @@ -287,11 +126,35 @@ if __name__ == "__main__": | ||||
|         default="./outputs/lfna-synthetic/use-all-past-data", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=256, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=2000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--log_per_epoch", | ||||
|         type=int, | ||||
|         default=200, | ||||
|         help="Log the training information per __ epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|   | ||||
| @@ -59,8 +59,10 @@ class Logger(object): | ||||
|         ) | ||||
|  | ||||
|     def path(self, mode): | ||||
|         valids = ("model", "best", "info", "log") | ||||
|         if mode == "model": | ||||
|         valids = ("model", "best", "info", "log", None) | ||||
|         if mode is None: | ||||
|             return self.log_dir | ||||
|         elif mode == "model": | ||||
|             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) | ||||
|         elif mode == "best": | ||||
|             return self.model_dir / "seed-{:}-best.pth".format(self.seed) | ||||
|   | ||||
| @@ -4,12 +4,14 @@ | ||||
| # Use module in xlayers to construct different models # | ||||
| ####################################################### | ||||
| from typing import List, Text, Dict, Any | ||||
| import torch | ||||
|  | ||||
| __all__ = ["get_model"] | ||||
|  | ||||
|  | ||||
| from xlayers.super_core import SuperSequential, SuperMLPv1 | ||||
| from xlayers.super_core import SuperSequential | ||||
| from xlayers.super_core import SuperSimpleNorm | ||||
| from xlayers.super_core import SuperLeakyReLU | ||||
| from xlayers.super_core import SuperLinear | ||||
|  | ||||
|  | ||||
| @@ -19,9 +21,9 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|         model = SuperSequential( | ||||
|             SuperSimpleNorm(kwargs["mean"], kwargs["std"]), | ||||
|             SuperLinear(kwargs["input_dim"], 200), | ||||
|             torch.nn.LeakyReLU(), | ||||
|             SuperLeakyReLU(), | ||||
|             SuperLinear(200, 100), | ||||
|             torch.nn.LeakyReLU(), | ||||
|             SuperLeakyReLU(), | ||||
|             SuperLinear(100, kwargs["output_dim"]), | ||||
|         ) | ||||
|     else: | ||||
|   | ||||
| @@ -12,49 +12,48 @@ from log_utils import time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
| def get_device(tensors): | ||||
|     if isinstance(tensors, (list, tuple)): | ||||
|         return get_device(tensors[0]) | ||||
|     elif isinstance(tensors, dict): | ||||
|         for key, value in tensors.items(): | ||||
|             return get_device(value) | ||||
|     else: | ||||
|         return tensors.device | ||||
|  | ||||
|  | ||||
| def basic_train_fn( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     metric, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|     results = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         metric, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
| def basic_eval_fn(xloader, network, metric, logger): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|         results = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             metric, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
| @@ -62,12 +61,11 @@ def procedure( | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     eval_metric, | ||||
|     metric, | ||||
|     mode: Text, | ||||
|     print_freq: int = 100, | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
| @@ -80,49 +78,23 @@ def procedure( | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         loss = criterion(outputs, targets) | ||||
|         targets = targets.to(get_device(outputs)) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss = criterion(outputs, targets) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         metrics = eval_metric(logits.data, targets.data) | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         with torch.no_grad(): | ||||
|             results = metric(outputs, targets) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|     return metric.get_info() | ||||
|   | ||||
| @@ -18,11 +18,3 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|  | ||||
|  | ||||
| class EvaluationMetric(abc.ABC): | ||||
|     def __init__(self): | ||||
|         self._total_metrics = 0 | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._total_metrics | ||||
|   | ||||
							
								
								
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch = predictions.shape[0] | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) | ||||
|             loss = loss.item() | ||||
|             self._mse.update(loss, batch) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._predicts = [] | ||||
|  | ||||
|     def __call__(self, predictions, targets=None): | ||||
|         if isinstance(predictions, torch.Tensor): | ||||
|             predicts = predictions.cpu().numpy() | ||||
|             self._predicts.append(predicts) | ||||
|             return predicts | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         all_predicts = np.concatenate(self._predicts) | ||||
|         return {"predictions": all_predicts} | ||||
| @@ -17,7 +17,7 @@ from .super_module import BoolSpaceType | ||||
| class SuperReLU(SuperModule): | ||||
|     """Applies a the rectified linear unit function element-wise.""" | ||||
|  | ||||
|     def __init__(self, inplace=False) -> None: | ||||
|     def __init__(self, inplace: bool = False) -> None: | ||||
|         super(SuperReLU, self).__init__() | ||||
|         self._inplace = inplace | ||||
|  | ||||
| @@ -33,3 +33,26 @@ class SuperReLU(SuperModule): | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "inplace=True" if self._inplace else "" | ||||
|  | ||||
|  | ||||
| class SuperLeakyReLU(SuperModule): | ||||
|     """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" | ||||
|  | ||||
|     def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: | ||||
|         super(SuperLeakyReLU, self).__init__() | ||||
|         self._negative_slope = negative_slope | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.leaky_relu(input, self._negative_slope, self._inplace) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         inplace_str = "inplace=True" if self._inplace else "" | ||||
|         return "negative_slope={}{}".format(self._negative_slope, inplace_str) | ||||
|   | ||||
| @@ -15,6 +15,7 @@ from .super_attention import SuperAttention | ||||
| from .super_transformer import SuperTransformerEncoderLayer | ||||
|  | ||||
| from .super_activations import SuperReLU | ||||
| from .super_activations import SuperLeakyReLU | ||||
|  | ||||
| from .super_trade_stem import SuperAlphaEBDv1 | ||||
| from .super_positional_embedding import SuperPositionalEncoder | ||||
|   | ||||
		Reference in New Issue
	
	Block a user