Add simple baseline for LFNA
This commit is contained in:
		| @@ -11,270 +11,109 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
| from log_utils import AverageMeter, time_string, convert_secs2time | from log_utils import time_string | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
| from datasets.synthetic_core import get_synthetic_env | from datasets.synthetic_core import get_synthetic_env | ||||||
| from models.xcore import get_model | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     torch.set_num_threads(args.workers) |     torch.set_num_threads(args.workers) | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|     dynamic_env = get_synthetic_env() |     dynamic_env = get_synthetic_env() | ||||||
|  |     historical_x, historical_y = None, None | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(dynamic_env): |     for idx, (timestamp, (allx, ally)) in enumerate(dynamic_env): | ||||||
|         import pdb |  | ||||||
|         pdb.set_trace() |  | ||||||
|  |  | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets( |         if historical_x is not None: | ||||||
|         args.dataset, args.data_path, args.cutout_length |             mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|     ) |         else: | ||||||
|     train_loader = torch.utils.data.DataLoader( |             mean, std = 0, 1 | ||||||
|         train_data, |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
|         batch_size=args.batch_size, |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         shuffle=True, |  | ||||||
|         num_workers=args.workers, |  | ||||||
|         pin_memory=True, |  | ||||||
|     ) |  | ||||||
|     valid_loader = torch.utils.data.DataLoader( |  | ||||||
|         valid_data, |  | ||||||
|         batch_size=args.batch_size, |  | ||||||
|         shuffle=False, |  | ||||||
|         num_workers=args.workers, |  | ||||||
|         pin_memory=True, |  | ||||||
|     ) |  | ||||||
|     # get configures |  | ||||||
|     model_config = load_config(args.model_config, {"class_num": class_num}, logger) |  | ||||||
|     optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) |  | ||||||
|  |  | ||||||
|     if args.model_source == "normal": |         # create the current data loader | ||||||
|         base_model = obtain_model(model_config) |         if historical_x is not None: | ||||||
|     elif args.model_source == "nas": |             train_dataset = torch.utils.data.TensorDataset(historical_x, historical_y) | ||||||
|         base_model = obtain_nas_infer_model(model_config, args.extra_model_path) |             train_loader = torch.utils.data.DataLoader( | ||||||
|     elif args.model_source == "autodl-searched": |                 train_dataset, | ||||||
|         base_model = obtain_model(model_config, args.extra_model_path) |                 batch_size=args.batch_size, | ||||||
|     else: |                 shuffle=True, | ||||||
|         raise ValueError("invalid model-source : {:}".format(args.model_source)) |                 num_workers=args.workers, | ||||||
|     flop, param = get_model_infos(base_model, xshape) |  | ||||||
|     logger.log("model ====>>>>:\n{:}".format(base_model)) |  | ||||||
|     logger.log("model information : {:}".format(base_model.get_message())) |  | ||||||
|     logger.log("-" * 50) |  | ||||||
|     logger.log( |  | ||||||
|         "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( |  | ||||||
|             param, flop, flop / 1e3 |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     logger.log("-" * 50) |  | ||||||
|     logger.log("train_data : {:}".format(train_data)) |  | ||||||
|     logger.log("valid_data : {:}".format(valid_data)) |  | ||||||
|     optimizer, scheduler, criterion = get_optim_scheduler( |  | ||||||
|         base_model.parameters(), optim_config |  | ||||||
|     ) |  | ||||||
|     logger.log("optimizer  : {:}".format(optimizer)) |  | ||||||
|     logger.log("scheduler  : {:}".format(scheduler)) |  | ||||||
|     logger.log("criterion  : {:}".format(criterion)) |  | ||||||
|  |  | ||||||
|     last_info, model_base_path, model_best_path = ( |  | ||||||
|         logger.path("info"), |  | ||||||
|         logger.path("model"), |  | ||||||
|         logger.path("best"), |  | ||||||
|     ) |  | ||||||
|     network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() |  | ||||||
|  |  | ||||||
|     if last_info.exists():  # automatically resume from previous checkpoint |  | ||||||
|         logger.log( |  | ||||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) |  | ||||||
|         ) |  | ||||||
|         last_infox = torch.load(last_info) |  | ||||||
|         start_epoch = last_infox["epoch"] + 1 |  | ||||||
|         last_checkpoint_path = last_infox["last_checkpoint"] |  | ||||||
|         if not last_checkpoint_path.exists(): |  | ||||||
|             logger.log( |  | ||||||
|                 "Does not find {:}, try another path".format(last_checkpoint_path) |  | ||||||
|             ) |             ) | ||||||
|             last_checkpoint_path = ( |             optimizer = torch.optim.Adam( | ||||||
|                 last_info.parent |                 model.parameters(), lr=args.init_lr, amsgrad=True | ||||||
|                 / last_checkpoint_path.parent.name |  | ||||||
|                 / last_checkpoint_path.name |  | ||||||
|             ) |             ) | ||||||
|         checkpoint = torch.load(last_checkpoint_path) |             criterion = torch.nn.MSELoss() | ||||||
|         base_model.load_state_dict(checkpoint["base-model"]) |             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) |                 optimizer, | ||||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) |                 milestones=[ | ||||||
|         valid_accuracies = checkpoint["valid_accuracies"] |                     int(args.epochs * 0.25), | ||||||
|         max_bytes = checkpoint["max_bytes"] |                     int(args.epochs * 0.5), | ||||||
|         logger.log( |                     int(args.epochs * 0.75), | ||||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( |                 ], | ||||||
|                 last_info, start_epoch |                 gamma=0.3, | ||||||
|             ) |             ) | ||||||
|         ) |             for _iepoch in range(args.epochs): | ||||||
|     elif args.resume is not None: |                 results = basic_train_fn( | ||||||
|         assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( |                     train_loader, model, criterion, optimizer, MSEMetric(), logger | ||||||
|             args.resume |  | ||||||
|         ) |  | ||||||
|         checkpoint = torch.load(args.resume) |  | ||||||
|         start_epoch = checkpoint["epoch"] + 1 |  | ||||||
|         base_model.load_state_dict(checkpoint["base-model"]) |  | ||||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) |  | ||||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) |  | ||||||
|         valid_accuracies = checkpoint["valid_accuracies"] |  | ||||||
|         max_bytes = checkpoint["max_bytes"] |  | ||||||
|         logger.log( |  | ||||||
|             "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( |  | ||||||
|                 args.resume, start_epoch |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|     elif args.init_model is not None: |  | ||||||
|         assert Path( |  | ||||||
|             args.init_model |  | ||||||
|         ).exists(), "Can not find the initialization file : {:}".format(args.init_model) |  | ||||||
|         checkpoint = torch.load(args.init_model) |  | ||||||
|         base_model.load_state_dict(checkpoint["base-model"]) |  | ||||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} |  | ||||||
|         logger.log("=> initialize the model from {:}".format(args.init_model)) |  | ||||||
|     else: |  | ||||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) |  | ||||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} |  | ||||||
|  |  | ||||||
|     train_func, valid_func = get_procedures(args.procedure) |  | ||||||
|  |  | ||||||
|     total_epoch = optim_config.epochs + optim_config.warmup |  | ||||||
|     # Main Training and Evaluation Loop |  | ||||||
|     start_time = time.time() |  | ||||||
|     epoch_time = AverageMeter() |  | ||||||
|     for epoch in range(start_epoch, total_epoch): |  | ||||||
|         scheduler.update(epoch, 0.0) |  | ||||||
|         need_time = "Time Left: {:}".format( |  | ||||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) |  | ||||||
|         ) |  | ||||||
|         epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) |  | ||||||
|         LRs = scheduler.get_lr() |  | ||||||
|         find_best = False |  | ||||||
|         # set-up drop-out ratio |  | ||||||
|         if hasattr(base_model, "update_drop_path"): |  | ||||||
|             base_model.update_drop_path( |  | ||||||
|                 model_config.drop_path_prob * epoch / total_epoch |  | ||||||
|             ) |  | ||||||
|         logger.log( |  | ||||||
|             "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( |  | ||||||
|                 time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # train for one epoch |  | ||||||
|         train_loss, train_acc1, train_acc5 = train_func( |  | ||||||
|             train_loader, |  | ||||||
|             network, |  | ||||||
|             criterion, |  | ||||||
|             scheduler, |  | ||||||
|             optimizer, |  | ||||||
|             optim_config, |  | ||||||
|             epoch_str, |  | ||||||
|             args.print_freq, |  | ||||||
|             logger, |  | ||||||
|         ) |  | ||||||
|         # log the results |  | ||||||
|         logger.log( |  | ||||||
|             "***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( |  | ||||||
|                 time_string(), epoch_str, train_loss, train_acc1, train_acc5 |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # evaluate the performance |  | ||||||
|         if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): |  | ||||||
|             logger.log("-" * 150) |  | ||||||
|             valid_loss, valid_acc1, valid_acc5 = valid_func( |  | ||||||
|                 valid_loader, |  | ||||||
|                 network, |  | ||||||
|                 criterion, |  | ||||||
|                 optim_config, |  | ||||||
|                 epoch_str, |  | ||||||
|                 args.print_freq_eval, |  | ||||||
|                 logger, |  | ||||||
|             ) |  | ||||||
|             valid_accuracies[epoch] = valid_acc1 |  | ||||||
|             logger.log( |  | ||||||
|                 "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( |  | ||||||
|                     time_string(), |  | ||||||
|                     epoch_str, |  | ||||||
|                     valid_loss, |  | ||||||
|                     valid_acc1, |  | ||||||
|                     valid_acc5, |  | ||||||
|                     valid_accuracies["best"], |  | ||||||
|                     100 - valid_accuracies["best"], |  | ||||||
|                 ) |                 ) | ||||||
|             ) |                 lr_scheduler.step() | ||||||
|             if valid_acc1 > valid_accuracies["best"]: |                 if _iepoch % args.log_per_epoch == 0: | ||||||
|                 valid_accuracies["best"] = valid_acc1 |                     log_str = ( | ||||||
|                 find_best = True |                         "[{:}]".format(time_string()) | ||||||
|                 logger.log( |                         + " [{:04d}/{:04d}][{:04d}/{:04d}]".format( | ||||||
|                     "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( |                             idx, len(dynamic_env), _iepoch, args.epochs | ||||||
|                         epoch, |                         ) | ||||||
|                         valid_acc1, |                         + " mse: {:.5f}, lr: {:.4f}".format( | ||||||
|                         valid_acc5, |                             results["mse"], min(lr_scheduler.get_last_lr()) | ||||||
|                         100 - valid_acc1, |                         ) | ||||||
|                         100 - valid_acc5, |  | ||||||
|                         model_best_path, |  | ||||||
|                     ) |                     ) | ||||||
|                 ) |                     logger.log(log_str) | ||||||
|             num_bytes = ( |             results = basic_eval_fn(train_loader, model, MSEMetric(), logger) | ||||||
|                 torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 |  | ||||||
|             ) |  | ||||||
|             logger.log( |             logger.log( | ||||||
|                 "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( |                 "[{:}] [{:04d}/{:04d}] train-mse: {:.5f}".format( | ||||||
|                     next(network.parameters()).device, |                     time_string(), idx, len(dynamic_env), results["mse"] | ||||||
|                     int(num_bytes), |  | ||||||
|                     num_bytes / 1e3, |  | ||||||
|                     num_bytes / 1e6, |  | ||||||
|                     num_bytes / 1e9, |  | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             max_bytes[epoch] = num_bytes |  | ||||||
|         if epoch % 10 == 0: |  | ||||||
|             torch.cuda.empty_cache() |  | ||||||
|  |  | ||||||
|         # save checkpoint |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|         save_path = save_checkpoint( |         eval_dataset = torch.utils.data.TensorDataset(allx, ally) | ||||||
|             { |         eval_loader = torch.utils.data.DataLoader( | ||||||
|                 "epoch": epoch, |             eval_dataset, | ||||||
|                 "args": deepcopy(args), |             batch_size=args.batch_size, | ||||||
|                 "max_bytes": deepcopy(max_bytes), |             shuffle=False, | ||||||
|                 "FLOP": flop, |             num_workers=args.workers, | ||||||
|                 "PARAM": param, |  | ||||||
|                 "valid_accuracies": deepcopy(valid_accuracies), |  | ||||||
|                 "model-config": model_config._asdict(), |  | ||||||
|                 "optim-config": optim_config._asdict(), |  | ||||||
|                 "base-model": base_model.state_dict(), |  | ||||||
|                 "scheduler": scheduler.state_dict(), |  | ||||||
|                 "optimizer": optimizer.state_dict(), |  | ||||||
|             }, |  | ||||||
|             model_base_path, |  | ||||||
|             logger, |  | ||||||
|         ) |         ) | ||||||
|         if find_best: |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|             copy_checkpoint(model_base_path, model_best_path, logger) |         log_str = ( | ||||||
|         last_info = save_checkpoint( |             "[{:}]".format(time_string()) | ||||||
|             { |             + " [{:04d}/{:04d}]".format(idx, len(dynamic_env)) | ||||||
|                 "epoch": epoch, |             + " eval-mse: {:.5f}".format(results["mse"]) | ||||||
|                 "args": deepcopy(args), |         ) | ||||||
|                 "last_checkpoint": save_path, |         logger.log(log_str) | ||||||
|             }, |  | ||||||
|             logger.path("info"), |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|  |             idx, len(dynamic_env) | ||||||
|  |         ) | ||||||
|  |         save_checkpoint( | ||||||
|  |             {"model": model.state_dict(), "index": idx, "timestamp": timestamp}, | ||||||
|  |             save_path, | ||||||
|             logger, |             logger, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # measure elapsed time |         # Update historical data | ||||||
|         epoch_time.update(time.time() - start_time) |         if historical_x is None: | ||||||
|         start_time = time.time() |             historical_x, historical_y = allx, ally | ||||||
|  |         else: | ||||||
|  |             historical_x, historical_y = torch.cat((historical_x, allx)), torch.cat( | ||||||
|  |                 (historical_y, ally) | ||||||
|  |             ) | ||||||
|  |         logger.log("") | ||||||
|  |  | ||||||
|     logger.log("\n" + "-" * 200) |  | ||||||
|     logger.log( |  | ||||||
|         "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( |  | ||||||
|             convert_secs2time(epoch_time.sum, True), |  | ||||||
|             max(v for k, v in max_bytes.items()) / 1e6, |  | ||||||
|             logger.path("info"), |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
|  |  | ||||||
| @@ -287,11 +126,35 @@ if __name__ == "__main__": | |||||||
|         default="./outputs/lfna-synthetic/use-all-past-data", |         default="./outputs/lfna-synthetic/use-all-past-data", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=256, | ||||||
|  |         help="The batch size", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=2000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--log_per_epoch", | ||||||
|  |         type=int, | ||||||
|  |         default=200, | ||||||
|  |         help="Log the training information per __ epochs.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|         type=int, |         type=int, | ||||||
|         default=8, |         default=4, | ||||||
|         help="number of data loading workers (default: 8)", |         help="The number of data loading workers (default: 4)", | ||||||
|     ) |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
| @@ -59,8 +59,10 @@ class Logger(object): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def path(self, mode): |     def path(self, mode): | ||||||
|         valids = ("model", "best", "info", "log") |         valids = ("model", "best", "info", "log", None) | ||||||
|         if mode == "model": |         if mode is None: | ||||||
|  |             return self.log_dir | ||||||
|  |         elif mode == "model": | ||||||
|             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) |             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) | ||||||
|         elif mode == "best": |         elif mode == "best": | ||||||
|             return self.model_dir / "seed-{:}-best.pth".format(self.seed) |             return self.model_dir / "seed-{:}-best.pth".format(self.seed) | ||||||
|   | |||||||
| @@ -4,12 +4,14 @@ | |||||||
| # Use module in xlayers to construct different models # | # Use module in xlayers to construct different models # | ||||||
| ####################################################### | ####################################################### | ||||||
| from typing import List, Text, Dict, Any | from typing import List, Text, Dict, Any | ||||||
|  | import torch | ||||||
|  |  | ||||||
| __all__ = ["get_model"] | __all__ = ["get_model"] | ||||||
|  |  | ||||||
|  |  | ||||||
| from xlayers.super_core import SuperSequential, SuperMLPv1 | from xlayers.super_core import SuperSequential | ||||||
| from xlayers.super_core import SuperSimpleNorm | from xlayers.super_core import SuperSimpleNorm | ||||||
|  | from xlayers.super_core import SuperLeakyReLU | ||||||
| from xlayers.super_core import SuperLinear | from xlayers.super_core import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -19,9 +21,9 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|         model = SuperSequential( |         model = SuperSequential( | ||||||
|             SuperSimpleNorm(kwargs["mean"], kwargs["std"]), |             SuperSimpleNorm(kwargs["mean"], kwargs["std"]), | ||||||
|             SuperLinear(kwargs["input_dim"], 200), |             SuperLinear(kwargs["input_dim"], 200), | ||||||
|             torch.nn.LeakyReLU(), |             SuperLeakyReLU(), | ||||||
|             SuperLinear(200, 100), |             SuperLinear(200, 100), | ||||||
|             torch.nn.LeakyReLU(), |             SuperLeakyReLU(), | ||||||
|             SuperLinear(100, kwargs["output_dim"]), |             SuperLinear(100, kwargs["output_dim"]), | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|   | |||||||
| @@ -12,49 +12,48 @@ from log_utils import time_string | |||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
| def basic_train( | def get_device(tensors): | ||||||
|  |     if isinstance(tensors, (list, tuple)): | ||||||
|  |         return get_device(tensors[0]) | ||||||
|  |     elif isinstance(tensors, dict): | ||||||
|  |         for key, value in tensors.items(): | ||||||
|  |             return get_device(value) | ||||||
|  |     else: | ||||||
|  |         return tensors.device | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def basic_train_fn( | ||||||
|     xloader, |     xloader, | ||||||
|     network, |     network, | ||||||
|     criterion, |     criterion, | ||||||
|     scheduler, |  | ||||||
|     optimizer, |     optimizer, | ||||||
|     optim_config, |     metric, | ||||||
|     extra_info, |  | ||||||
|     print_freq, |  | ||||||
|     logger, |     logger, | ||||||
| ): | ): | ||||||
|     loss, acc1, acc5 = procedure( |     results = procedure( | ||||||
|         xloader, |         xloader, | ||||||
|         network, |         network, | ||||||
|         criterion, |         criterion, | ||||||
|         scheduler, |  | ||||||
|         optimizer, |         optimizer, | ||||||
|  |         metric, | ||||||
|         "train", |         "train", | ||||||
|         optim_config, |  | ||||||
|         extra_info, |  | ||||||
|         print_freq, |  | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|     return loss, acc1, acc5 |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
| def basic_valid( | def basic_eval_fn(xloader, network, metric, logger): | ||||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger |  | ||||||
| ): |  | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|         loss, acc1, acc5 = procedure( |         results = procedure( | ||||||
|             xloader, |             xloader, | ||||||
|             network, |             network, | ||||||
|             criterion, |  | ||||||
|             None, |             None, | ||||||
|             None, |             None, | ||||||
|  |             metric, | ||||||
|             "valid", |             "valid", | ||||||
|             None, |  | ||||||
|             extra_info, |  | ||||||
|             print_freq, |  | ||||||
|             logger, |             logger, | ||||||
|         ) |         ) | ||||||
|     return loss, acc1, acc5 |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
| def procedure( | def procedure( | ||||||
| @@ -62,12 +61,11 @@ def procedure( | |||||||
|     network, |     network, | ||||||
|     criterion, |     criterion, | ||||||
|     optimizer, |     optimizer, | ||||||
|     eval_metric, |     metric, | ||||||
|     mode: Text, |     mode: Text, | ||||||
|     print_freq: int = 100, |  | ||||||
|     logger_fn: Callable = None, |     logger_fn: Callable = None, | ||||||
| ): | ): | ||||||
|     data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() |     data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|     if mode.lower() == "train": |     if mode.lower() == "train": | ||||||
|         network.train() |         network.train() | ||||||
|     elif mode.lower() == "valid": |     elif mode.lower() == "valid": | ||||||
| @@ -80,49 +78,23 @@ def procedure( | |||||||
|         # measure data loading time |         # measure data loading time | ||||||
|         data_time.update(time.time() - end) |         data_time.update(time.time() - end) | ||||||
|         # calculate prediction and loss |         # calculate prediction and loss | ||||||
|         targets = targets.cuda(non_blocking=True) |  | ||||||
|  |  | ||||||
|         if mode == "train": |         if mode == "train": | ||||||
|             optimizer.zero_grad() |             optimizer.zero_grad() | ||||||
|  |  | ||||||
|         outputs = network(inputs) |         outputs = network(inputs) | ||||||
|         loss = criterion(outputs, targets) |         targets = targets.to(get_device(outputs)) | ||||||
|  |  | ||||||
|         if mode == "train": |         if mode == "train": | ||||||
|  |             loss = criterion(outputs, targets) | ||||||
|             loss.backward() |             loss.backward() | ||||||
|             optimizer.step() |             optimizer.step() | ||||||
|  |  | ||||||
|         # record |         # record | ||||||
|         metrics = eval_metric(logits.data, targets.data) |         with torch.no_grad(): | ||||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) |             results = metric(outputs, targets) | ||||||
|         losses.update(loss.item(), inputs.size(0)) |  | ||||||
|         top1.update(prec1.item(), inputs.size(0)) |  | ||||||
|         top5.update(prec5.item(), inputs.size(0)) |  | ||||||
|  |  | ||||||
|         # measure elapsed time |         # measure elapsed time | ||||||
|         batch_time.update(time.time() - end) |         batch_time.update(time.time() - end) | ||||||
|         end = time.time() |         end = time.time() | ||||||
|  |     return metric.get_info() | ||||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): |  | ||||||
|             Sstr = ( |  | ||||||
|                 " {:5s} ".format(mode.upper()) |  | ||||||
|                 + time_string() |  | ||||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) |  | ||||||
|             ) |  | ||||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( |  | ||||||
|                 loss=losses, top1=top1, top5=top5 |  | ||||||
|             ) |  | ||||||
|             Istr = "Size={:}".format(list(inputs.size())) |  | ||||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) |  | ||||||
|  |  | ||||||
|     logger.log( |  | ||||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( |  | ||||||
|             mode=mode.upper(), |  | ||||||
|             top1=top1, |  | ||||||
|             top5=top5, |  | ||||||
|             error1=100 - top1.avg, |  | ||||||
|             error5=100 - top5.avg, |  | ||||||
|             loss=losses.avg, |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     return losses.avg, top1.avg, top5.avg |  | ||||||
|   | |||||||
| @@ -18,11 +18,3 @@ def obtain_accuracy(output, target, topk=(1,)): | |||||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||||
|         res.append(correct_k.mul_(100.0 / batch_size)) |         res.append(correct_k.mul_(100.0 / batch_size)) | ||||||
|     return res |     return res | ||||||
|  |  | ||||||
|  |  | ||||||
| class EvaluationMetric(abc.ABC): |  | ||||||
|     def __init__(self): |  | ||||||
|         self._total_metrics = 0 |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return self._total_metrics |  | ||||||
|   | |||||||
							
								
								
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||||
|  | ##################################################### | ||||||
|  | import abc | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AverageMeter(object): | ||||||
|  |     """Computes and stores the average and current value""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.reset() | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self.val = 0.0 | ||||||
|  |         self.avg = 0.0 | ||||||
|  |         self.sum = 0.0 | ||||||
|  |         self.count = 0.0 | ||||||
|  |  | ||||||
|  |     def update(self, val, n=1): | ||||||
|  |         self.val = val | ||||||
|  |         self.sum += val * n | ||||||
|  |         self.count += n | ||||||
|  |         self.avg = self.sum / self.count | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Metric(abc.ABC): | ||||||
|  |     """The default meta metric class.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.reset() | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({inner})".format( | ||||||
|  |             name=self.__class__.__name__, inner=self.inner_repr() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def inner_repr(self): | ||||||
|  |         return "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ComposeMetric(Metric): | ||||||
|  |     """The composed metric class.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, *metric_list): | ||||||
|  |         self.reset() | ||||||
|  |         for metric in metric_list: | ||||||
|  |             self.append(metric) | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._metric_list = [] | ||||||
|  |  | ||||||
|  |     def append(self, metric): | ||||||
|  |         if not isinstance(metric, Metric): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The input metric is not correct: {:}".format(type(metric)) | ||||||
|  |             ) | ||||||
|  |         self._metric_list.append(metric) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._metric_list) | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         results = list() | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             results.append(metric(predictions, targets)) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         results = dict() | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             for key, value in metric.get_info().items(): | ||||||
|  |                 results[key] = value | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def inner_repr(self): | ||||||
|  |         xlist = [] | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             xlist.append(str(metric)) | ||||||
|  |         return ",".join(xlist) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MSEMetric(Metric): | ||||||
|  |     """The metric for mse.""" | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._mse = AverageMeter() | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||||
|  |             batch = predictions.shape[0] | ||||||
|  |             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) | ||||||
|  |             loss = loss.item() | ||||||
|  |             self._mse.update(loss, batch) | ||||||
|  |             return loss | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         return {"mse": self._mse.avg} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SaveMetric(Metric): | ||||||
|  |     """The metric for mse.""" | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._predicts = [] | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets=None): | ||||||
|  |         if isinstance(predictions, torch.Tensor): | ||||||
|  |             predicts = predictions.cpu().numpy() | ||||||
|  |             self._predicts.append(predicts) | ||||||
|  |             return predicts | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         all_predicts = np.concatenate(self._predicts) | ||||||
|  |         return {"predictions": all_predicts} | ||||||
| @@ -17,7 +17,7 @@ from .super_module import BoolSpaceType | |||||||
| class SuperReLU(SuperModule): | class SuperReLU(SuperModule): | ||||||
|     """Applies a the rectified linear unit function element-wise.""" |     """Applies a the rectified linear unit function element-wise.""" | ||||||
|  |  | ||||||
|     def __init__(self, inplace=False) -> None: |     def __init__(self, inplace: bool = False) -> None: | ||||||
|         super(SuperReLU, self).__init__() |         super(SuperReLU, self).__init__() | ||||||
|         self._inplace = inplace |         self._inplace = inplace | ||||||
|  |  | ||||||
| @@ -33,3 +33,26 @@ class SuperReLU(SuperModule): | |||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "inplace=True" if self._inplace else "" |         return "inplace=True" if self._inplace else "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperLeakyReLU(SuperModule): | ||||||
|  |     """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" | ||||||
|  |  | ||||||
|  |     def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: | ||||||
|  |         super(SuperLeakyReLU, self).__init__() | ||||||
|  |         self._negative_slope = negative_slope | ||||||
|  |         self._inplace = inplace | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return F.leaky_relu(input, self._negative_slope, self._inplace) | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         inplace_str = "inplace=True" if self._inplace else "" | ||||||
|  |         return "negative_slope={}{}".format(self._negative_slope, inplace_str) | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ from .super_attention import SuperAttention | |||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|  |  | ||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
|  | from .super_activations import SuperLeakyReLU | ||||||
|  |  | ||||||
| from .super_trade_stem import SuperAlphaEBDv1 | from .super_trade_stem import SuperAlphaEBDv1 | ||||||
| from .super_positional_embedding import SuperPositionalEncoder | from .super_positional_embedding import SuperPositionalEncoder | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user