Update ablation for GeMOSA
This commit is contained in:
parent
5dd75696c9
commit
d4b846a971
@ -9,6 +9,9 @@
|
||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
# <----> ablation commands
|
||||
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
|
||||
##########################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
@ -267,11 +270,11 @@ def main(args):
|
||||
)
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
"""
|
||||
_, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, False
|
||||
w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, True, False
|
||||
)
|
||||
_, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, True
|
||||
w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, True, True
|
||||
)
|
||||
logger.log(
|
||||
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||
@ -286,6 +289,8 @@ def main(args):
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"w_containers_care_adapt": w_containers_care_adapt,
|
||||
"w_containers_easy_adapt": w_containers_easy_adapt,
|
||||
"test_loss_adapt_v1": loss_adapt_v1,
|
||||
"test_loss_adapt_v2": loss_adapt_v2,
|
||||
"test_metric_adapt_v1": metric_adapt_v1,
|
||||
|
Loading…
Reference in New Issue
Block a user