34 lines
1003 B
Python
34 lines
1003 B
Python
import os
|
|
import logging
|
|
import torch
|
|
import numpy as np
|
|
import random
|
|
|
|
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
|
if not resume:
|
|
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
|
return state
|
|
elif not os.path.exists(ckpt_dir):
|
|
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
|
os.makedirs(os.path.dirname(ckpt_dir))
|
|
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
|
f"Returned the same state as input")
|
|
return state
|
|
else:
|
|
loaded_state = torch.load(ckpt_dir, map_location=device)
|
|
for k in state:
|
|
if k in ['optimizer', 'model', 'ema']:
|
|
state[k].load_state_dict(loaded_state[k])
|
|
else:
|
|
state[k] = loaded_state[k]
|
|
return state
|
|
|
|
|
|
def reset_seed(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|