xautodl/lib/datasets/LandmarkDataset.py
2021-04-22 19:12:21 +08:00

302 lines
11 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from os import path as osp
from copy import deepcopy as copy
from tqdm import tqdm
import warnings, time, random, numpy as np
from pts_utils import generate_label_map
from xvision import denormalize_points
from xvision import identity2affine, solve2theta, affine2image
from .dataset_utils import pil_loader
from .landmark_utils import PointMeta2V
from .augmentation_utils import CutOut
import torch
import torch.utils.data as data
class LandmarkDataset(data.Dataset):
def __init__(
self,
transform,
sigma,
downsample,
heatmap_type,
shape,
use_gray,
mean_file,
data_indicator,
cache_images=None,
):
self.transform = transform
self.sigma = sigma
self.downsample = downsample
self.heatmap_type = heatmap_type
self.dataset_name = data_indicator
self.shape = shape # [H,W]
self.use_gray = use_gray
assert transform is not None, "transform : {:}".format(transform)
self.mean_file = mean_file
if mean_file is None:
self.mean_data = None
warnings.warn("LandmarkDataset initialized with mean_data = None")
else:
assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
self.mean_data = torch.load(mean_file)
self.reset()
self.cutout = None
self.cache_images = cache_images
print("The general dataset initialization done : {:}".format(self))
warnings.simplefilter("once")
def __repr__(self):
return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
name=self.__class__.__name__, **self.__dict__
)
def set_cutout(self, length):
if length is not None and length >= 1:
self.cutout = CutOut(int(length))
else:
self.cutout = None
def reset(self, num_pts=-1, boxid="default", only_pts=False):
self.NUM_PTS = num_pts
if only_pts:
return
self.length = 0
self.datas = []
self.labels = []
self.NormDistances = []
self.BOXID = boxid
if self.mean_data is None:
self.mean_face = None
else:
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
assert (self.mean_face >= -1).all() and (
self.mean_face <= 1
).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
# assert self.dataset_name is not None, 'The dataset name is None'
def __len__(self):
assert len(self.datas) == self.length, "The length is not correct : {}".format(
self.length
)
return self.length
def append(self, data, label, distance):
assert osp.isfile(data), "The image path is not a file : {:}".format(data)
self.datas.append(data)
self.labels.append(label)
self.NormDistances.append(distance)
self.length = self.length + 1
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
if reset:
self.reset(num_pts, boxindicator)
else:
assert (
self.NUM_PTS == num_pts and self.BOXID == boxindicator
), "The number of point is inconsistance : {:} vs {:}".format(
self.NUM_PTS, num_pts
)
if isinstance(file_lists, str):
file_lists = [file_lists]
samples = []
for idx, file_path in enumerate(file_lists):
print(
":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
)
xdata = torch.load(file_path)
if isinstance(xdata, list):
data = xdata # image or video dataset list
elif isinstance(xdata, dict):
data = xdata["datas"] # multi-view dataset list
else:
raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
samples = samples + data
# samples is a dict, where the key is the image-path and the value is the annotation
# each annotation is a dict, contains 'points' (3,num_pts), and various box
print("GeneralDataset-V2 : {:} samples".format(len(samples)))
# for index, annotation in enumerate(samples):
for index in tqdm(range(len(samples))):
annotation = samples[index]
image_path = annotation["current_frame"]
points, box = (
annotation["points"],
annotation["box-{:}".format(boxindicator)],
)
label = PointMeta2V(
self.NUM_PTS, points, box, image_path, self.dataset_name
)
if normalizeL is None:
normDistance = None
else:
normDistance = annotation["normalizeL-{:}".format(normalizeL)]
self.append(image_path, label, normDistance)
assert (
len(self.datas) == self.length
), "The length and the data is not right {} vs {}".format(
self.length, len(self.datas)
)
assert (
len(self.labels) == self.length
), "The length and the labels is not right {} vs {}".format(
self.length, len(self.labels)
)
assert (
len(self.NormDistances) == self.length
), "The length and the NormDistances is not right {} vs {}".format(
self.length, len(self.NormDistance)
)
print(
"Load data done for LandmarkDataset, which has {:} images.".format(
self.length
)
)
def __getitem__(self, index):
assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
if self.cache_images is not None and self.datas[index] in self.cache_images:
image = self.cache_images[self.datas[index]].clone()
else:
image = pil_loader(self.datas[index], self.use_gray)
target = self.labels[index].copy()
return self._process_(image, target, index)
def _process_(self, image, target, index):
# transform the image and points
image, target, theta = self.transform(image, target)
(C, H, W), (height, width) = image.size(), self.shape
# obtain the visiable indicator vector
if target.is_none():
nopoints = True
else:
nopoints = False
if index == -1:
__path = None
else:
__path = self.datas[index]
if isinstance(theta, list) or isinstance(theta, tuple):
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
[],
[],
[],
[],
[],
[],
)
for _theta in theta:
(
_affineImage,
_heatmaps,
_mask,
_norm_trans_points,
_theta,
_transpose_theta,
) = self.__process_affine(
image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
)
affineImage.append(_affineImage)
heatmaps.append(_heatmaps)
mask.append(_mask)
norm_trans_points.append(_norm_trans_points)
THETA.append(_theta)
transpose_theta.append(_transpose_theta)
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
torch.stack(affineImage),
torch.stack(heatmaps),
torch.stack(mask),
torch.stack(norm_trans_points),
torch.stack(THETA),
torch.stack(transpose_theta),
)
else:
(
affineImage,
heatmaps,
mask,
norm_trans_points,
THETA,
transpose_theta,
) = self.__process_affine(
image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
)
torch_index = torch.IntTensor([index])
torch_nopoints = torch.ByteTensor([nopoints])
torch_shape = torch.IntTensor([H, W])
return (
affineImage,
heatmaps,
mask,
norm_trans_points,
THETA,
transpose_theta,
torch_index,
torch_nopoints,
torch_shape,
)
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
image, target, theta = image.clone(), target.copy(), theta.clone()
(C, H, W), (height, width) = image.size(), self.shape
if nopoints: # do not have label
norm_trans_points = torch.zeros((3, self.NUM_PTS))
heatmaps = torch.zeros(
(self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
)
mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
transpose_theta = identity2affine(False)
else:
norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
norm_trans_points = apply_boundary(norm_trans_points)
real_trans_points = norm_trans_points.clone()
real_trans_points[:2, :] = denormalize_points(
self.shape, real_trans_points[:2, :]
)
heatmaps, mask = generate_label_map(
real_trans_points.numpy(),
height // self.downsample,
width // self.downsample,
self.sigma,
self.downsample,
nopoints,
self.heatmap_type,
) # H*W*C
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
torch.FloatTensor
)
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
if self.mean_face is None:
# warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
transpose_theta = identity2affine(False)
else:
if torch.sum(norm_trans_points[2, :] == 1) < 3:
warnings.warn(
"In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
aux_info
)
)
transpose_theta = identity2affine(False)
else:
transpose_theta = solve2theta(
norm_trans_points, self.mean_face.clone()
)
affineImage = affine2image(image, theta, self.shape)
if self.cutout is not None:
affineImage = self.cutout(affineImage)
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta