improve training comments
This commit is contained in:
parent
c6878420f5
commit
a362b0f9b1
109
train.py
109
train.py
@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
|
||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||
|
||||
|
||||
# define the handler function
|
||||
# for training on a slurm cluster
|
||||
def sig_handler(signum, frame):
|
||||
print("caught signal", signum)
|
||||
print(socket.gethostname(), "USR1 signal caught.")
|
||||
# do other stuff to cleanup here
|
||||
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def term_handler(signum, frame):
|
||||
print("bypassing sigterm", flush=True)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
"""Create the optimizer and learning rate scheduler"""
|
||||
optimizer = optim.AdamW(
|
||||
@ -302,9 +287,7 @@ class Lite(LightningLite):
|
||||
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||
|
||||
if "tapvid_davis_first" in args.eval_datasets:
|
||||
data_root = os.path.join(
|
||||
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
|
||||
)
|
||||
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
|
||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
@ -551,17 +534,15 @@ class Lite(LightningLite):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGUSR1, sig_handler)
|
||||
signal.signal(signal.SIGTERM, term_handler)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
||||
parser.add_argument("--ckpt_path", help="restore checkpoint")
|
||||
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
|
||||
parser.add_argument("--ckpt_path", help="path to save checkpoints")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=6, help="left right consistency loss"
|
||||
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -578,20 +559,34 @@ if __name__ == "__main__":
|
||||
"--evaluate_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
help="evaluate during training after every n epochs, after every epoch by default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
help="save checkpoints during training after every n epochs, after every epoch by default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_at_start", action="store_true", help="use mixed precision"
|
||||
"--validate_at_start",
|
||||
action="store_true",
|
||||
help="whether to run evaluation before training starts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_freq",
|
||||
type=int,
|
||||
default=100,
|
||||
help="frequency of trajectory visualization during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--traj_per_sample",
|
||||
type=int,
|
||||
default=768,
|
||||
help="the number of trajectories to sample for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
|
||||
)
|
||||
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
|
||||
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
|
||||
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_iters",
|
||||
@ -605,49 +600,75 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--eval_datasets",
|
||||
nargs="+",
|
||||
default=["things", "badja", "fastcapture"],
|
||||
help="eval datasets.",
|
||||
default=["things", "badja"],
|
||||
help="what datasets to use for evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove_space_attn", action="store_true", help="use mixed precision"
|
||||
"--remove_space_attn",
|
||||
action="store_true",
|
||||
help="remove space attention from CoTracker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_use_augs", action="store_true", help="use mixed precision"
|
||||
"--dont_use_augs",
|
||||
action="store_true",
|
||||
help="don't apply augmentations during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
|
||||
"--sample_vis_1st_frame",
|
||||
action="store_true",
|
||||
help="only sample trajectories with points visible on the first frame",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sliding_window_len", type=int, default=8, help="use mixed precision"
|
||||
"--sliding_window_len",
|
||||
type=int,
|
||||
default=8,
|
||||
help="length of the CoTracker sliding window",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
|
||||
"--updateformer_hidden_size",
|
||||
type=int,
|
||||
default=384,
|
||||
help="hidden dimension of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
|
||||
"--updateformer_num_heads",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of heads of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
|
||||
"--updateformer_space_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of group attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
|
||||
"--updateformer_time_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of time attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_stride", type=int, default=8, help="use mixed precision"
|
||||
"--model_stride",
|
||||
type=int,
|
||||
default=8,
|
||||
help="stride of the CoTracker feature network",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop_size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[384, 512],
|
||||
help="use mixed precision",
|
||||
help="crop videos to this resolution during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
|
||||
"--eval_max_seq_len",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="maximum length of evaluation videos",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||
@ -661,5 +682,5 @@ if __name__ == "__main__":
|
||||
devices="auto",
|
||||
accelerator="gpu",
|
||||
precision=32,
|
||||
num_nodes=4,
|
||||
# num_nodes=4,
|
||||
).run(args)
|
||||
|
Loading…
Reference in New Issue
Block a user