220 lines
6.8 KiB
Python
220 lines
6.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
import copy, math, torch, numpy as np
|
|
from xvision import normalize_points
|
|
from xvision import denormalize_points
|
|
|
|
|
|
class PointMeta:
|
|
# points : 3 x num_pts (x, y, oculusion)
|
|
# image_size: original [width, height]
|
|
def __init__(self, num_point, points, box, image_path, dataset_name):
|
|
|
|
self.num_point = num_point
|
|
if box is not None:
|
|
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
|
|
self.box = torch.Tensor(box)
|
|
else:
|
|
self.box = None
|
|
if points is None:
|
|
self.points = points
|
|
else:
|
|
assert (
|
|
len(points.shape) == 2
|
|
and points.shape[0] == 3
|
|
and points.shape[1] == self.num_point
|
|
), "The shape of point is not right : {}".format(points)
|
|
self.points = torch.Tensor(points.copy())
|
|
self.image_path = image_path
|
|
self.datasets = dataset_name
|
|
|
|
def __repr__(self):
|
|
if self.box is None:
|
|
boxstr = "None"
|
|
else:
|
|
boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
|
|
return (
|
|
"{name}(points={num_point}, ".format(
|
|
name=self.__class__.__name__, **self.__dict__
|
|
)
|
|
+ boxstr
|
|
+ ")"
|
|
)
|
|
|
|
def get_box(self, return_diagonal=False):
|
|
if self.box is None:
|
|
return None
|
|
if not return_diagonal:
|
|
return self.box.clone()
|
|
else:
|
|
W = (self.box[2] - self.box[0]).item()
|
|
H = (self.box[3] - self.box[1]).item()
|
|
return math.sqrt(H * H + W * W)
|
|
|
|
def get_points(self, ignore_indicator=False):
|
|
if ignore_indicator:
|
|
last = 2
|
|
else:
|
|
last = 3
|
|
if self.points is not None:
|
|
return self.points.clone()[:last, :]
|
|
else:
|
|
return torch.zeros((last, self.num_point))
|
|
|
|
def is_none(self):
|
|
# assert self.box is not None, 'The box should not be None'
|
|
return self.points is None
|
|
# if self.box is None: return True
|
|
# else : return self.points is None
|
|
|
|
def copy(self):
|
|
return copy.deepcopy(self)
|
|
|
|
def visiable_pts_num(self):
|
|
with torch.no_grad():
|
|
ans = self.points[2, :] > 0
|
|
ans = torch.sum(ans)
|
|
ans = ans.item()
|
|
return ans
|
|
|
|
def special_fun(self, indicator):
|
|
if (
|
|
indicator == "68to49"
|
|
): # For 300W or 300VW, convert the default 68 points to 49 points.
|
|
assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
|
|
self.num_point
|
|
)
|
|
self.num_point = 49
|
|
out = torch.ones((68), dtype=torch.uint8)
|
|
out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
|
|
if self.points is not None:
|
|
self.points = self.points.clone()[:, out]
|
|
else:
|
|
raise ValueError("Invalid indicator : {:}".format(indicator))
|
|
|
|
def apply_horizontal_flip(self):
|
|
# self.points[0, :] = width - self.points[0, :] - 1
|
|
# Mugsy spefic or Synthetic
|
|
if self.datasets.startswith("HandsyROT"):
|
|
ori = np.array(list(range(0, 42)))
|
|
pos = np.array(list(range(21, 42)) + list(range(0, 21)))
|
|
self.points[:, pos] = self.points[:, ori]
|
|
elif self.datasets.startswith("face68"):
|
|
ori = np.array(list(range(0, 68)))
|
|
pos = (
|
|
np.array(
|
|
[
|
|
17,
|
|
16,
|
|
15,
|
|
14,
|
|
13,
|
|
12,
|
|
11,
|
|
10,
|
|
9,
|
|
8,
|
|
7,
|
|
6,
|
|
5,
|
|
4,
|
|
3,
|
|
2,
|
|
1,
|
|
27,
|
|
26,
|
|
25,
|
|
24,
|
|
23,
|
|
22,
|
|
21,
|
|
20,
|
|
19,
|
|
18,
|
|
28,
|
|
29,
|
|
30,
|
|
31,
|
|
36,
|
|
35,
|
|
34,
|
|
33,
|
|
32,
|
|
46,
|
|
45,
|
|
44,
|
|
43,
|
|
48,
|
|
47,
|
|
40,
|
|
39,
|
|
38,
|
|
37,
|
|
42,
|
|
41,
|
|
55,
|
|
54,
|
|
53,
|
|
52,
|
|
51,
|
|
50,
|
|
49,
|
|
60,
|
|
59,
|
|
58,
|
|
57,
|
|
56,
|
|
65,
|
|
64,
|
|
63,
|
|
62,
|
|
61,
|
|
68,
|
|
67,
|
|
66,
|
|
]
|
|
)
|
|
- 1
|
|
)
|
|
self.points[:, ori] = self.points[:, pos]
|
|
else:
|
|
raise ValueError("Does not support {:}".format(self.datasets))
|
|
|
|
|
|
# shape = (H,W)
|
|
def apply_affine2point(points, theta, shape):
|
|
assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
|
|
with torch.no_grad():
|
|
ok_points = points[2, :] == 1
|
|
assert torch.sum(ok_points).item() > 0, "there is no visiable point"
|
|
points[:2, :] = normalize_points(shape, points[:2, :])
|
|
|
|
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
|
|
|
|
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
|
|
|
|
norm_trans_points[:, ok_points] = trans_points
|
|
|
|
return norm_trans_points
|
|
|
|
|
|
def apply_boundary(norm_trans_points):
|
|
with torch.no_grad():
|
|
norm_trans_points = norm_trans_points.clone()
|
|
oks = torch.stack(
|
|
(
|
|
norm_trans_points[0] > -1,
|
|
norm_trans_points[0] < 1,
|
|
norm_trans_points[1] > -1,
|
|
norm_trans_points[1] < 1,
|
|
norm_trans_points[2] > 0,
|
|
)
|
|
)
|
|
oks = torch.sum(oks, dim=0) == 5
|
|
norm_trans_points[2, :] = oks
|
|
return norm_trans_points
|