192 lines
8.9 KiB
Python
192 lines
8.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
from os import path as osp
|
|
from copy import deepcopy as copy
|
|
from tqdm import tqdm
|
|
import warnings, time, random, numpy as np
|
|
|
|
from pts_utils import generate_label_map
|
|
from xvision import denormalize_points
|
|
from xvision import identity2affine, solve2theta, affine2image
|
|
from .dataset_utils import pil_loader
|
|
from .landmark_utils import PointMeta2V
|
|
from .augmentation_utils import CutOut
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
|
|
class LandmarkDataset(data.Dataset):
|
|
|
|
def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):
|
|
|
|
self.transform = transform
|
|
self.sigma = sigma
|
|
self.downsample = downsample
|
|
self.heatmap_type = heatmap_type
|
|
self.dataset_name = data_indicator
|
|
self.shape = shape # [H,W]
|
|
self.use_gray = use_gray
|
|
assert transform is not None, 'transform : {:}'.format(transform)
|
|
self.mean_file = mean_file
|
|
if mean_file is None:
|
|
self.mean_data = None
|
|
warnings.warn('LandmarkDataset initialized with mean_data = None')
|
|
else:
|
|
assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
|
|
self.mean_data = torch.load(mean_file)
|
|
self.reset()
|
|
self.cutout = None
|
|
self.cache_images = cache_images
|
|
print ('The general dataset initialization done : {:}'.format(self))
|
|
warnings.simplefilter( 'once' )
|
|
|
|
|
|
def __repr__(self):
|
|
return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
|
|
|
|
def set_cutout(self, length):
|
|
if length is not None and length >= 1:
|
|
self.cutout = CutOut( int(length) )
|
|
else: self.cutout = None
|
|
|
|
|
|
def reset(self, num_pts=-1, boxid='default', only_pts=False):
|
|
self.NUM_PTS = num_pts
|
|
if only_pts: return
|
|
self.length = 0
|
|
self.datas = []
|
|
self.labels = []
|
|
self.NormDistances = []
|
|
self.BOXID = boxid
|
|
if self.mean_data is None:
|
|
self.mean_face = None
|
|
else:
|
|
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
|
|
assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
|
|
#assert self.dataset_name is not None, 'The dataset name is None'
|
|
|
|
|
|
def __len__(self):
|
|
assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
|
|
return self.length
|
|
|
|
|
|
def append(self, data, label, distance):
|
|
assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
|
|
self.datas.append( data ) ; self.labels.append( label )
|
|
self.NormDistances.append( distance )
|
|
self.length = self.length + 1
|
|
|
|
|
|
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
|
|
if reset: self.reset(num_pts, boxindicator)
|
|
else : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
|
|
if isinstance(file_lists, str): file_lists = [file_lists]
|
|
samples = []
|
|
for idx, file_path in enumerate(file_lists):
|
|
print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
|
|
xdata = torch.load(file_path)
|
|
if isinstance(xdata, list) : data = xdata # image or video dataset list
|
|
elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
|
|
else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
|
|
samples = samples + data
|
|
# samples is a dict, where the key is the image-path and the value is the annotation
|
|
# each annotation is a dict, contains 'points' (3,num_pts), and various box
|
|
print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))
|
|
|
|
#for index, annotation in enumerate(samples):
|
|
for index in tqdm( range( len(samples) ) ):
|
|
annotation = samples[index]
|
|
image_path = annotation['current_frame']
|
|
points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
|
|
label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
|
|
if normalizeL is None: normDistance = None
|
|
else : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
|
|
self.append(image_path, label, normDistance)
|
|
|
|
assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
|
|
assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
|
|
assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
|
|
print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))
|
|
|
|
|
|
def __getitem__(self, index):
|
|
assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
|
|
if self.cache_images is not None and self.datas[index] in self.cache_images:
|
|
image = self.cache_images[ self.datas[index] ].clone()
|
|
else:
|
|
image = pil_loader(self.datas[index], self.use_gray)
|
|
target = self.labels[index].copy()
|
|
return self._process_(image, target, index)
|
|
|
|
|
|
def _process_(self, image, target, index):
|
|
|
|
# transform the image and points
|
|
image, target, theta = self.transform(image, target)
|
|
(C, H, W), (height, width) = image.size(), self.shape
|
|
|
|
# obtain the visiable indicator vector
|
|
if target.is_none(): nopoints = True
|
|
else : nopoints = False
|
|
if index == -1: __path = None
|
|
else : __path = self.datas[index]
|
|
if isinstance(theta, list) or isinstance(theta, tuple):
|
|
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
|
|
for _theta in theta:
|
|
_affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
|
|
= self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
|
|
affineImage.append(_affineImage)
|
|
heatmaps.append(_heatmaps)
|
|
mask.append(_mask)
|
|
norm_trans_points.append(_norm_trans_points)
|
|
THETA.append(_theta)
|
|
transpose_theta.append(_transpose_theta)
|
|
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
|
|
torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
|
|
else:
|
|
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))
|
|
|
|
torch_index = torch.IntTensor([index])
|
|
torch_nopoints = torch.ByteTensor( [ nopoints ] )
|
|
torch_shape = torch.IntTensor([H,W])
|
|
|
|
return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape
|
|
|
|
|
|
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
|
|
image, target, theta = image.clone(), target.copy(), theta.clone()
|
|
(C, H, W), (height, width) = image.size(), self.shape
|
|
if nopoints: # do not have label
|
|
norm_trans_points = torch.zeros((3, self.NUM_PTS))
|
|
heatmaps = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
|
|
mask = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
|
|
transpose_theta = identity2affine(False)
|
|
else:
|
|
norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
|
|
norm_trans_points = apply_boundary(norm_trans_points)
|
|
real_trans_points = norm_trans_points.clone()
|
|
real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
|
|
heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
|
|
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
|
|
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
|
|
if self.mean_face is None:
|
|
#warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
|
|
transpose_theta = identity2affine(False)
|
|
else:
|
|
if torch.sum(norm_trans_points[2,:] == 1) < 3:
|
|
warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
|
|
transpose_theta = identity2affine(False)
|
|
else:
|
|
transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())
|
|
|
|
affineImage = affine2image(image, theta, self.shape)
|
|
if self.cutout is not None: affineImage = self.cutout( affineImage )
|
|
|
|
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
|