Fix black

This commit is contained in:
D-X-Y 2021-04-23 02:14:49 -07:00
parent 9b895bdf2e
commit 77c250c8fc
2 changed files with 13 additions and 9 deletions

View File

@ -73,7 +73,7 @@ def main(save_dir):
additional_xaxis = np.arange(-6, 6, 0.2) additional_xaxis = np.arange(-6, 6, 0.2)
models = dict() models = dict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[:, 0].numpy() xaxis_all = dataset[:, 0].numpy()
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) # xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
@ -84,15 +84,19 @@ def main(save_dir):
# split the dataset # split the dataset
indexes = list(range(xaxis_all.shape[0])) indexes = list(range(xaxis_all.shape[0]))
random.shuffle(indexes) random.shuffle(indexes)
train_indexes = indexes[:len(indexes)//2] train_indexes = indexes[: len(indexes) // 2]
valid_indexes = indexes[len(indexes)//2:] valid_indexes = indexes[len(indexes) // 2 :]
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss)) print(
"[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(
idx, timestamp, train_loss, valid_loss
)
)
# the first plot # the first plot
scatter_list = [] scatter_list = []
@ -114,10 +118,10 @@ def main(save_dir):
"color": "r", "color": "r",
"s": 10, "s": 10,
"alpha": 0.5, "alpha": 0.5,
"label": "MLP at now" "label": "MLP at now",
} }
) )
draw_fig(save_dir, timestamp, scatter_list) draw_fig(save_dir, timestamp, scatter_list)
print("Save all figures into {:}".format(save_dir)) print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve() save_dir = save_dir.resolve()

View File

@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module):
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
if not isinstance(module, SuperModule): if not isinstance(module, SuperModule):
warnings.warn( warnings.warn(
"Add {:} module, which is not SuperModule, into {:}".format( "Add {:}:{:} module, which is not SuperModule, into {:}".format(
name, self.__class__.__name__ name, module.__class__.__name__, self.__class__.__name__
) )
+ "\n" + "\n"
+ "It may cause some functions invalid." + "It may cause some functions invalid."