Update LFNA
This commit is contained in:
parent
8109ed166a
commit
df9917371e
@ -9,4 +9,4 @@
|
||||
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
|
||||
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
|
||||
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
|
||||
- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready
|
||||
- [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready
|
||||
|
@ -1,5 +1,5 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||
@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
|
||||
if meta_model.has_best(final_best_name):
|
||||
meta_model.load_best(final_best_name)
|
||||
logger.log("Directly load the best model from {:}".format(final_best_name))
|
||||
return
|
||||
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
left_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||
)
|
||||
total_meta_losses, total_match_losses = [], []
|
||||
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
|
||||
optimizer.zero_grad()
|
||||
for ibatch in range(args.meta_batch):
|
||||
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
||||
timestamps = meta_model.meta_timestamps[
|
||||
rand_index : rand_index + xenv.seq_length
|
||||
]
|
||||
meta_embeds = meta_model.super_meta_embed[
|
||||
rand_index : rand_index + xenv.seq_length
|
||||
]
|
||||
|
||||
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
||||
[seq_containers], time_embeds = meta_model(
|
||||
torch.unsqueeze(timestamps, dim=0), None
|
||||
)
|
||||
# performance loss
|
||||
losses = []
|
||||
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
||||
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
|
||||
args.device
|
||||
)
|
||||
# generate models one step ahead
|
||||
[seq_containers], time_embeds = meta_model(
|
||||
torch.unsqueeze(timestamps, dim=0), None
|
||||
)
|
||||
for container, inputs, targets in zip(
|
||||
seq_containers, seq_inputs, seq_targets
|
||||
):
|
||||
predictions = base_model.forward_with_container(inputs, container)
|
||||
loss = criterion(predictions, targets)
|
||||
losses.append(loss)
|
||||
meta_loss = torch.stack(losses).mean()
|
||||
match_loss = criterion(
|
||||
torch.squeeze(time_embeds, dim=0),
|
||||
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
|
||||
)
|
||||
total_meta_losses.append(meta_loss)
|
||||
total_meta_v1_losses.append(criterion(predictions, targets))
|
||||
# the matching loss
|
||||
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
|
||||
total_match_losses.append(match_loss)
|
||||
# generate models via memory
|
||||
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
|
||||
for container, inputs, targets in zip(
|
||||
seq_containers, seq_inputs, seq_targets
|
||||
):
|
||||
predictions = base_model.forward_with_container(inputs, container)
|
||||
total_meta_v2_losses.append(criterion(predictions, targets))
|
||||
with torch.no_grad():
|
||||
meta_std = torch.stack(total_meta_losses).std().item()
|
||||
final_meta_loss = torch.stack(total_meta_losses).mean()
|
||||
final_match_loss = torch.stack(total_match_losses).mean()
|
||||
total_loss = final_meta_loss + final_match_loss
|
||||
meta_std = torch.stack(total_meta_v1_losses).std().item()
|
||||
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
|
||||
meta_v2_loss = torch.stack(total_meta_v2_losses).mean()
|
||||
match_loss = torch.stack(total_match_losses).mean()
|
||||
total_loss = meta_v1_loss + meta_v2_loss + match_loss
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
# success
|
||||
success, best_score = meta_model.save_best(-total_loss.item())
|
||||
logger.log(
|
||||
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
|
||||
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format(
|
||||
time_string(),
|
||||
iepoch,
|
||||
args.epochs,
|
||||
total_loss.item(),
|
||||
meta_std,
|
||||
final_meta_loss.item(),
|
||||
final_match_loss.item(),
|
||||
meta_v1_loss.item(),
|
||||
meta_v2_loss.item(),
|
||||
match_loss.item(),
|
||||
)
|
||||
+ ", batch={:}".format(len(total_meta_losses))
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
||||
+ ", batch={:}".format(len(total_meta_v1_losses))
|
||||
+ ", success={:}, best={:.4f}".format(success, -best_score)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
|
||||
+ ", {:}".format(left_time)
|
||||
)
|
||||
if success:
|
||||
@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
meta_model.set_best_name(final_best_name)
|
||||
success, _ = meta_model.save_best(best_score + 1e-6)
|
||||
assert success
|
||||
logger.log("Save the best model into {:}".format(final_best_name))
|
||||
|
||||
|
||||
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
final_loss.item(),
|
||||
)
|
||||
+ ", batch={:}".format(len(losses))
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
||||
+ ", success={:}, best={:.4f}".format(success, -best_score)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh)
|
||||
+ " {:}".format(left_time)
|
||||
)
|
||||
if success:
|
||||
@ -277,6 +285,8 @@ def main(args):
|
||||
|
||||
logger.log("The base-model has {:} weights.".format(base_model.numel()))
|
||||
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
|
||||
logger.log("The base-model is\n{:}".format(base_model))
|
||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||
|
||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||
train_env.reset_max_seq_length(args.seq_length)
|
||||
@ -294,9 +304,10 @@ def main(args):
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
meta_model.parameters(),
|
||||
meta_model.get_parameters(True, True, False), # fix hypernet
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
amsgrad=True,
|
||||
@ -306,14 +317,10 @@ def main(args):
|
||||
milestones=[1, 2, 3, 4, 5],
|
||||
gamma=0.2,
|
||||
)
|
||||
logger.log("The base-model is\n{:}".format(base_model))
|
||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||
logger.log("The optimizer is\n{:}".format(optimizer))
|
||||
logger.log("The scheduler is\n{:}".format(lr_scheduler))
|
||||
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
|
||||
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
if logger.path("model").exists():
|
||||
ckp_data = torch.load(logger.path("model"))
|
||||
base_model.load_state_dict(ckp_data["base_model"])
|
||||
|
Loading…
Reference in New Issue
Block a user