autodl-projects/lib/datasets/landmark_utils/point_meta.py
2021-04-22 19:12:21 +08:00

220 lines
6.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy, math, torch, numpy as np
from xvision import normalize_points
from xvision import denormalize_points
class PointMeta:
# points : 3 x num_pts (x, y, oculusion)
# image_size: original [width, height]
def __init__(self, num_point, points, box, image_path, dataset_name):
self.num_point = num_point
if box is not None:
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
self.box = torch.Tensor(box)
else:
self.box = None
if points is None:
self.points = points
else:
assert (
len(points.shape) == 2
and points.shape[0] == 3
and points.shape[1] == self.num_point
), "The shape of point is not right : {}".format(points)
self.points = torch.Tensor(points.copy())
self.image_path = image_path
self.datasets = dataset_name
def __repr__(self):
if self.box is None:
boxstr = "None"
else:
boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
return (
"{name}(points={num_point}, ".format(
name=self.__class__.__name__, **self.__dict__
)
+ boxstr
+ ")"
)
def get_box(self, return_diagonal=False):
if self.box is None:
return None
if not return_diagonal:
return self.box.clone()
else:
W = (self.box[2] - self.box[0]).item()
H = (self.box[3] - self.box[1]).item()
return math.sqrt(H * H + W * W)
def get_points(self, ignore_indicator=False):
if ignore_indicator:
last = 2
else:
last = 3
if self.points is not None:
return self.points.clone()[:last, :]
else:
return torch.zeros((last, self.num_point))
def is_none(self):
# assert self.box is not None, 'The box should not be None'
return self.points is None
# if self.box is None: return True
# else : return self.points is None
def copy(self):
return copy.deepcopy(self)
def visiable_pts_num(self):
with torch.no_grad():
ans = self.points[2, :] > 0
ans = torch.sum(ans)
ans = ans.item()
return ans
def special_fun(self, indicator):
if (
indicator == "68to49"
): # For 300W or 300VW, convert the default 68 points to 49 points.
assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
self.num_point
)
self.num_point = 49
out = torch.ones((68), dtype=torch.uint8)
out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
if self.points is not None:
self.points = self.points.clone()[:, out]
else:
raise ValueError("Invalid indicator : {:}".format(indicator))
def apply_horizontal_flip(self):
# self.points[0, :] = width - self.points[0, :] - 1
# Mugsy spefic or Synthetic
if self.datasets.startswith("HandsyROT"):
ori = np.array(list(range(0, 42)))
pos = np.array(list(range(21, 42)) + list(range(0, 21)))
self.points[:, pos] = self.points[:, ori]
elif self.datasets.startswith("face68"):
ori = np.array(list(range(0, 68)))
pos = (
np.array(
[
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
27,
26,
25,
24,
23,
22,
21,
20,
19,
18,
28,
29,
30,
31,
36,
35,
34,
33,
32,
46,
45,
44,
43,
48,
47,
40,
39,
38,
37,
42,
41,
55,
54,
53,
52,
51,
50,
49,
60,
59,
58,
57,
56,
65,
64,
63,
62,
61,
68,
67,
66,
]
)
- 1
)
self.points[:, ori] = self.points[:, pos]
else:
raise ValueError("Does not support {:}".format(self.datasets))
# shape = (H,W)
def apply_affine2point(points, theta, shape):
assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
with torch.no_grad():
ok_points = points[2, :] == 1
assert torch.sum(ok_points).item() > 0, "there is no visiable point"
points[:2, :] = normalize_points(shape, points[:2, :])
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
norm_trans_points[:, ok_points] = trans_points
return norm_trans_points
def apply_boundary(norm_trans_points):
with torch.no_grad():
norm_trans_points = norm_trans_points.clone()
oks = torch.stack(
(
norm_trans_points[0] > -1,
norm_trans_points[0] < 1,
norm_trans_points[1] > -1,
norm_trans_points[1] < 1,
norm_trans_points[2] > 0,
)
)
oks = torch.sum(oks, dim=0) == 5
norm_trans_points[2, :] = oks
return norm_trans_points