LFNA ok on the valid data
This commit is contained in:
parent
63a0361152
commit
b1064e5a60
@ -99,18 +99,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
base_model.eval()
|
base_model.eval()
|
||||||
_, [future_container], _ = meta_model(
|
_, [future_container], time_embeds = meta_model(
|
||||||
future_time.to(args.device).view(1, 1), None, True
|
future_time.to(args.device).view(1, 1), None, True
|
||||||
)
|
)
|
||||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
logger.log(
|
refine, post_refine_loss = meta_model.adapt(
|
||||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
|
||||||
idx, len(env), future_loss.item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
refine = meta_model.adapt(
|
|
||||||
base_model,
|
base_model,
|
||||||
criterion,
|
criterion,
|
||||||
future_time.item(),
|
future_time.item(),
|
||||||
@ -118,6 +113,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
future_y,
|
future_y,
|
||||||
args.refine_lr,
|
args.refine_lr,
|
||||||
args.refine_epochs,
|
args.refine_epochs,
|
||||||
|
{"param": time_embeds, "loss": future_loss.item()},
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||||
|
idx, len(env), future_loss.item()
|
||||||
|
)
|
||||||
|
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
|
||||||
)
|
)
|
||||||
meta_model.clear_fixed()
|
meta_model.clear_fixed()
|
||||||
meta_model.clear_learnt()
|
meta_model.clear_learnt()
|
||||||
@ -244,21 +246,6 @@ def main(args):
|
|||||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||||
|
|
||||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||||
# train_env.reset_max_seq_length(args.seq_length)
|
|
||||||
# valid_env.reset_max_seq_length(args.seq_length)
|
|
||||||
valid_env_loader = torch.utils.data.DataLoader(
|
|
||||||
valid_env,
|
|
||||||
batch_size=args.meta_batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=args.workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
train_env_loader = torch.utils.data.DataLoader(
|
|
||||||
train_env,
|
|
||||||
batch_sampler=batch_sampler,
|
|
||||||
num_workers=args.workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||||
|
|
||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
@ -507,7 +494,7 @@ if __name__ == "__main__":
|
|||||||
help="The learning rate for the optimizer, during refine",
|
help="The learning rate for the optimizer, during refine",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
|
"--refine_epochs", type=int, default=40, help="The final refine #epochs."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
|
@ -276,10 +276,10 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
def forward_candidate(self, input):
|
def forward_candidate(self, input):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs):
|
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
||||||
distance = self.get_closest_meta_distance(timestamp)
|
distance = self.get_closest_meta_distance(timestamp)
|
||||||
if distance + self._interval * 1e-2 <= self._interval:
|
if distance + self._interval * 1e-2 <= self._interval:
|
||||||
return False
|
return False, None
|
||||||
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
new_param = self.create_meta_embed()
|
new_param = self.create_meta_embed()
|
||||||
@ -290,7 +290,11 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
self.replace_append_learnt(timestamp, new_param)
|
self.replace_append_learnt(timestamp, new_param)
|
||||||
self.train()
|
self.train()
|
||||||
base_model.train()
|
base_model.train()
|
||||||
best_new_param, best_loss = None, 1e9
|
if init_info is not None:
|
||||||
|
best_loss = init_info["loss"]
|
||||||
|
new_param.data.copy_(init_info["param"].data)
|
||||||
|
else:
|
||||||
|
best_new_param, best_loss = None, 1e9
|
||||||
for iepoch in range(epochs):
|
for iepoch in range(epochs):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
|
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
|
||||||
@ -303,14 +307,14 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
||||||
if loss.item() < best_loss:
|
if meta_loss.item() < best_loss:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
best_loss = loss.item()
|
best_loss = meta_loss.item()
|
||||||
best_new_param = new_param.detach()
|
best_new_param = new_param.detach()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.replace_append_learnt(None, None)
|
self.replace_append_learnt(None, None)
|
||||||
self.append_fixed(timestamp, best_new_param)
|
self.append_fixed(timestamp, best_new_param)
|
||||||
return True
|
return True, meta_loss.item()
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
||||||
|
@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
self._cov_functors = cov_functors
|
self._cov_functors = cov_functors
|
||||||
|
|
||||||
self._oracle_map = None
|
self._oracle_map = None
|
||||||
self._seq_length = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def seq_length(self):
|
|
||||||
return self._seq_length
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def min_timestamp(self):
|
def min_timestamp(self):
|
||||||
@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
def timestamp_interval(self):
|
def timestamp_interval(self):
|
||||||
return self._timestamp_generator.interval
|
return self._timestamp_generator.interval
|
||||||
|
|
||||||
def random_timestamp(self):
|
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
|
||||||
return (
|
if min_timestamp is None:
|
||||||
random.random() * (self.max_timestamp - self.min_timestamp)
|
min_timestamp = self.min_timestamp
|
||||||
+ self.min_timestamp
|
if max_timestamp is None:
|
||||||
)
|
max_timestamp = self.max_timestamp
|
||||||
|
return random.random() * (max_timestamp - min_timestamp) + min_timestamp
|
||||||
def reset_max_seq_length(self, seq_length):
|
|
||||||
self._seq_length = seq_length
|
|
||||||
|
|
||||||
def get_timestamp(self, index):
|
def get_timestamp(self, index):
|
||||||
if index is None:
|
if index is None:
|
||||||
@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||||
index, timestamp = self._timestamp_generator[index]
|
index, timestamp = self._timestamp_generator[index]
|
||||||
if self._seq_length is None:
|
return self.__call__(timestamp)
|
||||||
return self.__call__(timestamp)
|
|
||||||
else:
|
|
||||||
noise = (
|
|
||||||
random.random() * self.timestamp_interval * self._timestamp_noise_scale
|
|
||||||
)
|
|
||||||
timestamps = [
|
|
||||||
timestamp + i * self.timestamp_interval + noise
|
|
||||||
for i in range(self._seq_length)
|
|
||||||
]
|
|
||||||
# xdata = [self.__call__(timestamp) for timestamp in timestamps]
|
|
||||||
# return zip_sequence(xdata)
|
|
||||||
return self.seq_call(timestamps)
|
|
||||||
|
|
||||||
def seq_call(self, timestamps):
|
def seq_call(self, timestamps):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user