improve training comments

This commit is contained in:
nikitakaraevv 2023-07-19 04:44:38 -07:00
parent c6878420f5
commit a362b0f9b1

109
train.py
View File

@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(
@ -302,9 +287,7 @@ class Lite(LightningLite):
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
)
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
@ -551,17 +534,15 @@ class Lite(LightningLite):
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="restore checkpoint")
parser.add_argument("--ckpt_path", help="restore checkpoint")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
parser.add_argument("--ckpt_path", help="path to save checkpoints")
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument(
"--num_workers", type=int, default=6, help="left right consistency loss"
"--num_workers", type=int, default=6, help="number of dataloader workers"
)
parser.add_argument(
@ -578,20 +559,34 @@ if __name__ == "__main__":
"--evaluate_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
help="evaluate during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--save_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
help="save checkpoints during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--validate_at_start", action="store_true", help="use mixed precision"
"--validate_at_start",
action="store_true",
help="whether to run evaluation before training starts",
)
parser.add_argument(
"--save_freq",
type=int,
default=100,
help="frequency of trajectory visualization during training",
)
parser.add_argument(
"--traj_per_sample",
type=int,
default=768,
help="the number of trajectories to sample for training",
)
parser.add_argument(
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
)
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
parser.add_argument(
"--train_iters",
@ -605,49 +600,75 @@ if __name__ == "__main__":
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["things", "badja", "fastcapture"],
help="eval datasets.",
default=["things", "badja"],
help="what datasets to use for evaluation",
)
parser.add_argument(
"--remove_space_attn", action="store_true", help="use mixed precision"
"--remove_space_attn",
action="store_true",
help="remove space attention from CoTracker",
)
parser.add_argument(
"--dont_use_augs", action="store_true", help="use mixed precision"
"--dont_use_augs",
action="store_true",
help="don't apply augmentations during training",
)
parser.add_argument(
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
"--sample_vis_1st_frame",
action="store_true",
help="only sample trajectories with points visible on the first frame",
)
parser.add_argument(
"--sliding_window_len", type=int, default=8, help="use mixed precision"
"--sliding_window_len",
type=int,
default=8,
help="length of the CoTracker sliding window",
)
parser.add_argument(
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
"--updateformer_hidden_size",
type=int,
default=384,
help="hidden dimension of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
"--updateformer_num_heads",
type=int,
default=8,
help="number of heads of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
"--updateformer_space_depth",
type=int,
default=12,
help="number of group attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
"--updateformer_time_depth",
type=int,
default=12,
help="number of time attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--model_stride", type=int, default=8, help="use mixed precision"
"--model_stride",
type=int,
default=8,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--crop_size",
type=int,
nargs="+",
default=[384, 512],
help="use mixed precision",
help="crop videos to this resolution during training",
)
parser.add_argument(
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
"--eval_max_seq_len",
type=int,
default=1000,
help="maximum length of evaluation videos",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
@ -661,5 +682,5 @@ if __name__ == "__main__":
devices="auto",
accelerator="gpu",
precision=32,
num_nodes=4,
# num_nodes=4,
).run(args)