Fix bugs
This commit is contained in:
parent
5dd75696c9
commit
418be43566
@ -204,11 +204,13 @@ def main(args):
|
|||||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||||
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
|
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
|
||||||
|
test_env = get_synthetic_env(mode="test", version=args.env_version)
|
||||||
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||||
logger.log("The training enviornment: {:}".format(train_env))
|
logger.log("The training enviornment: {:}".format(train_env))
|
||||||
logger.log("The validation enviornment: {:}".format(valid_env))
|
logger.log("The validation enviornment: {:}".format(valid_env))
|
||||||
logger.log("The trainval enviornment: {:}".format(trainval_env))
|
logger.log("The trainval enviornment: {:}".format(trainval_env))
|
||||||
logger.log("The total enviornment: {:}".format(all_env))
|
logger.log("The total enviornment: {:}".format(all_env))
|
||||||
|
logger.log("The test enviornment: {:}".format(test_env))
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
config=dict(model_type="norm_mlp"),
|
config=dict(model_type="norm_mlp"),
|
||||||
input_dim=all_env.meta_info["input_dim"],
|
input_dim=all_env.meta_info["input_dim"],
|
||||||
@ -268,10 +270,10 @@ def main(args):
|
|||||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||||
"""
|
"""
|
||||||
_, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
_, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
||||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, False
|
test_env, meta_model, base_model, criterion, metric, args, logger, False, False
|
||||||
)
|
)
|
||||||
_, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
_, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
||||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, True
|
test_env, meta_model, base_model, criterion, metric, args, logger, False, True
|
||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||||
|
Loading…
Reference in New Issue
Block a user