Correct the codes
This commit is contained in:
		| @@ -9,6 +9,12 @@ from tqdm import tqdm | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| from xautodl.procedures import ( | from xautodl.procedures import ( | ||||||
|     prepare_seed, |     prepare_seed, | ||||||
|     prepare_logger, |     prepare_logger, | ||||||
| @@ -38,28 +44,30 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, model_kwargs = lfna_setup(args) | ||||||
|  |  | ||||||
|     w_container_per_epoch = dict() |     env = get_synthetic_env(mode=None, version=args.env_version) | ||||||
|  |     logger.log("The total enviornment: {:}".format(env)) | ||||||
|  |     w_containers = dict() | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for idx in range(1, env_info["total"]): |     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||||
|  |  | ||||||
|         need_time = "Time Left: {:}".format( |         need_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) |             convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) | ||||||
|         ) |         ) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "[{:}]".format(time_string()) |             "[{:}]".format(time_string()) | ||||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) |             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||||
|             + " " |             + " " | ||||||
|             + need_time |             + need_time | ||||||
|         ) |         ) | ||||||
|         # train the same data |         # train the same data | ||||||
|         historical_x = env_info["{:}-x".format(idx)] |         historical_x = future_x.to(args.device) | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |         historical_y = future_y.to(args.device) | ||||||
|         # build model |         # build model | ||||||
|         model = get_model(**model_kwargs) |         model = get_model(**model_kwargs) | ||||||
|         print(model) |         model = model.to(args.device) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|         criterion = torch.nn.MSELoss() |         criterion = torch.nn.MSELoss() | ||||||
| @@ -93,7 +101,7 @@ def main(args): | |||||||
|  |  | ||||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|         eval_dataset = torch.utils.data.TensorDataset( |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] |             future_x.to(args.device), future_y.to(args.device) | ||||||
|         ) |         ) | ||||||
|         eval_loader = torch.utils.data.DataLoader( |         eval_loader = torch.utils.data.DataLoader( | ||||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
| @@ -101,23 +109,21 @@ def main(args): | |||||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|         log_str = ( |         log_str = ( | ||||||
|             "[{:}]".format(time_string()) |             "[{:}]".format(time_string()) | ||||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) |             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( |             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||||
|                 train_results["mse"], results["mse"] |                 train_results["mse"], results["mse"] | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         logger.log(log_str) |         logger.log(log_str) | ||||||
|  |  | ||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env)) | ||||||
|             idx, env_info["total"] |         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||||
|         ) |  | ||||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() |  | ||||||
|         save_checkpoint( |         save_checkpoint( | ||||||
|             { |             { | ||||||
|                 "model_state_dict": model.state_dict(), |                 "model_state_dict": model.state_dict(), | ||||||
|                 "model": model, |                 "model": model, | ||||||
|                 "index": idx, |                 "index": idx, | ||||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], |                 "timestamp": future_time.item(), | ||||||
|             }, |             }, | ||||||
|             save_path, |             save_path, | ||||||
|             logger, |             logger, | ||||||
| @@ -127,7 +133,7 @@ def main(args): | |||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"w_container_per_epoch": w_container_per_epoch}, |         {"w_containers": w_containers}, | ||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp.pth", | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
| @@ -174,6 +180,12 @@ if __name__ == "__main__": | |||||||
|         default=300, |         default=300, | ||||||
|         help="The total number of epochs.", |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--device", | ||||||
|  |         type=str, | ||||||
|  |         default="cpu", | ||||||
|  |         help="", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|         type=int, |         type=int, | ||||||
|   | |||||||
| @@ -225,9 +225,11 @@ def main(args): | |||||||
|     logger, model_kwargs = lfna_setup(args) |     logger, model_kwargs = lfna_setup(args) | ||||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) |     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) |     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||||
|  |     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) |     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||||
|     logger.log("The training enviornment: {:}".format(train_env)) |     logger.log("The training enviornment: {:}".format(train_env)) | ||||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) |     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||||
|  |     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||||
|     logger.log("The total enviornment: {:}".format(all_env)) |     logger.log("The total enviornment: {:}".format(all_env)) | ||||||
|  |  | ||||||
|     base_model = get_model(**model_kwargs) |     base_model = get_model(**model_kwargs) | ||||||
| @@ -237,14 +239,14 @@ def main(args): | |||||||
|     shape_container = base_model.get_w_container().to_shape_container() |     shape_container = base_model.get_w_container().to_shape_container() | ||||||
|  |  | ||||||
|     # pre-train the hypernetwork |     # pre-train the hypernetwork | ||||||
|     timestamps = train_env.get_timestamp(None) |     timestamps = trainval_env.get_timestamp(None) | ||||||
|     meta_model = LFNA_Meta( |     meta_model = LFNA_Meta( | ||||||
|         shape_container, |         shape_container, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
|         timestamps, |         timestamps, | ||||||
|         seq_length=args.seq_length, |         seq_length=args.seq_length, | ||||||
|         interval=train_env.time_interval, |         interval=trainval_env.time_interval, | ||||||
|     ) |     ) | ||||||
|     meta_model = meta_model.to(args.device) |     meta_model = meta_model.to(args.device) | ||||||
|  |  | ||||||
| @@ -253,8 +255,7 @@ def main(args): | |||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|  |  | ||||||
|     # batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) |     pretrain_v2(base_model, meta_model, criterion, trainval_env, args, logger) | ||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |  | ||||||
|  |  | ||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) |     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||||
|   | |||||||
| @@ -22,12 +22,12 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio | |||||||
|             [mean_generator], [[std_generator]], (-2, 2) |             [mean_generator], [[std_generator]], (-2, 2) | ||||||
|         ) |         ) | ||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicLinearFunc( |         oracle_map = DynamicLinearFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), |                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), | ||||||
|                 1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), |                 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ class UnifiedSplit: | |||||||
|             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] |             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] | ||||||
|         elif mode.lower() in ("test", "testing"): |         elif mode.lower() in ("test", "testing"): | ||||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] |             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||||
|         elif mode.lower() in ("trainval", "trainvalidation"): |         elif mode.lower() in ("trainval", "trainvalid", "trainvalidation"): | ||||||
|             self._indexes = all_indexes[: num_of_train + num_of_valid] |             self._indexes = all_indexes[: num_of_train + num_of_valid] | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) |             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user