Correct the codes

This commit is contained in:
D-X-Y 2021-05-24 05:38:02 +00:00
parent 3a2af8e55a
commit 53b63d3924
4 changed files with 36 additions and 23 deletions

View File

@ -9,6 +9,12 @@ from tqdm import tqdm
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import ( from xautodl.procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
@ -38,28 +44,30 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args): def main(args):
logger, env_info, model_kwargs = lfna_setup(args) logger, model_kwargs = lfna_setup(args)
w_container_per_epoch = dict() env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
per_timestamp_time, start_time = AverageMeter(), time.time() per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(1, env_info["total"]): for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format( need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
) )
logger.log( logger.log(
"[{:}]".format(time_string()) "[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"]) + " [{:04d}/{:04d}]".format(idx, len(env))
+ " " + " "
+ need_time + need_time
) )
# train the same data # train the same data
historical_x = env_info["{:}-x".format(idx)] historical_x = future_x.to(args.device)
historical_y = env_info["{:}-y".format(idx)] historical_y = future_y.to(args.device)
# build model # build model
model = get_model(**model_kwargs) model = get_model(**model_kwargs)
print(model) model = model.to(args.device)
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
@ -93,7 +101,7 @@ def main(args):
metric = ComposeMetric(MSEMetric(), SaveMetric()) metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset( eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] future_x.to(args.device), future_y.to(args.device)
) )
eval_loader = torch.utils.data.DataLoader( eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
@ -101,23 +109,21 @@ def main(args):
results = basic_eval_fn(eval_loader, model, metric, logger) results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = ( log_str = (
"[{:}]".format(time_string()) "[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"]) + " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format( + " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"] train_results["mse"], results["mse"]
) )
) )
logger.log(log_str) logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env))
idx, env_info["total"] w_containers[idx] = model.get_w_container().no_grad_clone()
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint( save_checkpoint(
{ {
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"model": model, "model": model,
"index": idx, "index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)], "timestamp": future_time.item(),
}, },
save_path, save_path,
logger, logger,
@ -127,7 +133,7 @@ def main(args):
start_time = time.time() start_time = time.time()
save_checkpoint( save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch}, {"w_containers": w_containers},
logger.path(None) / "final-ckp.pth", logger.path(None) / "final-ckp.pth",
logger, logger,
) )
@ -174,6 +180,12 @@ if __name__ == "__main__":
default=300, default=300,
help="The total number of epochs.", help="The total number of epochs.",
) )
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument( parser.add_argument(
"--workers", "--workers",
type=int, type=int,

View File

@ -225,9 +225,11 @@ def main(args):
logger, model_kwargs = lfna_setup(args) logger, model_kwargs = lfna_setup(args)
train_env = get_synthetic_env(mode="train", version=args.env_version) train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version) all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env)) logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env)) logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env)) logger.log("The total enviornment: {:}".format(all_env))
base_model = get_model(**model_kwargs) base_model = get_model(**model_kwargs)
@ -237,14 +239,14 @@ def main(args):
shape_container = base_model.get_w_container().to_shape_container() shape_container = base_model.get_w_container().to_shape_container()
# pre-train the hypernetwork # pre-train the hypernetwork
timestamps = train_env.get_timestamp(None) timestamps = trainval_env.get_timestamp(None)
meta_model = LFNA_Meta( meta_model = LFNA_Meta(
shape_container, shape_container,
args.layer_dim, args.layer_dim,
args.time_dim, args.time_dim,
timestamps, timestamps,
seq_length=args.seq_length, seq_length=args.seq_length,
interval=train_env.time_interval, interval=trainval_env.time_interval,
) )
meta_model = meta_model.to(args.device) meta_model = meta_model.to(args.device)
@ -253,8 +255,7 @@ def main(args):
logger.log("The base-model is\n{:}".format(base_model)) logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model)) logger.log("The meta-model is\n{:}".format(meta_model))
# batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) pretrain_v2(base_model, meta_model, criterion, trainval_env, args, logger)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once # try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger) # online_evaluate(train_env, meta_model, base_model, criterion, args, logger)

View File

@ -22,12 +22,12 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
[mean_generator], [[std_generator]], (-2, 2) [mean_generator], [[std_generator]], (-2, 2)
) )
time_generator = TimeStamp( time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
) )
oracle_map = DynamicLinearFunc( oracle_map = DynamicLinearFunc(
params={ params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),
} }
) )
dynamic_env = SyntheticDEnv( dynamic_env = SyntheticDEnv(

View File

@ -28,7 +28,7 @@ class UnifiedSplit:
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"): elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :] self._indexes = all_indexes[num_of_train + num_of_valid :]
elif mode.lower() in ("trainval", "trainvalidation"): elif mode.lower() in ("trainval", "trainvalid", "trainvalidation"):
self._indexes = all_indexes[: num_of_train + num_of_valid] self._indexes = all_indexes[: num_of_train + num_of_valid]
else: else:
raise ValueError("Unkonwn mode of {:}".format(mode)) raise ValueError("Unkonwn mode of {:}".format(mode))