Updates
This commit is contained in:
		| @@ -9,6 +9,9 @@ | |||||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| # <----> ablation commands | # <----> ablation commands | ||||||
|  | # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||||
|  | # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda | ||||||
|  | # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||||
| ########################################################## | ########################################################## | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| @@ -269,11 +272,19 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) |     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||||
|     """ |     """ | ||||||
|  | <<<<<<< HEAD | ||||||
|     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( |     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||||
|         test_env, meta_model, base_model, criterion, metric, args, logger, False, False |         test_env, meta_model, base_model, criterion, metric, args, logger, False, False | ||||||
|     ) |     ) | ||||||
|     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( |     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||||
|         test_env, meta_model, base_model, criterion, metric, args, logger, False, True |         test_env, meta_model, base_model, criterion, metric, args, logger, False, True | ||||||
|  | ======= | ||||||
|  |     w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||||
|  |         valid_env, meta_model, base_model, criterion, metric, args, logger, True, False | ||||||
|  |     ) | ||||||
|  |     w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||||
|  |         valid_env, meta_model, base_model, criterion, metric, args, logger, True, True | ||||||
|  | >>>>>>> d4b846a9717279c08f1264398972c00aa949a69f | ||||||
|     ) |     ) | ||||||
|     logger.log( |     logger.log( | ||||||
|         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( |         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||||
| @@ -288,6 +299,8 @@ def main(args): | |||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         { |         { | ||||||
|  |             "w_containers_care_adapt": w_containers_care_adapt, | ||||||
|  |             "w_containers_easy_adapt": w_containers_easy_adapt, | ||||||
|             "test_loss_adapt_v1": loss_adapt_v1, |             "test_loss_adapt_v1": loss_adapt_v1, | ||||||
|             "test_loss_adapt_v2": loss_adapt_v2, |             "test_loss_adapt_v2": loss_adapt_v2, | ||||||
|             "test_metric_adapt_v1": metric_adapt_v1, |             "test_metric_adapt_v1": metric_adapt_v1, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user