Correct the codes
This commit is contained in:
parent
3a2af8e55a
commit
53b63d3924
@ -9,6 +9,12 @@ from tqdm import tqdm
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
|
print("LIB-DIR: {:}".format(lib_dir))
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
|
||||||
from xautodl.procedures import (
|
from xautodl.procedures import (
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
@ -38,28 +44,30 @@ def subsample(historical_x, historical_y, maxn=10000):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
logger, env_info, model_kwargs = lfna_setup(args)
|
logger, model_kwargs = lfna_setup(args)
|
||||||
|
|
||||||
w_container_per_epoch = dict()
|
env = get_synthetic_env(mode=None, version=args.env_version)
|
||||||
|
logger.log("The total enviornment: {:}".format(env))
|
||||||
|
w_containers = dict()
|
||||||
|
|
||||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||||
for idx in range(1, env_info["total"]):
|
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||||
|
|
||||||
need_time = "Time Left: {:}".format(
|
need_time = "Time Left: {:}".format(
|
||||||
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
|
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
|
||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"[{:}]".format(time_string())
|
"[{:}]".format(time_string())
|
||||||
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
|
+ " [{:04d}/{:04d}]".format(idx, len(env))
|
||||||
+ " "
|
+ " "
|
||||||
+ need_time
|
+ need_time
|
||||||
)
|
)
|
||||||
# train the same data
|
# train the same data
|
||||||
historical_x = env_info["{:}-x".format(idx)]
|
historical_x = future_x.to(args.device)
|
||||||
historical_y = env_info["{:}-y".format(idx)]
|
historical_y = future_y.to(args.device)
|
||||||
# build model
|
# build model
|
||||||
model = get_model(**model_kwargs)
|
model = get_model(**model_kwargs)
|
||||||
print(model)
|
model = model.to(args.device)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
@ -93,7 +101,7 @@ def main(args):
|
|||||||
|
|
||||||
metric = ComposeMetric(MSEMetric(), SaveMetric())
|
metric = ComposeMetric(MSEMetric(), SaveMetric())
|
||||||
eval_dataset = torch.utils.data.TensorDataset(
|
eval_dataset = torch.utils.data.TensorDataset(
|
||||||
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
|
future_x.to(args.device), future_y.to(args.device)
|
||||||
)
|
)
|
||||||
eval_loader = torch.utils.data.DataLoader(
|
eval_loader = torch.utils.data.DataLoader(
|
||||||
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
|
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
|
||||||
@ -101,23 +109,21 @@ def main(args):
|
|||||||
results = basic_eval_fn(eval_loader, model, metric, logger)
|
results = basic_eval_fn(eval_loader, model, metric, logger)
|
||||||
log_str = (
|
log_str = (
|
||||||
"[{:}]".format(time_string())
|
"[{:}]".format(time_string())
|
||||||
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
|
+ " [{:04d}/{:04d}]".format(idx, len(env))
|
||||||
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
|
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
|
||||||
train_results["mse"], results["mse"]
|
train_results["mse"], results["mse"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.log(log_str)
|
logger.log(log_str)
|
||||||
|
|
||||||
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env))
|
||||||
idx, env_info["total"]
|
w_containers[idx] = model.get_w_container().no_grad_clone()
|
||||||
)
|
|
||||||
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{
|
{
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"model": model,
|
"model": model,
|
||||||
"index": idx,
|
"index": idx,
|
||||||
"timestamp": env_info["{:}-timestamp".format(idx)],
|
"timestamp": future_time.item(),
|
||||||
},
|
},
|
||||||
save_path,
|
save_path,
|
||||||
logger,
|
logger,
|
||||||
@ -127,7 +133,7 @@ def main(args):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{"w_container_per_epoch": w_container_per_epoch},
|
{"w_containers": w_containers},
|
||||||
logger.path(None) / "final-ckp.pth",
|
logger.path(None) / "final-ckp.pth",
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@ -174,6 +180,12 @@ if __name__ == "__main__":
|
|||||||
default=300,
|
default=300,
|
||||||
help="The total number of epochs.",
|
help="The total number of epochs.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--workers",
|
"--workers",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -225,9 +225,11 @@ def main(args):
|
|||||||
logger, model_kwargs = lfna_setup(args)
|
logger, model_kwargs = lfna_setup(args)
|
||||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||||
|
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
|
||||||
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||||
logger.log("The training enviornment: {:}".format(train_env))
|
logger.log("The training enviornment: {:}".format(train_env))
|
||||||
logger.log("The validation enviornment: {:}".format(valid_env))
|
logger.log("The validation enviornment: {:}".format(valid_env))
|
||||||
|
logger.log("The trainval enviornment: {:}".format(trainval_env))
|
||||||
logger.log("The total enviornment: {:}".format(all_env))
|
logger.log("The total enviornment: {:}".format(all_env))
|
||||||
|
|
||||||
base_model = get_model(**model_kwargs)
|
base_model = get_model(**model_kwargs)
|
||||||
@ -237,14 +239,14 @@ def main(args):
|
|||||||
shape_container = base_model.get_w_container().to_shape_container()
|
shape_container = base_model.get_w_container().to_shape_container()
|
||||||
|
|
||||||
# pre-train the hypernetwork
|
# pre-train the hypernetwork
|
||||||
timestamps = train_env.get_timestamp(None)
|
timestamps = trainval_env.get_timestamp(None)
|
||||||
meta_model = LFNA_Meta(
|
meta_model = LFNA_Meta(
|
||||||
shape_container,
|
shape_container,
|
||||||
args.layer_dim,
|
args.layer_dim,
|
||||||
args.time_dim,
|
args.time_dim,
|
||||||
timestamps,
|
timestamps,
|
||||||
seq_length=args.seq_length,
|
seq_length=args.seq_length,
|
||||||
interval=train_env.time_interval,
|
interval=trainval_env.time_interval,
|
||||||
)
|
)
|
||||||
meta_model = meta_model.to(args.device)
|
meta_model = meta_model.to(args.device)
|
||||||
|
|
||||||
@ -253,8 +255,7 @@ def main(args):
|
|||||||
logger.log("The base-model is\n{:}".format(base_model))
|
logger.log("The base-model is\n{:}".format(base_model))
|
||||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||||
|
|
||||||
# batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
pretrain_v2(base_model, meta_model, criterion, trainval_env, args, logger)
|
||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
|
||||||
|
|
||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||||
|
@ -22,12 +22,12 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
|||||||
[mean_generator], [[std_generator]], (-2, 2)
|
[mean_generator], [[std_generator]], (-2, 2)
|
||||||
)
|
)
|
||||||
time_generator = TimeStamp(
|
time_generator = TimeStamp(
|
||||||
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
|
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
|
||||||
)
|
)
|
||||||
oracle_map = DynamicLinearFunc(
|
oracle_map = DynamicLinearFunc(
|
||||||
params={
|
params={
|
||||||
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
|
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
|
||||||
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
|
1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
dynamic_env = SyntheticDEnv(
|
dynamic_env = SyntheticDEnv(
|
||||||
|
@ -28,7 +28,7 @@ class UnifiedSplit:
|
|||||||
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
||||||
elif mode.lower() in ("test", "testing"):
|
elif mode.lower() in ("test", "testing"):
|
||||||
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
||||||
elif mode.lower() in ("trainval", "trainvalidation"):
|
elif mode.lower() in ("trainval", "trainvalid", "trainvalidation"):
|
||||||
self._indexes = all_indexes[: num_of_train + num_of_valid]
|
self._indexes = all_indexes[: num_of_train + num_of_valid]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unkonwn mode of {:}".format(mode))
|
raise ValueError("Unkonwn mode of {:}".format(mode))
|
||||||
|
Loading…
Reference in New Issue
Block a user