Initial commit
This commit is contained in:
		
							
								
								
									
										1
									
								
								datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								datasets/landmark_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .point_meta import PointMeta2V, apply_affine2point, apply_boundary | ||||
							
								
								
									
										116
									
								
								datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								datasets/landmark_utils/point_meta.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import copy, math, torch, numpy as np | ||||
| from xvision import normalize_points | ||||
| from xvision import denormalize_points | ||||
|  | ||||
|  | ||||
| class PointMeta(): | ||||
|   # points    : 3 x num_pts (x, y, oculusion) | ||||
|   # image_size: original [width, height] | ||||
|   def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
|  | ||||
|     self.num_point = num_point | ||||
|     if box is not None: | ||||
|       assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|       self.box = torch.Tensor(box) | ||||
|     else: self.box = None | ||||
|     if points is None: | ||||
|       self.points = points | ||||
|     else: | ||||
|       assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) | ||||
|       self.points = torch.Tensor(points.copy()) | ||||
|     self.image_path = image_path | ||||
|     self.datasets = dataset_name | ||||
|  | ||||
|   def __repr__(self): | ||||
|     if self.box is None: boxstr = 'None' | ||||
|     else               : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist()) | ||||
|     return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')') | ||||
|  | ||||
|   def get_box(self, return_diagonal=False): | ||||
|     if self.box is None: return None | ||||
|     if not return_diagonal: | ||||
|       return self.box.clone() | ||||
|     else: | ||||
|       W = (self.box[2]-self.box[0]).item() | ||||
|       H = (self.box[3]-self.box[1]).item() | ||||
|       return math.sqrt(H*H+W*W) | ||||
|  | ||||
|   def get_points(self, ignore_indicator=False): | ||||
|     if ignore_indicator: last = 2 | ||||
|     else               : last = 3 | ||||
|     if self.points is not None: return self.points.clone()[:last, :] | ||||
|     else                      : return torch.zeros((last, self.num_point)) | ||||
|  | ||||
|   def is_none(self): | ||||
|     #assert self.box is not None, 'The box should not be None' | ||||
|     return self.points is None | ||||
|     #if self.box is None: return True | ||||
|     #else               : return self.points is None | ||||
|  | ||||
|   def copy(self): | ||||
|     return copy.deepcopy(self) | ||||
|  | ||||
|   def visiable_pts_num(self): | ||||
|     with torch.no_grad(): | ||||
|       ans = self.points[2,:] > 0 | ||||
|       ans = torch.sum(ans) | ||||
|       ans = ans.item() | ||||
|     return ans | ||||
|    | ||||
|   def special_fun(self, indicator): | ||||
|     if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|       assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point) | ||||
|       self.num_point = 49 | ||||
|       out = torch.ones((68), dtype=torch.uint8) | ||||
|       out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0 | ||||
|       if self.points is not None: self.points = self.points.clone()[:, out] | ||||
|     else: | ||||
|       raise ValueError('Invalid indicator : {:}'.format( indicator )) | ||||
|  | ||||
|   def apply_horizontal_flip(self): | ||||
|     #self.points[0, :] = width - self.points[0, :] - 1 | ||||
|     # Mugsy spefic or Synthetic | ||||
|     if self.datasets.startswith('HandsyROT'): | ||||
|       ori = np.array(list(range(0, 42))) | ||||
|       pos = np.array(list(range(21,42)) + list(range(0,21))) | ||||
|       self.points[:, pos] = self.points[:, ori] | ||||
|     elif self.datasets.startswith('face68'): | ||||
|       ori = np.array(list(range(0, 68))) | ||||
|       pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1 | ||||
|       self.points[:, ori] = self.points[:, pos] | ||||
|     else: | ||||
|       raise ValueError('Does not support {:}'.format(self.datasets)) | ||||
|  | ||||
|  | ||||
|  | ||||
| # shape = (H,W) | ||||
| def apply_affine2point(points, theta, shape): | ||||
|   assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size()) | ||||
|   with torch.no_grad(): | ||||
|     ok_points = points[2,:] == 1 | ||||
|     assert torch.sum(ok_points).item() > 0, 'there is no visiable point' | ||||
|     points[:2,:] = normalize_points(shape, points[:2,:]) | ||||
|  | ||||
|     norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|  | ||||
|     trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|  | ||||
|     norm_trans_points[:, ok_points] = trans_points | ||||
|      | ||||
|   return norm_trans_points | ||||
|  | ||||
|  | ||||
|  | ||||
| def apply_boundary(norm_trans_points): | ||||
|   with torch.no_grad(): | ||||
|     norm_trans_points = norm_trans_points.clone() | ||||
|     oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0)) | ||||
|     oks = torch.sum(oks, dim=0) == 5 | ||||
|     norm_trans_points[2, :] = oks | ||||
|   return norm_trans_points | ||||
		Reference in New Issue
	
	Block a user