From d4b846a9717279c08f1264398972c00aa949a69f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 20:54:13 +0800 Subject: [PATCH] Update ablation for GeMOSA --- exps/GeMOSA/main.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index 6824179..3722cf3 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -9,6 +9,9 @@ # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda # <----> ablation commands +# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda +# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda +# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda ########################################################## import sys, time, copy, torch, random, argparse @@ -267,11 +270,11 @@ def main(args): ) logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) """ - _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( - valid_env, meta_model, base_model, criterion, metric, args, logger, False, False + w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, True, False ) - _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( - valid_env, meta_model, base_model, criterion, metric, args, logger, False, True + w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, True, True ) logger.log( "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( @@ -286,6 +289,8 @@ def main(args): save_checkpoint( { + "w_containers_care_adapt": w_containers_care_adapt, + "w_containers_easy_adapt": w_containers_easy_adapt, "test_loss_adapt_v1": loss_adapt_v1, "test_loss_adapt_v2": loss_adapt_v2, "test_metric_adapt_v1": metric_adapt_v1,