Update ablation for GeMOSA
This commit is contained in:
		| @@ -9,6 +9,9 @@ | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| # <----> ablation commands | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||
| ########################################################## | ||||
| import sys, time, copy, torch, random, argparse | ||||
| @@ -267,11 +270,11 @@ def main(args): | ||||
|     ) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|     """ | ||||
|     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, False | ||||
|     w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, True, False | ||||
|     ) | ||||
|     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, True | ||||
|     w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, True, True | ||||
|     ) | ||||
|     logger.log( | ||||
|         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||
| @@ -286,6 +289,8 @@ def main(args): | ||||
|  | ||||
|     save_checkpoint( | ||||
|         { | ||||
|             "w_containers_care_adapt": w_containers_care_adapt, | ||||
|             "w_containers_easy_adapt": w_containers_easy_adapt, | ||||
|             "test_loss_adapt_v1": loss_adapt_v1, | ||||
|             "test_loss_adapt_v2": loss_adapt_v2, | ||||
|             "test_metric_adapt_v1": metric_adapt_v1, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user