diff --git a/exps/LFNA/basic.py b/exps/LFNA/basic.py index c0b2f81..354b13e 100644 --- a/exps/LFNA/basic.py +++ b/exps/LFNA/basic.py @@ -11,270 +11,109 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import AverageMeter, time_string, convert_secs2time +from log_utils import time_string +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from datasets.synthetic_core import get_synthetic_env from models.xcore import get_model + def main(args): torch.set_num_threads(args.workers) prepare_seed(args.rand_seed) logger = prepare_logger(args) dynamic_env = get_synthetic_env() + historical_x, historical_y = None, None for idx, (timestamp, (allx, ally)) in enumerate(dynamic_env): - import pdb - pdb.set_trace() - train_data, valid_data, xshape, class_num = get_datasets( - args.dataset, args.data_path, args.cutout_length - ) - train_loader = torch.utils.data.DataLoader( - train_data, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.workers, - pin_memory=True, - ) - valid_loader = torch.utils.data.DataLoader( - valid_data, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.workers, - pin_memory=True, - ) - # get configures - model_config = load_config(args.model_config, {"class_num": class_num}, logger) - optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) + if historical_x is not None: + mean, std = historical_x.mean().item(), historical_x.std().item() + else: + mean, std = 0, 1 + model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) - if args.model_source == "normal": - base_model = obtain_model(model_config) - elif args.model_source == "nas": - base_model = obtain_nas_infer_model(model_config, args.extra_model_path) - elif args.model_source == "autodl-searched": - base_model = obtain_model(model_config, args.extra_model_path) - else: - raise ValueError("invalid model-source : {:}".format(args.model_source)) - flop, param = get_model_infos(base_model, xshape) - logger.log("model ====>>>>:\n{:}".format(base_model)) - logger.log("model information : {:}".format(base_model.get_message())) - logger.log("-" * 50) - logger.log( - "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( - param, flop, flop / 1e3 - ) - ) - logger.log("-" * 50) - logger.log("train_data : {:}".format(train_data)) - logger.log("valid_data : {:}".format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler( - base_model.parameters(), optim_config - ) - logger.log("optimizer : {:}".format(optimizer)) - logger.log("scheduler : {:}".format(scheduler)) - logger.log("criterion : {:}".format(criterion)) - - last_info, model_base_path, model_best_path = ( - logger.path("info"), - logger.path("model"), - logger.path("best"), - ) - network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() - - if last_info.exists(): # automatically resume from previous checkpoint - logger.log( - "=> loading checkpoint of the last-info '{:}' start".format(last_info) - ) - last_infox = torch.load(last_info) - start_epoch = last_infox["epoch"] + 1 - last_checkpoint_path = last_infox["last_checkpoint"] - if not last_checkpoint_path.exists(): - logger.log( - "Does not find {:}, try another path".format(last_checkpoint_path) + # create the current data loader + if historical_x is not None: + train_dataset = torch.utils.data.TensorDataset(historical_x, historical_y) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, ) - last_checkpoint_path = ( - last_info.parent - / last_checkpoint_path.parent.name - / last_checkpoint_path.name + optimizer = torch.optim.Adam( + model.parameters(), lr=args.init_lr, amsgrad=True ) - checkpoint = torch.load(last_checkpoint_path) - base_model.load_state_dict(checkpoint["base-model"]) - scheduler.load_state_dict(checkpoint["scheduler"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - valid_accuracies = checkpoint["valid_accuracies"] - max_bytes = checkpoint["max_bytes"] - logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( - last_info, start_epoch + criterion = torch.nn.MSELoss() + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(args.epochs * 0.25), + int(args.epochs * 0.5), + int(args.epochs * 0.75), + ], + gamma=0.3, ) - ) - elif args.resume is not None: - assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( - args.resume - ) - checkpoint = torch.load(args.resume) - start_epoch = checkpoint["epoch"] + 1 - base_model.load_state_dict(checkpoint["base-model"]) - scheduler.load_state_dict(checkpoint["scheduler"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - valid_accuracies = checkpoint["valid_accuracies"] - max_bytes = checkpoint["max_bytes"] - logger.log( - "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( - args.resume, start_epoch - ) - ) - elif args.init_model is not None: - assert Path( - args.init_model - ).exists(), "Can not find the initialization file : {:}".format(args.init_model) - checkpoint = torch.load(args.init_model) - base_model.load_state_dict(checkpoint["base-model"]) - start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} - logger.log("=> initialize the model from {:}".format(args.init_model)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} - - train_func, valid_func = get_procedures(args.procedure) - - total_epoch = optim_config.epochs + optim_config.warmup - # Main Training and Evaluation Loop - start_time = time.time() - epoch_time = AverageMeter() - for epoch in range(start_epoch, total_epoch): - scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format( - convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) - ) - epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) - LRs = scheduler.get_lr() - find_best = False - # set-up drop-out ratio - if hasattr(base_model, "update_drop_path"): - base_model.update_drop_path( - model_config.drop_path_prob * epoch / total_epoch - ) - logger.log( - "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( - time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler - ) - ) - - # train for one epoch - train_loss, train_acc1, train_acc5 = train_func( - train_loader, - network, - criterion, - scheduler, - optimizer, - optim_config, - epoch_str, - args.print_freq, - logger, - ) - # log the results - logger.log( - "***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( - time_string(), epoch_str, train_loss, train_acc1, train_acc5 - ) - ) - - # evaluate the performance - if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): - logger.log("-" * 150) - valid_loss, valid_acc1, valid_acc5 = valid_func( - valid_loader, - network, - criterion, - optim_config, - epoch_str, - args.print_freq_eval, - logger, - ) - valid_accuracies[epoch] = valid_acc1 - logger.log( - "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( - time_string(), - epoch_str, - valid_loss, - valid_acc1, - valid_acc5, - valid_accuracies["best"], - 100 - valid_accuracies["best"], + for _iepoch in range(args.epochs): + results = basic_train_fn( + train_loader, model, criterion, optimizer, MSEMetric(), logger ) - ) - if valid_acc1 > valid_accuracies["best"]: - valid_accuracies["best"] = valid_acc1 - find_best = True - logger.log( - "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( - epoch, - valid_acc1, - valid_acc5, - 100 - valid_acc1, - 100 - valid_acc5, - model_best_path, + lr_scheduler.step() + if _iepoch % args.log_per_epoch == 0: + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}][{:04d}/{:04d}]".format( + idx, len(dynamic_env), _iepoch, args.epochs + ) + + " mse: {:.5f}, lr: {:.4f}".format( + results["mse"], min(lr_scheduler.get_last_lr()) + ) ) - ) - num_bytes = ( - torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 - ) + logger.log(log_str) + results = basic_eval_fn(train_loader, model, MSEMetric(), logger) logger.log( - "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( - next(network.parameters()).device, - int(num_bytes), - num_bytes / 1e3, - num_bytes / 1e6, - num_bytes / 1e9, + "[{:}] [{:04d}/{:04d}] train-mse: {:.5f}".format( + time_string(), idx, len(dynamic_env), results["mse"] ) ) - max_bytes[epoch] = num_bytes - if epoch % 10 == 0: - torch.cuda.empty_cache() - # save checkpoint - save_path = save_checkpoint( - { - "epoch": epoch, - "args": deepcopy(args), - "max_bytes": deepcopy(max_bytes), - "FLOP": flop, - "PARAM": param, - "valid_accuracies": deepcopy(valid_accuracies), - "model-config": model_config._asdict(), - "optim-config": optim_config._asdict(), - "base-model": base_model.state_dict(), - "scheduler": scheduler.state_dict(), - "optimizer": optimizer.state_dict(), - }, - model_base_path, - logger, + metric = ComposeMetric(MSEMetric(), SaveMetric()) + eval_dataset = torch.utils.data.TensorDataset(allx, ally) + eval_loader = torch.utils.data.DataLoader( + eval_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, ) - if find_best: - copy_checkpoint(model_base_path, model_best_path, logger) - last_info = save_checkpoint( - { - "epoch": epoch, - "args": deepcopy(args), - "last_checkpoint": save_path, - }, - logger.path("info"), + results = basic_eval_fn(eval_loader, model, metric, logger) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(dynamic_env)) + + " eval-mse: {:.5f}".format(results["mse"]) + ) + logger.log(log_str) + + save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( + idx, len(dynamic_env) + ) + save_checkpoint( + {"model": model.state_dict(), "index": idx, "timestamp": timestamp}, + save_path, logger, ) - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() + # Update historical data + if historical_x is None: + historical_x, historical_y = allx, ally + else: + historical_x, historical_y = torch.cat((historical_x, allx)), torch.cat( + (historical_y, ally) + ) + logger.log("") - logger.log("\n" + "-" * 200) - logger.log( - "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( - convert_secs2time(epoch_time.sum, True), - max(v for k, v in max_bytes.items()) / 1e6, - logger.path("info"), - ) - ) logger.log("-" * 200 + "\n") logger.close() @@ -287,11 +126,35 @@ if __name__ == "__main__": default="./outputs/lfna-synthetic/use-all-past-data", help="The checkpoint directory.", ) + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=256, + help="The batch size", + ) + parser.add_argument( + "--epochs", + type=int, + default=2000, + help="The total number of epochs.", + ) + parser.add_argument( + "--log_per_epoch", + type=int, + default=200, + help="Log the training information per __ epochs.", + ) parser.add_argument( "--workers", type=int, - default=8, - help="number of data loading workers (default: 8)", + default=4, + help="The number of data loading workers (default: 4)", ) # Random Seed parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") diff --git a/lib/log_utils/logger.py b/lib/log_utils/logger.py index c95573f..3da517d 100644 --- a/lib/log_utils/logger.py +++ b/lib/log_utils/logger.py @@ -59,8 +59,10 @@ class Logger(object): ) def path(self, mode): - valids = ("model", "best", "info", "log") - if mode == "model": + valids = ("model", "best", "info", "log", None) + if mode is None: + return self.log_dir + elif mode == "model": return self.model_dir / "seed-{:}-basic.pth".format(self.seed) elif mode == "best": return self.model_dir / "seed-{:}-best.pth".format(self.seed) diff --git a/lib/models/xcore.py b/lib/models/xcore.py index c6a6bc6..3916222 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -4,12 +4,14 @@ # Use module in xlayers to construct different models # ####################################################### from typing import List, Text, Dict, Any +import torch __all__ = ["get_model"] -from xlayers.super_core import SuperSequential, SuperMLPv1 +from xlayers.super_core import SuperSequential from xlayers.super_core import SuperSimpleNorm +from xlayers.super_core import SuperLeakyReLU from xlayers.super_core import SuperLinear @@ -19,9 +21,9 @@ def get_model(config: Dict[Text, Any], **kwargs): model = SuperSequential( SuperSimpleNorm(kwargs["mean"], kwargs["std"]), SuperLinear(kwargs["input_dim"], 200), - torch.nn.LeakyReLU(), + SuperLeakyReLU(), SuperLinear(200, 100), - torch.nn.LeakyReLU(), + SuperLeakyReLU(), SuperLinear(100, kwargs["output_dim"]), ) else: diff --git a/lib/procedures/advanced_main.py b/lib/procedures/advanced_main.py index 8b6e71b..dfb32f9 100644 --- a/lib/procedures/advanced_main.py +++ b/lib/procedures/advanced_main.py @@ -12,49 +12,48 @@ from log_utils import time_string from .eval_funcs import obtain_accuracy -def basic_train( +def get_device(tensors): + if isinstance(tensors, (list, tuple)): + return get_device(tensors[0]) + elif isinstance(tensors, dict): + for key, value in tensors.items(): + return get_device(value) + else: + return tensors.device + + +def basic_train_fn( xloader, network, criterion, - scheduler, optimizer, - optim_config, - extra_info, - print_freq, + metric, logger, ): - loss, acc1, acc5 = procedure( + results = procedure( xloader, network, criterion, - scheduler, optimizer, + metric, "train", - optim_config, - extra_info, - print_freq, logger, ) - return loss, acc1, acc5 + return results -def basic_valid( - xloader, network, criterion, optim_config, extra_info, print_freq, logger -): +def basic_eval_fn(xloader, network, metric, logger): with torch.no_grad(): - loss, acc1, acc5 = procedure( + results = procedure( xloader, network, - criterion, None, None, + metric, "valid", - None, - extra_info, - print_freq, logger, ) - return loss, acc1, acc5 + return results def procedure( @@ -62,12 +61,11 @@ def procedure( network, criterion, optimizer, - eval_metric, + metric, mode: Text, - print_freq: int = 100, logger_fn: Callable = None, ): - data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() + data_time, batch_time = AverageMeter(), AverageMeter() if mode.lower() == "train": network.train() elif mode.lower() == "valid": @@ -80,49 +78,23 @@ def procedure( # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss - targets = targets.cuda(non_blocking=True) if mode == "train": optimizer.zero_grad() outputs = network(inputs) - loss = criterion(outputs, targets) + targets = targets.to(get_device(outputs)) if mode == "train": + loss = criterion(outputs, targets) loss.backward() optimizer.step() # record - metrics = eval_metric(logits.data, targets.data) - prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - losses.update(loss.item(), inputs.size(0)) - top1.update(prec1.item(), inputs.size(0)) - top5.update(prec5.item(), inputs.size(0)) + with torch.no_grad(): + results = metric(outputs, targets) # measure elapsed time batch_time.update(time.time() - end) end = time.time() - - if i % print_freq == 0 or (i + 1) == len(xloader): - Sstr = ( - " {:5s} ".format(mode.upper()) - + time_string() - + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) - ) - Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( - loss=losses, top1=top1, top5=top5 - ) - Istr = "Size={:}".format(list(inputs.size())) - logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) - - logger.log( - " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( - mode=mode.upper(), - top1=top1, - top5=top5, - error1=100 - top1.avg, - error5=100 - top5.avg, - loss=losses.avg, - ) - ) - return losses.avg, top1.avg, top5.avg + return metric.get_info() diff --git a/lib/procedures/eval_funcs.py b/lib/procedures/eval_funcs.py index 85b5300..99b569b 100644 --- a/lib/procedures/eval_funcs.py +++ b/lib/procedures/eval_funcs.py @@ -18,11 +18,3 @@ def obtain_accuracy(output, target, topk=(1,)): correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res - - -class EvaluationMetric(abc.ABC): - def __init__(self): - self._total_metrics = 0 - - def __len__(self): - return self._total_metrics diff --git a/lib/procedures/metric_utils.py b/lib/procedures/metric_utils.py new file mode 100644 index 0000000..f88c587 --- /dev/null +++ b/lib/procedures/metric_utils.py @@ -0,0 +1,134 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # +##################################################### +import abc +import numpy as np +import torch + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0.0 + self.avg = 0.0 + self.sum = 0.0 + self.count = 0.0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __repr__(self): + return "{name}(val={val}, avg={avg}, count={count})".format( + name=self.__class__.__name__, **self.__dict__ + ) + + +class Metric(abc.ABC): + """The default meta metric class.""" + + def __init__(self): + self.reset() + + def reset(self): + raise NotImplementedError + + def __call__(self, predictions, targets): + raise NotImplementedError + + def get_info(self): + raise NotImplementedError + + def __repr__(self): + return "{name}({inner})".format( + name=self.__class__.__name__, inner=self.inner_repr() + ) + + def inner_repr(self): + return "" + + +class ComposeMetric(Metric): + """The composed metric class.""" + + def __init__(self, *metric_list): + self.reset() + for metric in metric_list: + self.append(metric) + + def reset(self): + self._metric_list = [] + + def append(self, metric): + if not isinstance(metric, Metric): + raise ValueError( + "The input metric is not correct: {:}".format(type(metric)) + ) + self._metric_list.append(metric) + + def __len__(self): + return len(self._metric_list) + + def __call__(self, predictions, targets): + results = list() + for metric in self._metric_list: + results.append(metric(predictions, targets)) + return results + + def get_info(self): + results = dict() + for metric in self._metric_list: + for key, value in metric.get_info().items(): + results[key] = value + return results + + def inner_repr(self): + xlist = [] + for metric in self._metric_list: + xlist.append(str(metric)) + return ",".join(xlist) + + +class MSEMetric(Metric): + """The metric for mse.""" + + def reset(self): + self._mse = AverageMeter() + + def __call__(self, predictions, targets): + if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): + batch = predictions.shape[0] + loss = torch.nn.functional.mse_loss(predictions.data, targets.data) + loss = loss.item() + self._mse.update(loss, batch) + return loss + else: + raise NotImplementedError + + def get_info(self): + return {"mse": self._mse.avg} + + +class SaveMetric(Metric): + """The metric for mse.""" + + def reset(self): + self._predicts = [] + + def __call__(self, predictions, targets=None): + if isinstance(predictions, torch.Tensor): + predicts = predictions.cpu().numpy() + self._predicts.append(predicts) + return predicts + else: + raise NotImplementedError + + def get_info(self): + all_predicts = np.concatenate(self._predicts) + return {"predictions": all_predicts} diff --git a/lib/xlayers/super_activations.py b/lib/xlayers/super_activations.py index a0dac54..336dff3 100644 --- a/lib/xlayers/super_activations.py +++ b/lib/xlayers/super_activations.py @@ -17,7 +17,7 @@ from .super_module import BoolSpaceType class SuperReLU(SuperModule): """Applies a the rectified linear unit function element-wise.""" - def __init__(self, inplace=False) -> None: + def __init__(self, inplace: bool = False) -> None: super(SuperReLU, self).__init__() self._inplace = inplace @@ -33,3 +33,26 @@ class SuperReLU(SuperModule): def extra_repr(self) -> str: return "inplace=True" if self._inplace else "" + + +class SuperLeakyReLU(SuperModule): + """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super(SuperLeakyReLU, self).__init__() + self._negative_slope = negative_slope + self._inplace = inplace + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return F.leaky_relu(input, self._negative_slope, self._inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self._inplace else "" + return "negative_slope={}{}".format(self._negative_slope, inplace_str) diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 11d3fd2..58a0c2f 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -15,6 +15,7 @@ from .super_attention import SuperAttention from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU +from .super_activations import SuperLeakyReLU from .super_trade_stem import SuperAlphaEBDv1 from .super_positional_embedding import SuperPositionalEncoder