Fix bugs
This commit is contained in:
		| @@ -204,11 +204,13 @@ def main(args): | |||||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) |     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) |     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||||
|     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) |     trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) | ||||||
|  |     test_env = get_synthetic_env(mode="test", version=args.env_version) | ||||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) |     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||||
|     logger.log("The training enviornment: {:}".format(train_env)) |     logger.log("The training enviornment: {:}".format(train_env)) | ||||||
|     logger.log("The validation enviornment: {:}".format(valid_env)) |     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||||
|     logger.log("The trainval enviornment: {:}".format(trainval_env)) |     logger.log("The trainval enviornment: {:}".format(trainval_env)) | ||||||
|     logger.log("The total enviornment: {:}".format(all_env)) |     logger.log("The total enviornment: {:}".format(all_env)) | ||||||
|  |     logger.log("The test enviornment: {:}".format(test_env)) | ||||||
|     model_kwargs = dict( |     model_kwargs = dict( | ||||||
|         config=dict(model_type="norm_mlp"), |         config=dict(model_type="norm_mlp"), | ||||||
|         input_dim=all_env.meta_info["input_dim"], |         input_dim=all_env.meta_info["input_dim"], | ||||||
| @@ -268,10 +270,10 @@ def main(args): | |||||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) |     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||||
|     """ |     """ | ||||||
|     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( |     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, False |         test_env, meta_model, base_model, criterion, metric, args, logger, False, False | ||||||
|     ) |     ) | ||||||
|     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( |     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, True |         test_env, meta_model, base_model, criterion, metric, args, logger, False, True | ||||||
|     ) |     ) | ||||||
|     logger.log( |     logger.log( | ||||||
|         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( |         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user