Move to xautodl
This commit is contained in:
		
							
								
								
									
										1
									
								
								xautodl/datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								xautodl/datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .point_meta import PointMeta2V, apply_affine2point, apply_boundary | ||||
							
								
								
									
										219
									
								
								xautodl/datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								xautodl/datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,219 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import copy, math, torch, numpy as np | ||||
| from xvision import normalize_points | ||||
| from xvision import denormalize_points | ||||
|  | ||||
|  | ||||
| class PointMeta: | ||||
|     # points    : 3 x num_pts (x, y, oculusion) | ||||
|     # image_size: original [width, height] | ||||
|     def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
|  | ||||
|         self.num_point = num_point | ||||
|         if box is not None: | ||||
|             assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|             self.box = torch.Tensor(box) | ||||
|         else: | ||||
|             self.box = None | ||||
|         if points is None: | ||||
|             self.points = points | ||||
|         else: | ||||
|             assert ( | ||||
|                 len(points.shape) == 2 | ||||
|                 and points.shape[0] == 3 | ||||
|                 and points.shape[1] == self.num_point | ||||
|             ), "The shape of point is not right : {}".format(points) | ||||
|             self.points = torch.Tensor(points.copy()) | ||||
|         self.image_path = image_path | ||||
|         self.datasets = dataset_name | ||||
|  | ||||
|     def __repr__(self): | ||||
|         if self.box is None: | ||||
|             boxstr = "None" | ||||
|         else: | ||||
|             boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist()) | ||||
|         return ( | ||||
|             "{name}(points={num_point}, ".format( | ||||
|                 name=self.__class__.__name__, **self.__dict__ | ||||
|             ) | ||||
|             + boxstr | ||||
|             + ")" | ||||
|         ) | ||||
|  | ||||
|     def get_box(self, return_diagonal=False): | ||||
|         if self.box is None: | ||||
|             return None | ||||
|         if not return_diagonal: | ||||
|             return self.box.clone() | ||||
|         else: | ||||
|             W = (self.box[2] - self.box[0]).item() | ||||
|             H = (self.box[3] - self.box[1]).item() | ||||
|             return math.sqrt(H * H + W * W) | ||||
|  | ||||
|     def get_points(self, ignore_indicator=False): | ||||
|         if ignore_indicator: | ||||
|             last = 2 | ||||
|         else: | ||||
|             last = 3 | ||||
|         if self.points is not None: | ||||
|             return self.points.clone()[:last, :] | ||||
|         else: | ||||
|             return torch.zeros((last, self.num_point)) | ||||
|  | ||||
|     def is_none(self): | ||||
|         # assert self.box is not None, 'The box should not be None' | ||||
|         return self.points is None | ||||
|         # if self.box is None: return True | ||||
|         # else               : return self.points is None | ||||
|  | ||||
|     def copy(self): | ||||
|         return copy.deepcopy(self) | ||||
|  | ||||
|     def visiable_pts_num(self): | ||||
|         with torch.no_grad(): | ||||
|             ans = self.points[2, :] > 0 | ||||
|             ans = torch.sum(ans) | ||||
|             ans = ans.item() | ||||
|         return ans | ||||
|  | ||||
|     def special_fun(self, indicator): | ||||
|         if ( | ||||
|             indicator == "68to49" | ||||
|         ):  # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|             assert self.num_point == 68, "num-point must be 68 vs. {:}".format( | ||||
|                 self.num_point | ||||
|             ) | ||||
|             self.num_point = 49 | ||||
|             out = torch.ones((68), dtype=torch.uint8) | ||||
|             out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0 | ||||
|             if self.points is not None: | ||||
|                 self.points = self.points.clone()[:, out] | ||||
|         else: | ||||
|             raise ValueError("Invalid indicator : {:}".format(indicator)) | ||||
|  | ||||
|     def apply_horizontal_flip(self): | ||||
|         # self.points[0, :] = width - self.points[0, :] - 1 | ||||
|         # Mugsy spefic or Synthetic | ||||
|         if self.datasets.startswith("HandsyROT"): | ||||
|             ori = np.array(list(range(0, 42))) | ||||
|             pos = np.array(list(range(21, 42)) + list(range(0, 21))) | ||||
|             self.points[:, pos] = self.points[:, ori] | ||||
|         elif self.datasets.startswith("face68"): | ||||
|             ori = np.array(list(range(0, 68))) | ||||
|             pos = ( | ||||
|                 np.array( | ||||
|                     [ | ||||
|                         17, | ||||
|                         16, | ||||
|                         15, | ||||
|                         14, | ||||
|                         13, | ||||
|                         12, | ||||
|                         11, | ||||
|                         10, | ||||
|                         9, | ||||
|                         8, | ||||
|                         7, | ||||
|                         6, | ||||
|                         5, | ||||
|                         4, | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         27, | ||||
|                         26, | ||||
|                         25, | ||||
|                         24, | ||||
|                         23, | ||||
|                         22, | ||||
|                         21, | ||||
|                         20, | ||||
|                         19, | ||||
|                         18, | ||||
|                         28, | ||||
|                         29, | ||||
|                         30, | ||||
|                         31, | ||||
|                         36, | ||||
|                         35, | ||||
|                         34, | ||||
|                         33, | ||||
|                         32, | ||||
|                         46, | ||||
|                         45, | ||||
|                         44, | ||||
|                         43, | ||||
|                         48, | ||||
|                         47, | ||||
|                         40, | ||||
|                         39, | ||||
|                         38, | ||||
|                         37, | ||||
|                         42, | ||||
|                         41, | ||||
|                         55, | ||||
|                         54, | ||||
|                         53, | ||||
|                         52, | ||||
|                         51, | ||||
|                         50, | ||||
|                         49, | ||||
|                         60, | ||||
|                         59, | ||||
|                         58, | ||||
|                         57, | ||||
|                         56, | ||||
|                         65, | ||||
|                         64, | ||||
|                         63, | ||||
|                         62, | ||||
|                         61, | ||||
|                         68, | ||||
|                         67, | ||||
|                         66, | ||||
|                     ] | ||||
|                 ) | ||||
|                 - 1 | ||||
|             ) | ||||
|             self.points[:, ori] = self.points[:, pos] | ||||
|         else: | ||||
|             raise ValueError("Does not support {:}".format(self.datasets)) | ||||
|  | ||||
|  | ||||
| # shape = (H,W) | ||||
| def apply_affine2point(points, theta, shape): | ||||
|     assert points.size(0) == 3, "invalid points shape : {:}".format(points.size()) | ||||
|     with torch.no_grad(): | ||||
|         ok_points = points[2, :] == 1 | ||||
|         assert torch.sum(ok_points).item() > 0, "there is no visiable point" | ||||
|         points[:2, :] = normalize_points(shape, points[:2, :]) | ||||
|  | ||||
|         norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|  | ||||
|         trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|  | ||||
|         norm_trans_points[:, ok_points] = trans_points | ||||
|  | ||||
|     return norm_trans_points | ||||
|  | ||||
|  | ||||
| def apply_boundary(norm_trans_points): | ||||
|     with torch.no_grad(): | ||||
|         norm_trans_points = norm_trans_points.clone() | ||||
|         oks = torch.stack( | ||||
|             ( | ||||
|                 norm_trans_points[0] > -1, | ||||
|                 norm_trans_points[0] < 1, | ||||
|                 norm_trans_points[1] > -1, | ||||
|                 norm_trans_points[1] < 1, | ||||
|                 norm_trans_points[2] > 0, | ||||
|             ) | ||||
|         ) | ||||
|         oks = torch.sum(oks, dim=0) == 5 | ||||
|         norm_trans_points[2, :] = oks | ||||
|     return norm_trans_points | ||||
		Reference in New Issue
	
	Block a user