Update xmisc with yaml
This commit is contained in:
		| @@ -1,35 +1,28 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| # python exps/basic/xmain.py --save_dir outputs/x   # | ||||
| ##################################################### | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.config_utils import load_config, obtain_basic_args as obtain_args | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
| ) | ||||
| from xautodl.procedures import get_optim_scheduler, get_procedures | ||||
| from xautodl.models import obtain_model | ||||
| from xautodl.xmodels import obtain_model as obtain_xmodel | ||||
| from xautodl.nas_infer_model import obtain_nas_infer_model | ||||
| from xautodl.utils import get_model_infos | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.xmisc import nested_call_by_yaml | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = True | ||||
|     # torch.backends.cudnn.deterministic = True | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     train_data = nested_call_by_yaml(args.train_data_config, args.data_path) | ||||
|     valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
| @@ -290,5 +283,44 @@ def main(args): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     args = obtain_args() | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a model with a loss function.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer config" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--train_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--valid_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--data_path", type=str, help="The path to the dataset." | ||||
|     ) | ||||
|     parser.add_argument("--algorithm", type=str, help="The algorithm.") | ||||
|     # Optimization options | ||||
|     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     if args.save_dir is None: | ||||
|         raise ValueError("The save-path argument can not be None") | ||||
|  | ||||
|     main(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user