Update codes
This commit is contained in:
		
							
								
								
									
										206
									
								
								exps/GeMOSA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								exps/GeMOSA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     # check indexes to be evaluated | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||
|     logger.log( | ||||
|         "Evaluate {:}, which has {:} timestamps in total.".format( | ||||
|             args.srange, len(to_evaluate_indexes) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time( | ||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x, historical_y = [], [] | ||||
|         for past_i in range(idx): | ||||
|             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||
|             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||
|         historical_x, historical_y = subsample(historical_x, historical_y) | ||||
|         # build model | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use all the past data to train.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-all-past-data", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.env_version, args.hidden_dim | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										271
									
								
								exps/GeMOSA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								exps/GeMOSA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,271 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-maml.py --env_version v1 --inner_step 5 | ||||
| # python exps/LFNA/basic-maml.py --env_version v2 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from log_utils import time_string | ||||
| from log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from utils import split_str2indexes | ||||
|  | ||||
| from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env, EnvSampler | ||||
| from models.xcore import get_model | ||||
| from xlayers import super_core | ||||
|  | ||||
| from lfna_utils import lfna_setup, TimeData | ||||
|  | ||||
|  | ||||
| class MAML: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 | ||||
|     ): | ||||
|         self.criterion = criterion | ||||
|         # self.container = container | ||||
|         self.network = network | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             self.meta_optimizer, | ||||
|             milestones=[ | ||||
|                 int(epochs * 0.8), | ||||
|                 int(epochs * 0.9), | ||||
|             ], | ||||
|             gamma=0.1, | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|         self._best_info = dict(state_dict=None, iepoch=None, score=None) | ||||
|         print("There are {:} weights.".format(self.network.get_w_container().numel())) | ||||
|  | ||||
|     def adapt(self, dataset): | ||||
|         # create a container for the future timestamp | ||||
|         container = self.network.get_w_container() | ||||
|  | ||||
|         for k in range(0, self.inner_step): | ||||
|             y_hat = self.network.forward_with_container(dataset.x, container) | ||||
|             loss = self.criterion(y_hat, dataset.y) | ||||
|             grads = torch.autograd.grad(loss, container.parameters()) | ||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||
|         return container | ||||
|  | ||||
|     def predict(self, x, container=None): | ||||
|         if container is not None: | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|         else: | ||||
|             y_hat = self.network(x) | ||||
|         return y_hat | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|         self.meta_lr_scheduler.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.criterion.load_state_dict(state_dict["criterion"]) | ||||
|         self.network.load_state_dict(state_dict["network"]) | ||||
|         self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) | ||||
|         self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         state_dict = dict() | ||||
|         state_dict["criterion"] = self.criterion.state_dict() | ||||
|         state_dict["network"] = self.network.state_dict() | ||||
|         state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() | ||||
|         state_dict["meta_lr_scheduler"] = self.meta_lr_scheduler.state_dict() | ||||
|         return state_dict | ||||
|  | ||||
|     def save_best(self, score): | ||||
|         success, best_score = self.network.save_best(score) | ||||
|         return success, best_score | ||||
|  | ||||
|     def load_best(self): | ||||
|         self.network.load_best() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     model = get_model(**model_kwargs) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|  | ||||
|     criterion = torch.nn.MSELoss() | ||||
|  | ||||
|     maml = MAML( | ||||
|         model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step | ||||
|     ) | ||||
|  | ||||
|     # meta-training | ||||
|     last_success_epoch = 0 | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         head_str = ( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|         meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             future_timestamp = dynamic_env.random_timestamp() | ||||
|             _, (future_x, future_y) = dynamic_env(future_timestamp) | ||||
|             past_timestamp = ( | ||||
|                 future_timestamp - args.prev_time * dynamic_env.timestamp_interval | ||||
|             ) | ||||
|             _, (past_x, past_y) = dynamic_env(past_timestamp) | ||||
|  | ||||
|             future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) | ||||
|             future_y_hat = maml.predict(future_x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|             meta_losses.append(future_loss) | ||||
|         meta_loss = torch.stack(meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         maml.step() | ||||
|  | ||||
|         logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|         success, best_score = maml.save_best(-meta_loss.item()) | ||||
|         if success: | ||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||
|             save_checkpoint(maml.state_dict(), logger.path("model"), logger) | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: | ||||
|             logger.log("Early stop at {:}".format(iepoch)) | ||||
|             break | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # meta-test | ||||
|     maml.load_best() | ||||
|     eval_env = env_info["dynamic_env"] | ||||
|     assert eval_env.timestamp_interval == dynamic_env.timestamp_interval | ||||
|     w_container_per_epoch = dict() | ||||
|     for idx in range(args.prev_time, len(eval_env)): | ||||
|         future_timestamp, (future_x, future_y) = eval_env[idx] | ||||
|         past_timestamp = ( | ||||
|             future_timestamp.item() - args.prev_time * eval_env.timestamp_interval | ||||
|         ) | ||||
|         _, (past_x, past_y) = eval_env(past_timestamp) | ||||
|         future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) | ||||
|         w_container_per_epoch[idx] = future_container.no_grad_clone() | ||||
|         with torch.no_grad(): | ||||
|             future_y_hat = maml.predict(future_x, w_container_per_epoch[idx]) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|         logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-maml", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.01, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--fail_thresh", | ||||
|         type=float, | ||||
|         default=1000, | ||||
|         help="The threshold for the failure, which we reuse the previous best model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_lr", | ||||
|         type=float, | ||||
|         default=0.005, | ||||
|         help="The learning rate for the inner optimization", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_step", type=int, default=1, help="The inner loop steps for MAML." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prev_time", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The gap between prev_time and current_timestamp", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=64, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=2000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=50, | ||||
|         help="The maximum epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-s{:}-mlr{:}-d{:}-prev{:}-e{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.inner_step, | ||||
|         args.meta_lr, | ||||
|         args.hidden_dim, | ||||
|         args.prev_time, | ||||
|         args.epochs, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										203
									
								
								exps/GeMOSA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								exps/GeMOSA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,203 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||
| # python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx in range(args.prev_time, env_info["total"]): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         historical_x = env_info["{:}-x".format(idx - args.prev_time)] | ||||
|         historical_y = env_info["{:}-y".format(idx - args.prev_time)] | ||||
|         # build model | ||||
|         model = get_model(**model_kwargs) | ||||
|         print(model) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the last timestamp.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-prev-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prev_time", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The gap between prev_time and current_timestamp", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.hidden_dim, | ||||
|         args.epochs, | ||||
|         args.init_lr, | ||||
|         args.prev_time, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										204
									
								
								exps/GeMOSA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								exps/GeMOSA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,204 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||
| # python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
|         return historical_x, historical_y | ||||
|     else: | ||||
|         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||
|         return historical_x[indexes], historical_y[indexes] | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     logger.log("The total enviornment: {:}".format(env)) | ||||
|     w_containers = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         historical_x = future_x.to(args.device) | ||||
|         historical_y = future_y.to(args.device) | ||||
|         # build model | ||||
|         model = get_model(**model_kwargs) | ||||
|         model = model.to(args.device) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         model.analyze_weights() | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             future_x.to(args.device), future_y.to(args.device) | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env)) | ||||
|         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": future_time.item(), | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( | ||||
|         args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										265
									
								
								exps/GeMOSA/lfna_meta_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										265
									
								
								exps/GeMOSA/lfna_meta_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,265 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import torch | ||||
|  | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl.xlayers import super_core | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class MetaModelV1(super_core.SuperModule): | ||||
|     """Learning to Generate Models One Step Ahead (Meta Model Design).""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_dim, | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = 10, | ||||
|         interval: float = None, | ||||
|         thresh: float = None, | ||||
|     ): | ||||
|         super(MetaModelV1, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._seq_length = seq_length | ||||
|         self._thresh = interval * 50 if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_meta_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         # register a time difference buffer | ||||
|         time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         time_interval.reverse() | ||||
|         self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_dim, scale=1 / interval | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttentionV2( | ||||
|             qk_att_dim=time_dim, | ||||
|             in_v_dim=time_dim, | ||||
|             hidden_dim=time_dim, | ||||
|             num_heads=4, | ||||
|             proj_dim=time_dim, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_dim + time_dim, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=dropout, | ||||
|         ) | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|  | ||||
|         # initialization | ||||
|         trunc_normal_( | ||||
|             [self._super_layer_embed, self._super_meta_embed], | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
|     def get_parameters(self, time_embed, attention, generator): | ||||
|         parameters = [] | ||||
|         if time_embed: | ||||
|             parameters.append(self._super_meta_embed) | ||||
|         if attention: | ||||
|             parameters.extend(list(self._trans_att.parameters())) | ||||
|         if generator: | ||||
|             parameters.append(self._super_layer_embed) | ||||
|             parameters.extend(list(self._generator.parameters())) | ||||
|         return parameters | ||||
|  | ||||
|     @property | ||||
|     def meta_timestamps(self): | ||||
|         with torch.no_grad(): | ||||
|             meta_timestamps = [self._meta_timestamps] | ||||
|             for key in ("fixed", "learnt"): | ||||
|                 if self._append_meta_timestamps[key] is not None: | ||||
|                     meta_timestamps.append(self._append_meta_timestamps[key]) | ||||
|         return torch.cat(meta_timestamps) | ||||
|  | ||||
|     @property | ||||
|     def super_meta_embed(self): | ||||
|         meta_embed = [self._super_meta_embed] | ||||
|         for key in ("fixed", "learnt"): | ||||
|             if self._append_meta_embed[key] is not None: | ||||
|                 meta_embed.append(self._append_meta_embed[key]) | ||||
|         return torch.cat(meta_embed) | ||||
|  | ||||
|     def create_meta_embed(self): | ||||
|         param = torch.Tensor(1, self._time_embed_dim) | ||||
|         trunc_normal_(param, std=0.02) | ||||
|         param = param.to(self._super_meta_embed.device) | ||||
|         param = torch.nn.Parameter(param, True) | ||||
|         return param | ||||
|  | ||||
|     def get_closest_meta_distance(self, timestamp): | ||||
|         with torch.no_grad(): | ||||
|             distances = torch.abs(self.meta_timestamps - timestamp) | ||||
|             return torch.min(distances).item() | ||||
|  | ||||
|     def replace_append_learnt(self, timestamp, meta_embed): | ||||
|         self._append_meta_timestamps["learnt"] = timestamp | ||||
|         self._append_meta_embed["learnt"] = meta_embed | ||||
|  | ||||
|     @property | ||||
|     def meta_length(self): | ||||
|         return self.meta_timestamps.numel() | ||||
|  | ||||
|     def clear_fixed(self): | ||||
|         self._append_meta_timestamps["fixed"] = None | ||||
|         self._append_meta_embed["fixed"] = None | ||||
|  | ||||
|     def clear_learnt(self): | ||||
|         self.replace_append_learnt(None, None) | ||||
|  | ||||
|     def append_fixed(self, timestamp, meta_embed): | ||||
|         with torch.no_grad(): | ||||
|             device = self._super_meta_embed.device | ||||
|             timestamp = timestamp.detach().clone().to(device) | ||||
|             meta_embed = meta_embed.detach().clone().to(device) | ||||
|             if self._append_meta_timestamps["fixed"] is None: | ||||
|                 self._append_meta_timestamps["fixed"] = timestamp | ||||
|             else: | ||||
|                 self._append_meta_timestamps["fixed"] = torch.cat( | ||||
|                     (self._append_meta_timestamps["fixed"], timestamp), dim=0 | ||||
|                 ) | ||||
|             if self._append_meta_embed["fixed"] is None: | ||||
|                 self._append_meta_embed["fixed"] = meta_embed | ||||
|             else: | ||||
|                 self._append_meta_embed["fixed"] = torch.cat( | ||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||
|                 ) | ||||
|  | ||||
|     def _obtain_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|         batch, seq = timestamps.shape | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_qk_att_embed = self._tscalar_embed( | ||||
|             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||
|         ) | ||||
|         # create the mask | ||||
|         mask = ( | ||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
|         ) | ( | ||||
|             torch.abs( | ||||
|                 torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) | ||||
|             ) | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_qk_att_embed, | ||||
|             timestamp_v_embed, | ||||
|             mask, | ||||
|         ) | ||||
|         return timestamp_embeds[:, -1, :] | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             B, S = time_seq.shape | ||||
|             time_embeds = self._obtain_time_embed(time_seq) | ||||
|         else:  # use the hyper-net only | ||||
|             time_seq = None | ||||
|             B, _ = time_embeds.shape | ||||
|         if tembed_only: | ||||
|             return time_embeds | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||
|         joint_embeds = torch.cat( | ||||
|             ( | ||||
|                 time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), | ||||
|                 self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), | ||||
|             ), | ||||
|             dim=-1, | ||||
|         ) | ||||
|         batch_weights = self._generator(joint_embeds) | ||||
|         batch_containers = [] | ||||
|         for weights in torch.split(batch_weights, 1): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return time_seq, batch_containers, time_embeds | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
|             return False, None | ||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
|  | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||
|             ) | ||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
|             base_model.train() | ||||
|             if init_info is not None: | ||||
|                 best_loss = init_info["loss"] | ||||
|                 new_param.data.copy_(init_info["param"].data) | ||||
|             else: | ||||
|                 best_loss = 1e9 | ||||
|             with torch.no_grad(): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 _, [container], time_embed = self(None, new_param.view(1, 1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach().clone() | ||||
|         with torch.no_grad(): | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, best_new_param) | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||
|             list(self._super_layer_embed.shape), | ||||
|             list(self._super_meta_embed.shape), | ||||
|             list(self._meta_timestamps.shape), | ||||
|         ) | ||||
							
								
								
									
										117
									
								
								exps/GeMOSA/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								exps/GeMOSA/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
| import torch | ||||
|  | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xlayers import trunc_normal_ | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class HyperNet(super_core.SuperModule): | ||||
|     """The hyper-network.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_embeding, | ||||
|         task_embedding, | ||||
|         num_tasks, | ||||
|         return_container=True, | ||||
|     ): | ||||
|         super(HyperNet, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_task_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|         trunc_normal_(self._super_task_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_embeding + task_embedding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=0.2, | ||||
|         ) | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, task_embed_id): | ||||
|         layer_embed = self._super_layer_embed | ||||
|         task_embed = ( | ||||
|             self._super_task_embed[task_embed_id] | ||||
|             .view(1, -1) | ||||
|             .expand(self._num_layers, -1) | ||||
|         ) | ||||
|  | ||||
|         joint_embed = torch.cat((task_embed, layer_embed), dim=-1) | ||||
|         weights = self._generator(joint_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
|  | ||||
|  | ||||
| class HyperNet_VX(super_core.SuperModule): | ||||
|     def __init__(self, shape_container, input_embeding, return_container=True): | ||||
|         super(HyperNet_VX, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             input_dim=input_embeding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dim=input_embeding * 4, | ||||
|             act_cls="sigmoid", | ||||
|             norm_cls="identity", | ||||
|         ) | ||||
|         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, input): | ||||
|         weights = self._generator(self._super_layer_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
							
								
								
									
										64
									
								
								exps/GeMOSA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								exps/GeMOSA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| from xautodl.procedures import prepare_seed, prepare_logger | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
|  | ||||
|  | ||||
| def lfna_setup(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="gelu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     return logger, model_kwargs | ||||
|  | ||||
|  | ||||
| def train_model(model, dataset, lr, epochs): | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True) | ||||
|     best_loss, best_param = None, None | ||||
|     for _iepoch in range(epochs): | ||||
|         preds = model(dataset.x) | ||||
|         optimizer.zero_grad() | ||||
|         loss = criterion(preds, dataset.y) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         # save best | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|     model.load_state_dict(best_param) | ||||
|     return best_loss | ||||
|  | ||||
|  | ||||
| class TimeData: | ||||
|     def __init__(self, timestamp, xs, ys): | ||||
|         self._timestamp = timestamp | ||||
|         self._xs = xs | ||||
|         self._ys = ys | ||||
|  | ||||
|     @property | ||||
|     def x(self): | ||||
|         return self._xs | ||||
|  | ||||
|     @property | ||||
|     def y(self): | ||||
|         return self._ys | ||||
|  | ||||
|     @property | ||||
|     def timestamp(self): | ||||
|         return self._timestamp | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(timestamp={timestamp}, with {num} samples)".format( | ||||
|             name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) | ||||
|         ) | ||||
							
								
								
									
										343
									
								
								exps/GeMOSA/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								exps/GeMOSA/main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,343 @@ | ||||
| ##################################################### | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core, trunc_normal_ | ||||
|  | ||||
| from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from lfna_meta_model import MetaModelV1 | ||||
|  | ||||
|  | ||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     loss_meter = AverageMeter() | ||||
|     w_containers = dict() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             _, [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(1, 1), None, False | ||||
|             ) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|         refine, post_refine_loss = meta_model.adapt( | ||||
|             base_model, | ||||
|             criterion, | ||||
|             future_time.item(), | ||||
|             future_x, | ||||
|             future_y, | ||||
|             args.refine_lr, | ||||
|             args.refine_epochs, | ||||
|             {"param": time_embeds, "loss": future_loss.item()}, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item() | ||||
|             ) | ||||
|             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||
|         ) | ||||
|     meta_model.clear_fixed() | ||||
|     meta_model.clear_learnt() | ||||
|     return w_containers, loss_meter | ||||
|  | ||||
|  | ||||
| def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.get_parameters(True, True, True), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     logger.log("Pre-train the meta-model") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") | ||||
|     final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) | ||||
|     if meta_model.has_best(final_best_name): | ||||
|         meta_model.load_best(final_best_name) | ||||
|         logger.log("Directly load the best model from {:}".format(final_best_name)) | ||||
|         return | ||||
|  | ||||
|     total_indexes = list(range(meta_model.meta_length)) | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
|     last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     device = args.device | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|         generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True) | ||||
|  | ||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||
|  | ||||
|         raw_time_steps = meta_model.meta_timestamps[batch_indexes] | ||||
|  | ||||
|         regularization_loss = F.l1_loss( | ||||
|             generated_time_embeds, meta_model.super_meta_embed, reduction="mean" | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         _, future_containers, _ = meta_model( | ||||
|             None, generated_time_embeds[batch_indexes], False | ||||
|         ) | ||||
|         _, present_containers, _ = meta_model( | ||||
|             None, meta_model.super_meta_embed[batch_indexes], False | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|             _, (inputs, targets) = xenv(time_step) | ||||
|             inputs, targets = inputs.to(device), targets.to(device) | ||||
|  | ||||
|             predictions = base_model.forward_with_container( | ||||
|                 inputs, future_containers[ibatch] | ||||
|             ) | ||||
|             total_future_losses.append(criterion(predictions, targets)) | ||||
|  | ||||
|             predictions = base_model.forward_with_container( | ||||
|                 inputs, present_containers[ibatch] | ||||
|             ) | ||||
|             total_present_losses.append(criterion(predictions, targets)) | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             meta_std = torch.stack(total_future_losses).std().item() | ||||
|         loss_future = torch.stack(total_future_losses).mean() | ||||
|         loss_present = torch.stack(total_present_losses).mean() | ||||
|         total_loss = loss_future + loss_present + regularization_loss | ||||
|         total_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 meta_std, | ||||
|                 loss_future.item(), | ||||
|                 loss_present.item(), | ||||
|                 regularization_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_future_losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) | ||||
|             + ", {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= early_stop_thresh: | ||||
|             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||
|             break | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     meta_model.load_best() | ||||
|     # save to the final model | ||||
|     meta_model.set_best_name(final_best_name) | ||||
|     success, _ = meta_model.save_best(best_score + 1e-6) | ||||
|     assert success | ||||
|     logger.log("Save the best model into {:}".format(final_best_name)) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|     logger.log("The training enviornment: {:}".format(train_env)) | ||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||
|     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||
|     logger.log("The total enviornment: {:}".format(all_env)) | ||||
|  | ||||
|     base_model = get_model(**model_kwargs) | ||||
|     base_model = base_model.to(args.device) | ||||
|     criterion = torch.nn.MSELoss() | ||||
|  | ||||
|     shape_container = base_model.get_w_container().to_shape_container() | ||||
|  | ||||
|     # pre-train the hypernetwork | ||||
|     timestamps = trainval_env.get_timestamp(None) | ||||
|     meta_model = MetaModelV1( | ||||
|         shape_container, | ||||
|         args.layer_dim, | ||||
|         args.time_dim, | ||||
|         timestamps, | ||||
|         seq_length=args.seq_length, | ||||
|         interval=trainval_env.time_interval, | ||||
|     ) | ||||
|     meta_model = meta_model.to(args.device) | ||||
|  | ||||
|     logger.log("The base-model has {:} weights.".format(base_model.numel())) | ||||
|     logger.log("The meta-model has {:} weights.".format(meta_model.numel())) | ||||
|     logger.log("The base-model is\n{:}".format(base_model)) | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|  | ||||
|     meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger) | ||||
|  | ||||
|     # try to evaluate once | ||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||
|     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||
|     w_containers, loss_meter = online_evaluate( | ||||
|         all_env, meta_model, base_model, criterion, args, logger, True | ||||
|     ) | ||||
|     logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(".") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/lfna-battle", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--layer_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The layer chunk dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_dim", | ||||
|         type=int, | ||||
|         default=16, | ||||
|         help="The timestamp dimension.", | ||||
|     ) | ||||
|     ##### | ||||
|     parser.add_argument( | ||||
|         "--lr", | ||||
|         type=float, | ||||
|         default=0.002, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--weight_decay", | ||||
|         type=float, | ||||
|         default=0.00001, | ||||
|         help="The weight decay for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=64, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sampler_enlarge", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="Enlarge the #iterations for an epoch", | ||||
|     ) | ||||
|     parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") | ||||
|     parser.add_argument( | ||||
|         "--refine_lr", | ||||
|         type=float, | ||||
|         default=0.001, | ||||
|         help="The learning rate for the optimizer, during refine", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--refine_epochs", type=int, default=150, help="The final refine #epochs." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=20, | ||||
|         help="The #epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--pretrain_early_stop_thresh", | ||||
|         type=int, | ||||
|         default=300, | ||||
|         help="The #epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=10, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.meta_batch, | ||||
|         args.hidden_dim, | ||||
|         args.layer_dim, | ||||
|         args.time_dim, | ||||
|         args.seq_length, | ||||
|         args.lr, | ||||
|         args.weight_decay, | ||||
|         args.epochs, | ||||
|         args.env_version, | ||||
|     ) | ||||
|     main(args) | ||||
							
								
								
									
										373
									
								
								exps/GeMOSA/vis-synthetic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										373
									
								
								exps/GeMOSA/vis-synthetic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,373 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ############################################################################ | ||||
| # python exps/GMOA/vis-synthetic.py --env_version v1                       # | ||||
| # python exps/GMOA/vis-synthetic.py --env_version v2                       # | ||||
| ############################################################################ | ||||
| import os, sys, copy, random | ||||
| import torch | ||||
| import numpy as np | ||||
| import argparse | ||||
| from collections import OrderedDict, defaultdict | ||||
| from pathlib import Path | ||||
| from tqdm import tqdm | ||||
| from pprint import pprint | ||||
|  | ||||
| import matplotlib | ||||
| from matplotlib import cm | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None) | ||||
|  | ||||
|  | ||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
|     save_path = save_dir / "{:04d}".format(timestamp) | ||||
|     # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path)) | ||||
|     dpi, width, height = 40, wh[0], wh[1] | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     if fig_title is not None: | ||||
|         fig.suptitle( | ||||
|             fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92 | ||||
|         ) | ||||
|  | ||||
|     for idx, scatter_dict in enumerate(scatter_list): | ||||
|         cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1) | ||||
|         plot_scatter( | ||||
|             cur_ax, | ||||
|             scatter_dict["xaxis"], | ||||
|             scatter_dict["yaxis"], | ||||
|             scatter_dict["color"], | ||||
|             scatter_dict["alpha"], | ||||
|             scatter_dict["linewidths"], | ||||
|             scatter_dict["label"], | ||||
|         ) | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1]) | ||||
|         cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1]) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|     fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def find_min(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others) | ||||
|     else: | ||||
|         return float(min(cur, others)) | ||||
|  | ||||
|  | ||||
| def find_max(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others.max()) | ||||
|     else: | ||||
|         return float(max(cur, others)) | ||||
|  | ||||
|  | ||||
| def compare_cl(save_dir): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     dynamic_env, cl_function = create_example_v1( | ||||
|         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||
|         timestamp_config=dict(num=200), | ||||
|         num_per_task=1000, | ||||
|     ) | ||||
|  | ||||
|     models = dict() | ||||
|  | ||||
|     cl_function.set_timestamp(0) | ||||
|     cl_xaxis_min = None | ||||
|     cl_xaxis_max = None | ||||
|  | ||||
|     all_data = OrderedDict() | ||||
|  | ||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         xaxis_all = dataset[0][:, 0].numpy() | ||||
|         yaxis_all = dataset[1][:, 0].numpy() | ||||
|         current_data = dict() | ||||
|         current_data["lfna_xaxis_all"] = xaxis_all | ||||
|         current_data["lfna_yaxis_all"] = yaxis_all | ||||
|  | ||||
|         # compute cl-min | ||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) | ||||
|         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) | ||||
|         all_data[timestamp] = current_data | ||||
|  | ||||
|     global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) | ||||
|     global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all) | ||||
|  | ||||
|     for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): | ||||
|         scatter_list = [] | ||||
|         scatter_list.append( | ||||
|             { | ||||
|                 "xaxis": xdata["lfna_xaxis_all"], | ||||
|                 "yaxis": xdata["lfna_yaxis_all"], | ||||
|                 "color": "k", | ||||
|                 "linewidths": 15, | ||||
|                 "alpha": 0.99, | ||||
|                 "xlim": (-6, 6), | ||||
|                 "ylim": (-40, 40), | ||||
|                 "label": "LFNA", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|         cur_cl_xaxis_min = cl_xaxis_min | ||||
|         cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * ( | ||||
|             idx + 1 | ||||
|         ) / len(all_data) | ||||
|         cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01) | ||||
|         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2) | ||||
|  | ||||
|         scatter_list.append( | ||||
|             { | ||||
|                 "xaxis": cl_xaxis_all, | ||||
|                 "yaxis": cl_yaxis_all, | ||||
|                 "color": "k", | ||||
|                 "linewidths": 15, | ||||
|                 "xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)), | ||||
|                 "ylim": (-20, 6), | ||||
|                 "alpha": 0.99, | ||||
|                 "label": "Continual Learning", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|         draw_multi_fig( | ||||
|             save_dir, | ||||
|             idx, | ||||
|             scatter_list, | ||||
|             wh=(2200, 1800), | ||||
|             fig_title="Timestamp={:03d}".format(idx), | ||||
|         ) | ||||
|     print("Save all figures into {:}".format(save_dir)) | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = ( | ||||
|         "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( | ||||
|             xdir=save_dir | ||||
|         ) | ||||
|     ) | ||||
|     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( | ||||
|         base_cmd, xdir=save_dir | ||||
|     ) | ||||
|     print(video_cmd + "\n") | ||||
|     os.system(video_cmd) | ||||
|     os.system( | ||||
|         "{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def visualize_env(save_dir, version): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
|         LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|         allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx)) | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir / "png", version=version | ||||
|     ) | ||||
|     print(base_cmd) | ||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|  | ||||
|  | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 30, 3200, 2000 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(mode=None, version=version) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|  | ||||
|     alg_name2dir = OrderedDict() | ||||
|     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     # alg_name2dir["MAML"] = "use-maml-s1" | ||||
|     # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" | ||||
|     if version == "v1": | ||||
|         # alg_name2dir["Optimal"] = "use-same-timestamp" | ||||
|         alg_name2dir[ | ||||
|             "GMOA" | ||||
|         ] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1" | ||||
|     else: | ||||
|         raise ValueError("Invalid version: {:}".format(version)) | ||||
|     alg_name2all_containers = OrderedDict() | ||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|         ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth" | ||||
|         xdata = torch.load(ckp_path, map_location="cpu") | ||||
|         alg_name2all_containers[alg] = xdata["w_containers"] | ||||
|     # load the basic model | ||||
|     model = get_model( | ||||
|         dict(model_type="norm_mlp"), | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dims=[16] * 2, | ||||
|         act_cls="gelu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|  | ||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||
|     colors = ["r", "g", "b", "m", "y"] | ||||
|  | ||||
|     linewidths, skip = 10, 5 | ||||
|     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( | ||||
|         tqdm(dynamic_env, ncols=50) | ||||
|     ): | ||||
|         if idx <= skip: | ||||
|             continue | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|         cur_ax = fig.add_subplot(2, 1, 1) | ||||
|  | ||||
|         # the data | ||||
|         allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||
|  | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             with torch.no_grad(): | ||||
|                 predicts = model.forward_with_container( | ||||
|                     ori_allx, alg_name2all_containers[alg][idx] | ||||
|                 ) | ||||
|                 predicts = predicts.cpu() | ||||
|                 # keep data | ||||
|                 metric = MSEMetric() | ||||
|                 metric(predicts, ori_ally) | ||||
|                 predicts = predicts.view(-1).numpy() | ||||
|                 alg2xs[alg].append(idx) | ||||
|                 alg2ys[alg].append(metric.get_info()["mse"]) | ||||
|             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         # the trajectory data | ||||
|         cur_ax = fig.add_subplot(2, 1, 2) | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             # plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg) | ||||
|             cur_ax.plot( | ||||
|                 alg2xs[alg], | ||||
|                 alg2ys[alg], | ||||
|                 color=colors[idx_alg], | ||||
|                 linestyle="-", | ||||
|                 linewidth=5, | ||||
|                 label=alg, | ||||
|             ) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         cur_ax.set_xlabel("Timestamp", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("MSE", fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(1, len(dynamic_env)) | ||||
|         cur_ax.set_ylim(0, 10) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir / "png", w=width, h=height, ver=version | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Visualize synthetic data.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/vis-synthetic", | ||||
|         help="The save directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
		Reference in New Issue
	
	Block a user