Fix black
This commit is contained in:
parent
9b895bdf2e
commit
77c250c8fc
@ -73,7 +73,7 @@ def main(save_dir):
|
|||||||
|
|
||||||
additional_xaxis = np.arange(-6, 6, 0.2)
|
additional_xaxis = np.arange(-6, 6, 0.2)
|
||||||
models = dict()
|
models = dict()
|
||||||
|
|
||||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
xaxis_all = dataset[:, 0].numpy()
|
xaxis_all = dataset[:, 0].numpy()
|
||||||
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
||||||
@ -84,15 +84,19 @@ def main(save_dir):
|
|||||||
# split the dataset
|
# split the dataset
|
||||||
indexes = list(range(xaxis_all.shape[0]))
|
indexes = list(range(xaxis_all.shape[0]))
|
||||||
random.shuffle(indexes)
|
random.shuffle(indexes)
|
||||||
train_indexes = indexes[:len(indexes)//2]
|
train_indexes = indexes[: len(indexes) // 2]
|
||||||
valid_indexes = indexes[len(indexes)//2:]
|
valid_indexes = indexes[len(indexes) // 2 :]
|
||||||
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
|
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
|
||||||
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
|
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
|
||||||
|
|
||||||
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
|
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
|
||||||
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
|
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
|
||||||
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
|
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
|
||||||
print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss))
|
print(
|
||||||
|
"[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(
|
||||||
|
idx, timestamp, train_loss, valid_loss
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# the first plot
|
# the first plot
|
||||||
scatter_list = []
|
scatter_list = []
|
||||||
@ -114,10 +118,10 @@ def main(save_dir):
|
|||||||
"color": "r",
|
"color": "r",
|
||||||
"s": 10,
|
"s": 10,
|
||||||
"alpha": 0.5,
|
"alpha": 0.5,
|
||||||
"label": "MLP at now"
|
"label": "MLP at now",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
draw_fig(save_dir, timestamp, scatter_list)
|
draw_fig(save_dir, timestamp, scatter_list)
|
||||||
print("Save all figures into {:}".format(save_dir))
|
print("Save all figures into {:}".format(save_dir))
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
|
@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
|
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
|
||||||
if not isinstance(module, SuperModule):
|
if not isinstance(module, SuperModule):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Add {:} module, which is not SuperModule, into {:}".format(
|
"Add {:}:{:} module, which is not SuperModule, into {:}".format(
|
||||||
name, self.__class__.__name__
|
name, module.__class__.__name__, self.__class__.__name__
|
||||||
)
|
)
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ "It may cause some functions invalid."
|
+ "It may cause some functions invalid."
|
||||||
|
Loading…
Reference in New Issue
Block a user