Update yaml configs
This commit is contained in:
		| @@ -3,7 +3,7 @@ | ||||
| ##################################################### | ||||
| # python exps/basic/xmain.py --save_dir outputs/x   # | ||||
| ##################################################### | ||||
| import sys, time, torch, random, argparse | ||||
| import os, sys, time, torch, random, argparse | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| @@ -12,24 +12,38 @@ print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.xmisc import nested_call_by_yaml | ||||
| from xautodl import xmisc | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|  | ||||
|     train_data = nested_call_by_yaml(args.train_data_config, args.data_path) | ||||
|     valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||
|     train_data = xmisc.nested_call_by_yaml(args.train_data_config, args.data_path) | ||||
|     valid_data = xmisc.nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||
|     logger = xmisc.Logger(args.save_dir, prefix="seed-{:}-".format(args.rand_seed)) | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         args.dataset, args.data_path, args.cutout_length | ||||
|     logger.log("Create the logger: {:}".format(logger)) | ||||
|     logger.log("Arguments : -------------------------------") | ||||
|     for name, value in args._get_kwargs(): | ||||
|         logger.log("{:16} : {:}".format(name, value)) | ||||
|     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | ||||
|     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] | ||||
|             if "CUDA_VISIBLE_DEVICES" in os.environ | ||||
|             else "None" | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("The training data is:\n{:}".format(train_data)) | ||||
|     logger.log("The validation data is:\n{:}".format(valid_data)) | ||||
|  | ||||
|     model = xmisc.nested_call_by_yaml(args.model_config) | ||||
|     logger.log("The model is:\n{:}".format(model)) | ||||
|     logger.log("The model size is {:.4f} M".format(xmisc.count_parameters(model))) | ||||
|  | ||||
|     train_loader = torch.utils.data.DataLoader( | ||||
|         train_data, | ||||
|         batch_size=args.batch_size, | ||||
| @@ -44,100 +58,25 @@ def main(args): | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     # get configures | ||||
|     model_config = load_config(args.model_config, {"class_num": class_num}, logger) | ||||
|     optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) | ||||
|  | ||||
|     if args.model_source == "normal": | ||||
|         base_model = obtain_model(model_config) | ||||
|     elif args.model_source == "nas": | ||||
|         base_model = obtain_nas_infer_model(model_config, args.extra_model_path) | ||||
|     elif args.model_source == "autodl-searched": | ||||
|         base_model = obtain_model(model_config, args.extra_model_path) | ||||
|     elif args.model_source in ("x", "xmodel"): | ||||
|         base_model = obtain_xmodel(model_config) | ||||
|     else: | ||||
|         raise ValueError("invalid model-source : {:}".format(args.model_source)) | ||||
|     flop, param = get_model_infos(base_model, xshape) | ||||
|     logger.log("model ====>>>>:\n{:}".format(base_model)) | ||||
|     logger.log("model information : {:}".format(base_model.get_message())) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log( | ||||
|         "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( | ||||
|             param, flop, flop / 1e3 | ||||
|         ) | ||||
|     logger.log("The training loader: {:}".format(train_loader)) | ||||
|     logger.log("The validation loader: {:}".format(valid_loader)) | ||||
|     optimizer = xmisc.nested_call_by_yaml( | ||||
|         args.optim_config, | ||||
|         model.parameters(), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|     ) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log("train_data : {:}".format(train_data)) | ||||
|     logger.log("valid_data : {:}".format(valid_data)) | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler( | ||||
|         base_model.parameters(), optim_config | ||||
|     ) | ||||
|     logger.log("optimizer  : {:}".format(optimizer)) | ||||
|     logger.log("scheduler  : {:}".format(scheduler)) | ||||
|     logger.log("criterion  : {:}".format(criterion)) | ||||
|     loss = xmisc.nested_call_by_yaml(args.loss_config) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() | ||||
|     logger.log("The optimizer is:\n{:}".format(optimizer)) | ||||
|     logger.log("The loss is {:}".format(loss)) | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_infox = torch.load(last_info) | ||||
|         start_epoch = last_infox["epoch"] + 1 | ||||
|         last_checkpoint_path = last_infox["last_checkpoint"] | ||||
|         if not last_checkpoint_path.exists(): | ||||
|             logger.log( | ||||
|                 "Does not find {:}, try another path".format(last_checkpoint_path) | ||||
|             ) | ||||
|             last_checkpoint_path = ( | ||||
|                 last_info.parent | ||||
|                 / last_checkpoint_path.parent.name | ||||
|                 / last_checkpoint_path.name | ||||
|             ) | ||||
|         checkpoint = torch.load(last_checkpoint_path) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     elif args.resume is not None: | ||||
|         assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( | ||||
|             args.resume | ||||
|         ) | ||||
|         checkpoint = torch.load(args.resume) | ||||
|         start_epoch = checkpoint["epoch"] + 1 | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( | ||||
|                 args.resume, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     elif args.init_model is not None: | ||||
|         assert Path( | ||||
|             args.init_model | ||||
|         ).exists(), "Can not find the initialization file : {:}".format(args.init_model) | ||||
|         checkpoint = torch.load(args.init_model) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|         logger.log("=> initialize the model from {:}".format(args.init_model)) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|     model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     train_func, valid_func = get_procedures(args.procedure) | ||||
|  | ||||
| @@ -284,7 +223,7 @@ def main(args): | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a model with a loss function.", | ||||
|         description="Train a classification model with a loss function.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -293,27 +232,21 @@ if __name__ == "__main__": | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||
|     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") | ||||
|     parser.add_argument("--loss_config", type=str, help="The loss config file.") | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer config" | ||||
|         "--train_data_config", type=str, help="The training dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--train_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--valid_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--data_path", type=str, help="The path to the dataset." | ||||
|         "--valid_data_config", type=str, help="The validation dataset config path." | ||||
|     ) | ||||
|     parser.add_argument("--data_path", type=str, help="The path to the dataset.") | ||||
|     parser.add_argument("--algorithm", type=str, help="The algorithm.") | ||||
|     # Optimization options | ||||
|     parser.add_argument("--lr", type=float, help="The learning rate") | ||||
|     parser.add_argument("--weight_decay", type=float, help="The weight decay") | ||||
|     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     parser.add_argument("--workers", type=int, default=4, help="The number of workers") | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user