Updates
This commit is contained in:
parent
da4b61f3ab
commit
c5788ba19c
@ -1,7 +1,6 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
import copy
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -294,7 +293,9 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
best_loss = init_info["loss"]
|
best_loss = init_info["loss"]
|
||||||
new_param.data.copy_(init_info["param"].data)
|
new_param.data.copy_(init_info["param"].data)
|
||||||
else:
|
else:
|
||||||
best_new_param, best_loss = None, 1e9
|
best_loss = 1e9
|
||||||
|
with torch.no_grad():
|
||||||
|
best_new_param = new_param.detach().clone()
|
||||||
for iepoch in range(epochs):
|
for iepoch in range(epochs):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
|
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
|
||||||
@ -310,7 +311,7 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
if meta_loss.item() < best_loss:
|
if meta_loss.item() < best_loss:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
best_loss = meta_loss.item()
|
best_loss = meta_loss.item()
|
||||||
best_new_param = new_param.detach()
|
best_new_param = new_param.detach().clone()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.replace_append_learnt(None, None)
|
self.replace_append_learnt(None, None)
|
||||||
self.append_fixed(timestamp, best_new_param)
|
self.append_fixed(timestamp, best_new_param)
|
||||||
|
Loading…
Reference in New Issue
Block a user