Fix black
This commit is contained in:
parent
9b895bdf2e
commit
77c250c8fc
@ -73,7 +73,7 @@ def main(save_dir):
|
||||
|
||||
additional_xaxis = np.arange(-6, 6, 0.2)
|
||||
models = dict()
|
||||
|
||||
|
||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
xaxis_all = dataset[:, 0].numpy()
|
||||
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
||||
@ -84,15 +84,19 @@ def main(save_dir):
|
||||
# split the dataset
|
||||
indexes = list(range(xaxis_all.shape[0]))
|
||||
random.shuffle(indexes)
|
||||
train_indexes = indexes[:len(indexes)//2]
|
||||
valid_indexes = indexes[len(indexes)//2:]
|
||||
train_indexes = indexes[: len(indexes) // 2]
|
||||
valid_indexes = indexes[len(indexes) // 2 :]
|
||||
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
|
||||
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
|
||||
|
||||
|
||||
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
|
||||
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
|
||||
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
|
||||
print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss))
|
||||
print(
|
||||
"[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(
|
||||
idx, timestamp, train_loss, valid_loss
|
||||
)
|
||||
)
|
||||
|
||||
# the first plot
|
||||
scatter_list = []
|
||||
@ -114,10 +118,10 @@ def main(save_dir):
|
||||
"color": "r",
|
||||
"s": 10,
|
||||
"alpha": 0.5,
|
||||
"label": "MLP at now"
|
||||
"label": "MLP at now",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
draw_fig(save_dir, timestamp, scatter_list)
|
||||
print("Save all figures into {:}".format(save_dir))
|
||||
save_dir = save_dir.resolve()
|
||||
|
@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
|
||||
if not isinstance(module, SuperModule):
|
||||
warnings.warn(
|
||||
"Add {:} module, which is not SuperModule, into {:}".format(
|
||||
name, self.__class__.__name__
|
||||
"Add {:}:{:} module, which is not SuperModule, into {:}".format(
|
||||
name, module.__class__.__name__, self.__class__.__name__
|
||||
)
|
||||
+ "\n"
|
||||
+ "It may cause some functions invalid."
|
||||
|
Loading…
Reference in New Issue
Block a user