diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index 199cf8f..ac6ec23 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -9,6 +9,9 @@ # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda # <----> ablation commands +# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda +# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda +# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda ########################################################## import sys, time, copy, torch, random, argparse @@ -269,11 +272,19 @@ def main(args): ) logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) """ +<<<<<<< HEAD _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( test_env, meta_model, base_model, criterion, metric, args, logger, False, False ) _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( test_env, meta_model, base_model, criterion, metric, args, logger, False, True +======= + w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, True, False + ) + w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, True, True +>>>>>>> d4b846a9717279c08f1264398972c00aa949a69f ) logger.log( "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( @@ -288,6 +299,8 @@ def main(args): save_checkpoint( { + "w_containers_care_adapt": w_containers_care_adapt, + "w_containers_easy_adapt": w_containers_easy_adapt, "test_loss_adapt_v1": loss_adapt_v1, "test_loss_adapt_v2": loss_adapt_v2, "test_metric_adapt_v1": metric_adapt_v1,