Update metrics
This commit is contained in:
		| @@ -69,10 +69,13 @@ def main(args): | ||||
|         weight_decay=args.weight_decay, | ||||
|     ) | ||||
|     objective = xmisc.nested_call_by_yaml(args.loss_config) | ||||
|     metric = xmisc.nested_call_by_yaml(args.metric_config) | ||||
|  | ||||
|     logger.log("The optimizer is:\n{:}".format(optimizer)) | ||||
|     logger.log("The objective is {:}".format(objective)) | ||||
|     logger.log("The iters_per_epoch={:}".format(iters_per_epoch)) | ||||
|     logger.log("The metric is {:}".format(metric)) | ||||
|     logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format( | ||||
|       iters_per_epoch, args.steps // iters_per_epoch)) | ||||
|  | ||||
|     model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() | ||||
|     scheduler = xmisc.LRMultiplier( | ||||
| @@ -99,6 +102,7 @@ def main(args): | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         scheduler.step() | ||||
|  | ||||
|         if xiter % iters_per_epoch == 0: | ||||
|             logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) | ||||
|  | ||||
| @@ -123,6 +127,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||
|     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") | ||||
|     parser.add_argument("--loss_config", type=str, help="The loss config file.") | ||||
|     parser.add_argument("--metric_config", type=str, help="The metric config file.") | ||||
|     parser.add_argument( | ||||
|         "--train_data_config", type=str, help="The training dataset config path." | ||||
|     ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user