diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 83c4592..1ca25a6 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -9,6 +9,12 @@ from tqdm import tqdm from copy import deepcopy from pathlib import Path +lib_dir = (Path(__file__).parent / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + + from xautodl.procedures import ( prepare_seed, prepare_logger, @@ -38,28 +44,30 @@ def subsample(historical_x, historical_y, maxn=10000): def main(args): - logger, env_info, model_kwargs = lfna_setup(args) + logger, model_kwargs = lfna_setup(args) - w_container_per_epoch = dict() + env = get_synthetic_env(mode=None, version=args.env_version) + logger.log("The total enviornment: {:}".format(env)) + w_containers = dict() per_timestamp_time, start_time = AverageMeter(), time.time() - for idx in range(1, env_info["total"]): + for idx, (future_time, (future_x, future_y)) in enumerate(env): need_time = "Time Left: {:}".format( - convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) + convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) ) logger.log( "[{:}]".format(time_string()) - + " [{:04d}/{:04d}]".format(idx, env_info["total"]) + + " [{:04d}/{:04d}]".format(idx, len(env)) + " " + need_time ) # train the same data - historical_x = env_info["{:}-x".format(idx)] - historical_y = env_info["{:}-y".format(idx)] + historical_x = future_x.to(args.device) + historical_y = future_y.to(args.device) # build model model = get_model(**model_kwargs) - print(model) + model = model.to(args.device) # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) criterion = torch.nn.MSELoss() @@ -93,7 +101,7 @@ def main(args): metric = ComposeMetric(MSEMetric(), SaveMetric()) eval_dataset = torch.utils.data.TensorDataset( - env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] + future_x.to(args.device), future_y.to(args.device) ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 @@ -101,23 +109,21 @@ def main(args): results = basic_eval_fn(eval_loader, model, metric, logger) log_str = ( "[{:}]".format(time_string()) - + " [{:04d}/{:04d}]".format(idx, env_info["total"]) + + " [{:04d}/{:04d}]".format(idx, len(env)) + " train-mse: {:.5f}, eval-mse: {:.5f}".format( train_results["mse"], results["mse"] ) ) logger.log(log_str) - save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( - idx, env_info["total"] - ) - w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() + save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env)) + w_containers[idx] = model.get_w_container().no_grad_clone() save_checkpoint( { "model_state_dict": model.state_dict(), "model": model, "index": idx, - "timestamp": env_info["{:}-timestamp".format(idx)], + "timestamp": future_time.item(), }, save_path, logger, @@ -127,7 +133,7 @@ def main(args): start_time = time.time() save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, + {"w_containers": w_containers}, logger.path(None) / "final-ckp.pth", logger, ) @@ -174,6 +180,12 @@ if __name__ == "__main__": default=300, help="The total number of epochs.", ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) parser.add_argument( "--workers", type=int, diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 8ab37f0..ddfcee3 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -225,9 +225,11 @@ def main(args): logger, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version) + trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) all_env = get_synthetic_env(mode=None, version=args.env_version) logger.log("The training enviornment: {:}".format(train_env)) logger.log("The validation enviornment: {:}".format(valid_env)) + logger.log("The trainval enviornment: {:}".format(trainval_env)) logger.log("The total enviornment: {:}".format(all_env)) base_model = get_model(**model_kwargs) @@ -237,14 +239,14 @@ def main(args): shape_container = base_model.get_w_container().to_shape_container() # pre-train the hypernetwork - timestamps = train_env.get_timestamp(None) + timestamps = trainval_env.get_timestamp(None) meta_model = LFNA_Meta( shape_container, args.layer_dim, args.time_dim, timestamps, seq_length=args.seq_length, - interval=train_env.time_interval, + interval=trainval_env.time_interval, ) meta_model = meta_model.to(args.device) @@ -253,8 +255,7 @@ def main(args): logger.log("The base-model is\n{:}".format(base_model)) logger.log("The meta-model is\n{:}".format(meta_model)) - # batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) - pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) + pretrain_v2(base_model, meta_model, criterion, trainval_env, args, logger) # try to evaluate once # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) diff --git a/xautodl/datasets/synthetic_core.py b/xautodl/datasets/synthetic_core.py index 9a3eb2e..b22819e 100644 --- a/xautodl/datasets/synthetic_core.py +++ b/xautodl/datasets/synthetic_core.py @@ -22,12 +22,12 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio [mean_generator], [[std_generator]], (-2, 2) ) time_generator = TimeStamp( - min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode + min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode ) oracle_map = DynamicLinearFunc( params={ 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), - 1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), + 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), } ) dynamic_env = SyntheticDEnv( diff --git a/xautodl/datasets/synthetic_utils.py b/xautodl/datasets/synthetic_utils.py index 9c70e6b..af353c3 100644 --- a/xautodl/datasets/synthetic_utils.py +++ b/xautodl/datasets/synthetic_utils.py @@ -28,7 +28,7 @@ class UnifiedSplit: self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] elif mode.lower() in ("test", "testing"): self._indexes = all_indexes[num_of_train + num_of_valid :] - elif mode.lower() in ("trainval", "trainvalidation"): + elif mode.lower() in ("trainval", "trainvalid", "trainvalidation"): self._indexes = all_indexes[: num_of_train + num_of_valid] else: raise ValueError("Unkonwn mode of {:}".format(mode))