Update LFNA

This commit is contained in:
D-X-Y 2021-05-22 23:49:09 +08:00
parent 8109ed166a
commit df9917371e
2 changed files with 41 additions and 34 deletions

View File

@ -9,4 +9,4 @@
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready
- [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready

View File

@ -1,5 +1,5 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/LFNA/lfna.py --env_version v1 --workers 0
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name)
logger.log("Directly load the best model from {:}".format(final_best_name))
return
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
total_meta_losses, total_match_losses = [], []
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
timestamps = meta_model.meta_timestamps[
rand_index : rand_index + xenv.seq_length
]
meta_embeds = meta_model.super_meta_embed[
rand_index : rand_index + xenv.seq_length
]
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0), None
)
# performance loss
losses = []
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
args.device
)
# generate models one step ahead
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0), None
)
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
predictions = base_model.forward_with_container(inputs, container)
loss = criterion(predictions, targets)
losses.append(loss)
meta_loss = torch.stack(losses).mean()
match_loss = criterion(
torch.squeeze(time_embeds, dim=0),
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
)
total_meta_losses.append(meta_loss)
total_meta_v1_losses.append(criterion(predictions, targets))
# the matching loss
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
total_match_losses.append(match_loss)
# generate models via memory
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
predictions = base_model.forward_with_container(inputs, container)
total_meta_v2_losses.append(criterion(predictions, targets))
with torch.no_grad():
meta_std = torch.stack(total_meta_losses).std().item()
final_meta_loss = torch.stack(total_meta_losses).mean()
final_match_loss = torch.stack(total_match_losses).mean()
total_loss = final_meta_loss + final_match_loss
meta_std = torch.stack(total_meta_v1_losses).std().item()
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
meta_v2_loss = torch.stack(total_meta_v2_losses).mean()
match_loss = torch.stack(total_match_losses).mean()
total_loss = meta_v1_loss + meta_v2_loss + match_loss
total_loss.backward()
optimizer.step()
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
final_meta_loss.item(),
final_match_loss.item(),
meta_v1_loss.item(),
meta_v2_loss.item(),
match_loss.item(),
)
+ ", batch={:}".format(len(total_meta_losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
+ ", batch={:}".format(len(total_meta_v1_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
+ ", {:}".format(left_time)
)
if success:
@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6)
assert success
logger.log("Save the best model into {:}".format(final_best_name))
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
final_loss.item(),
)
+ ", batch={:}".format(len(losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
+ " {:}".format(left_time)
)
if success:
@ -277,6 +285,8 @@ def main(args):
logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
train_env.reset_max_seq_length(args.seq_length)
@ -294,9 +304,10 @@ def main(args):
num_workers=args.workers,
pin_memory=True,
)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
optimizer = torch.optim.Adam(
meta_model.parameters(),
meta_model.get_parameters(True, True, False), # fix hypernet
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
@ -306,14 +317,10 @@ def main(args):
milestones=[1, 2, 3, 4, 5],
gamma=0.2,
)
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
logger.log("The optimizer is\n{:}".format(optimizer))
logger.log("The scheduler is\n{:}".format(lr_scheduler))
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
if logger.path("model").exists():
ckp_data = torch.load(logger.path("model"))
base_model.load_state_dict(ckp_data["base_model"])