Update his/same
This commit is contained in:
		
							
								
								
									
										192
									
								
								exps/LFNA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								exps/LFNA/basic-his.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,192 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/basic-his.py --srange 1-999 | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def subsample(historical_x, historical_y, maxn=10000): | ||||||
|  |     total = historical_x.size(0) | ||||||
|  |     if total <= maxn: | ||||||
|  |         return historical_x, historical_y | ||||||
|  |     else: | ||||||
|  |         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||||
|  |         return historical_x[indexes], historical_y[indexes] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     prepare_seed(args.rand_seed) | ||||||
|  |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||||
|  |     if cache_path.exists(): | ||||||
|  |         env_info = torch.load(cache_path) | ||||||
|  |     else: | ||||||
|  |         env_info = dict() | ||||||
|  |         dynamic_env = get_synthetic_env() | ||||||
|  |         env_info["total"] = len(dynamic_env) | ||||||
|  |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|  |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
|  |             env_info["{:}-x".format(idx)] = _allx | ||||||
|  |             env_info["{:}-y".format(idx)] = _ally | ||||||
|  |         env_info["dynamic_env"] = dynamic_env | ||||||
|  |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     # check indexes to be evaluated | ||||||
|  |     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||||
|  |     logger.log( | ||||||
|  |         "Evaluate {:}, which has {:} timestamps in total.".format( | ||||||
|  |             args.srange, len(to_evaluate_indexes) | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for i, idx in enumerate(to_evaluate_indexes): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time( | ||||||
|  |                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||||
|  |             + " " | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |         # train the same data | ||||||
|  |         assert idx != 0 | ||||||
|  |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|  |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|  |         # build model | ||||||
|  |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|  |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
|  |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |         # build optimizer | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(args.epochs * 0.25), | ||||||
|  |                 int(args.epochs * 0.5), | ||||||
|  |                 int(args.epochs * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.3, | ||||||
|  |         ) | ||||||
|  |         train_metric = MSEMetric() | ||||||
|  |         best_loss, best_param = None, None | ||||||
|  |         for _iepoch in range(args.epochs): | ||||||
|  |             preds = model(historical_x) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss = criterion(preds, historical_y) | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             lr_scheduler.step() | ||||||
|  |             # save best | ||||||
|  |             if best_loss is None or best_loss > loss.item(): | ||||||
|  |                 best_loss = loss.item() | ||||||
|  |                 best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |         model.load_state_dict(best_param) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             train_metric(preds, historical_y) | ||||||
|  |         train_results = train_metric.get_info() | ||||||
|  |  | ||||||
|  |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|  |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|  |             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||||
|  |         ) | ||||||
|  |         eval_loader = torch.utils.data.DataLoader( | ||||||
|  |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
|  |         ) | ||||||
|  |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|  |         log_str = ( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|  |             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||||
|  |                 train_results["mse"], results["mse"] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log(log_str) | ||||||
|  |  | ||||||
|  |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|  |             idx, env_info["total"] | ||||||
|  |         ) | ||||||
|  |         save_checkpoint( | ||||||
|  |             { | ||||||
|  |                 "model": model.state_dict(), | ||||||
|  |                 "index": idx, | ||||||
|  |                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||||
|  |             }, | ||||||
|  |             save_path, | ||||||
|  |             logger, | ||||||
|  |         ) | ||||||
|  |         logger.log("") | ||||||
|  |  | ||||||
|  |         per_timestamp_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use all the past data to train.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     main(args) | ||||||
							
								
								
									
										196
									
								
								exps/LFNA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								exps/LFNA/basic-same.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,196 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/basic-same.py --srange 1-999 | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def subsample(historical_x, historical_y, maxn=10000): | ||||||
|  |     total = historical_x.size(0) | ||||||
|  |     if total <= maxn: | ||||||
|  |         return historical_x, historical_y | ||||||
|  |     else: | ||||||
|  |         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||||
|  |         return historical_x[indexes], historical_y[indexes] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     prepare_seed(args.rand_seed) | ||||||
|  |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||||
|  |     if cache_path.exists(): | ||||||
|  |         env_info = torch.load(cache_path) | ||||||
|  |     else: | ||||||
|  |         env_info = dict() | ||||||
|  |         dynamic_env = get_synthetic_env() | ||||||
|  |         env_info["total"] = len(dynamic_env) | ||||||
|  |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|  |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
|  |             env_info["{:}-x".format(idx)] = _allx | ||||||
|  |             env_info["{:}-y".format(idx)] = _ally | ||||||
|  |         env_info["dynamic_env"] = dynamic_env | ||||||
|  |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     # check indexes to be evaluated | ||||||
|  |     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||||
|  |     logger.log( | ||||||
|  |         "Evaluate {:}, which has {:} timestamps in total.".format( | ||||||
|  |             args.srange, len(to_evaluate_indexes) | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for i, idx in enumerate(to_evaluate_indexes): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time( | ||||||
|  |                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||||
|  |             + " " | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |         # train the same data | ||||||
|  |         assert idx != 0 | ||||||
|  |         historical_x, historical_y = [], [] | ||||||
|  |         for past_i in range(idx): | ||||||
|  |             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||||
|  |             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||||
|  |         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||||
|  |         historical_x, historical_y = subsample(historical_x, historical_y) | ||||||
|  |         # build model | ||||||
|  |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|  |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
|  |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |         # build optimizer | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(args.epochs * 0.25), | ||||||
|  |                 int(args.epochs * 0.5), | ||||||
|  |                 int(args.epochs * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.3, | ||||||
|  |         ) | ||||||
|  |         train_metric = MSEMetric() | ||||||
|  |         best_loss, best_param = None, None | ||||||
|  |         for _iepoch in range(args.epochs): | ||||||
|  |             preds = model(historical_x) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss = criterion(preds, historical_y) | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             lr_scheduler.step() | ||||||
|  |             # save best | ||||||
|  |             if best_loss is None or best_loss > loss.item(): | ||||||
|  |                 best_loss = loss.item() | ||||||
|  |                 best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |         model.load_state_dict(best_param) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             train_metric(preds, historical_y) | ||||||
|  |         train_results = train_metric.get_info() | ||||||
|  |  | ||||||
|  |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|  |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|  |             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||||
|  |         ) | ||||||
|  |         eval_loader = torch.utils.data.DataLoader( | ||||||
|  |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
|  |         ) | ||||||
|  |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|  |         log_str = ( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|  |             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||||
|  |                 train_results["mse"], results["mse"] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log(log_str) | ||||||
|  |  | ||||||
|  |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|  |             idx, env_info["total"] | ||||||
|  |         ) | ||||||
|  |         save_checkpoint( | ||||||
|  |             { | ||||||
|  |                 "model": model.state_dict(), | ||||||
|  |                 "index": idx, | ||||||
|  |                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||||
|  |             }, | ||||||
|  |             save_path, | ||||||
|  |             logger, | ||||||
|  |         ) | ||||||
|  |         logger.log("") | ||||||
|  |  | ||||||
|  |         per_timestamp_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use data at the same timestamp.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/use-all-past-data", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     main(args) | ||||||
| @@ -1,165 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # |  | ||||||
| ##################################################### |  | ||||||
| # python exps/LFNA/basic.py |  | ||||||
| ##################################################### |  | ||||||
| import sys, time, torch, random, argparse |  | ||||||
| from copy import deepcopy |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() |  | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint |  | ||||||
| from log_utils import time_string |  | ||||||
|  |  | ||||||
| from procedures.advanced_main import basic_train_fn, basic_eval_fn |  | ||||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric |  | ||||||
| from datasets.synthetic_core import get_synthetic_env |  | ||||||
| from models.xcore import get_model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): |  | ||||||
|     torch.set_num_threads(args.workers) |  | ||||||
|     prepare_seed(args.rand_seed) |  | ||||||
|     logger = prepare_logger(args) |  | ||||||
|  |  | ||||||
|     dynamic_env = get_synthetic_env() |  | ||||||
|     historical_x, historical_y = None, None |  | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(dynamic_env): |  | ||||||
|  |  | ||||||
|         if historical_x is not None: |  | ||||||
|             mean, std = historical_x.mean().item(), historical_x.std().item() |  | ||||||
|         else: |  | ||||||
|             mean, std = 0, 1 |  | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |  | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |  | ||||||
|  |  | ||||||
|         # create the current data loader |  | ||||||
|         if historical_x is not None: |  | ||||||
|             train_dataset = torch.utils.data.TensorDataset(historical_x, historical_y) |  | ||||||
|             train_loader = torch.utils.data.DataLoader( |  | ||||||
|                 train_dataset, |  | ||||||
|                 batch_size=args.batch_size, |  | ||||||
|                 shuffle=True, |  | ||||||
|                 num_workers=args.workers, |  | ||||||
|             ) |  | ||||||
|             optimizer = torch.optim.Adam( |  | ||||||
|                 model.parameters(), lr=args.init_lr, amsgrad=True |  | ||||||
|             ) |  | ||||||
|             criterion = torch.nn.MSELoss() |  | ||||||
|             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |  | ||||||
|                 optimizer, |  | ||||||
|                 milestones=[ |  | ||||||
|                     int(args.epochs * 0.25), |  | ||||||
|                     int(args.epochs * 0.5), |  | ||||||
|                     int(args.epochs * 0.75), |  | ||||||
|                 ], |  | ||||||
|                 gamma=0.3, |  | ||||||
|             ) |  | ||||||
|             for _iepoch in range(args.epochs): |  | ||||||
|                 results = basic_train_fn( |  | ||||||
|                     train_loader, model, criterion, optimizer, MSEMetric(), logger |  | ||||||
|                 ) |  | ||||||
|                 lr_scheduler.step() |  | ||||||
|                 if _iepoch % args.log_per_epoch == 0: |  | ||||||
|                     log_str = ( |  | ||||||
|                         "[{:}]".format(time_string()) |  | ||||||
|                         + " [{:04d}/{:04d}][{:04d}/{:04d}]".format( |  | ||||||
|                             idx, len(dynamic_env), _iepoch, args.epochs |  | ||||||
|                         ) |  | ||||||
|                         + " mse: {:.5f}, lr: {:.4f}".format( |  | ||||||
|                             results["mse"], min(lr_scheduler.get_last_lr()) |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                     logger.log(log_str) |  | ||||||
|             results = basic_eval_fn(train_loader, model, MSEMetric(), logger) |  | ||||||
|             logger.log( |  | ||||||
|                 "[{:}] [{:04d}/{:04d}] train-mse: {:.5f}".format( |  | ||||||
|                     time_string(), idx, len(dynamic_env), results["mse"] |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) |  | ||||||
|         eval_dataset = torch.utils.data.TensorDataset(allx, ally) |  | ||||||
|         eval_loader = torch.utils.data.DataLoader( |  | ||||||
|             eval_dataset, |  | ||||||
|             batch_size=args.batch_size, |  | ||||||
|             shuffle=False, |  | ||||||
|             num_workers=args.workers, |  | ||||||
|         ) |  | ||||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) |  | ||||||
|         log_str = ( |  | ||||||
|             "[{:}]".format(time_string()) |  | ||||||
|             + " [{:04d}/{:04d}]".format(idx, len(dynamic_env)) |  | ||||||
|             + " eval-mse: {:.5f}".format(results["mse"]) |  | ||||||
|         ) |  | ||||||
|         logger.log(log_str) |  | ||||||
|  |  | ||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |  | ||||||
|             idx, len(dynamic_env) |  | ||||||
|         ) |  | ||||||
|         save_checkpoint( |  | ||||||
|             {"model": model.state_dict(), "index": idx, "timestamp": timestamp}, |  | ||||||
|             save_path, |  | ||||||
|             logger, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Update historical data |  | ||||||
|         if historical_x is None: |  | ||||||
|             historical_x, historical_y = allx, ally |  | ||||||
|         else: |  | ||||||
|             historical_x, historical_y = torch.cat((historical_x, allx)), torch.cat( |  | ||||||
|                 (historical_y, ally) |  | ||||||
|             ) |  | ||||||
|         logger.log("") |  | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |  | ||||||
|     logger.close() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser("Use all the past data to train.") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--save_dir", |  | ||||||
|         type=str, |  | ||||||
|         default="./outputs/lfna-synthetic/use-all-past-data", |  | ||||||
|         help="The checkpoint directory.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--init_lr", |  | ||||||
|         type=float, |  | ||||||
|         default=0.1, |  | ||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--batch_size", |  | ||||||
|         type=int, |  | ||||||
|         default=256, |  | ||||||
|         help="The batch size", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--epochs", |  | ||||||
|         type=int, |  | ||||||
|         default=2000, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--log_per_epoch", |  | ||||||
|         type=int, |  | ||||||
|         default=200, |  | ||||||
|         help="Log the training information per __ epochs.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--workers", |  | ||||||
|         type=int, |  | ||||||
|         default=4, |  | ||||||
|         help="The number of data loading workers (default: 4)", |  | ||||||
|     ) |  | ||||||
|     # Random Seed |  | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     if args.rand_seed is None or args.rand_seed < 0: |  | ||||||
|         args.rand_seed = random.randint(1, 100000) |  | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |  | ||||||
|     main(args) |  | ||||||
		Reference in New Issue
	
	Block a user