# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy, math, torch, numpy as np
from xvision import normalize_points
from xvision import denormalize_points


class PointMeta:
    # points    : 3 x num_pts (x, y, oculusion)
    # image_size: original [width, height]
    def __init__(self, num_point, points, box, image_path, dataset_name):

        self.num_point = num_point
        if box is not None:
            assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
            self.box = torch.Tensor(box)
        else:
            self.box = None
        if points is None:
            self.points = points
        else:
            assert (
                len(points.shape) == 2
                and points.shape[0] == 3
                and points.shape[1] == self.num_point
            ), "The shape of point is not right : {}".format(points)
            self.points = torch.Tensor(points.copy())
        self.image_path = image_path
        self.datasets = dataset_name

    def __repr__(self):
        if self.box is None:
            boxstr = "None"
        else:
            boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
        return (
            "{name}(points={num_point}, ".format(
                name=self.__class__.__name__, **self.__dict__
            )
            + boxstr
            + ")"
        )

    def get_box(self, return_diagonal=False):
        if self.box is None:
            return None
        if not return_diagonal:
            return self.box.clone()
        else:
            W = (self.box[2] - self.box[0]).item()
            H = (self.box[3] - self.box[1]).item()
            return math.sqrt(H * H + W * W)

    def get_points(self, ignore_indicator=False):
        if ignore_indicator:
            last = 2
        else:
            last = 3
        if self.points is not None:
            return self.points.clone()[:last, :]
        else:
            return torch.zeros((last, self.num_point))

    def is_none(self):
        # assert self.box is not None, 'The box should not be None'
        return self.points is None
        # if self.box is None: return True
        # else               : return self.points is None

    def copy(self):
        return copy.deepcopy(self)

    def visiable_pts_num(self):
        with torch.no_grad():
            ans = self.points[2, :] > 0
            ans = torch.sum(ans)
            ans = ans.item()
        return ans

    def special_fun(self, indicator):
        if (
            indicator == "68to49"
        ):  # For 300W or 300VW, convert the default 68 points to 49 points.
            assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
                self.num_point
            )
            self.num_point = 49
            out = torch.ones((68), dtype=torch.uint8)
            out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
            if self.points is not None:
                self.points = self.points.clone()[:, out]
        else:
            raise ValueError("Invalid indicator : {:}".format(indicator))

    def apply_horizontal_flip(self):
        # self.points[0, :] = width - self.points[0, :] - 1
        # Mugsy spefic or Synthetic
        if self.datasets.startswith("HandsyROT"):
            ori = np.array(list(range(0, 42)))
            pos = np.array(list(range(21, 42)) + list(range(0, 21)))
            self.points[:, pos] = self.points[:, ori]
        elif self.datasets.startswith("face68"):
            ori = np.array(list(range(0, 68)))
            pos = (
                np.array(
                    [
                        17,
                        16,
                        15,
                        14,
                        13,
                        12,
                        11,
                        10,
                        9,
                        8,
                        7,
                        6,
                        5,
                        4,
                        3,
                        2,
                        1,
                        27,
                        26,
                        25,
                        24,
                        23,
                        22,
                        21,
                        20,
                        19,
                        18,
                        28,
                        29,
                        30,
                        31,
                        36,
                        35,
                        34,
                        33,
                        32,
                        46,
                        45,
                        44,
                        43,
                        48,
                        47,
                        40,
                        39,
                        38,
                        37,
                        42,
                        41,
                        55,
                        54,
                        53,
                        52,
                        51,
                        50,
                        49,
                        60,
                        59,
                        58,
                        57,
                        56,
                        65,
                        64,
                        63,
                        62,
                        61,
                        68,
                        67,
                        66,
                    ]
                )
                - 1
            )
            self.points[:, ori] = self.points[:, pos]
        else:
            raise ValueError("Does not support {:}".format(self.datasets))


# shape = (H,W)
def apply_affine2point(points, theta, shape):
    assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
    with torch.no_grad():
        ok_points = points[2, :] == 1
        assert torch.sum(ok_points).item() > 0, "there is no visiable point"
        points[:2, :] = normalize_points(shape, points[:2, :])

        norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()

        trans_points, ___ = torch.gesv(points[:, ok_points], theta)

        norm_trans_points[:, ok_points] = trans_points

    return norm_trans_points


def apply_boundary(norm_trans_points):
    with torch.no_grad():
        norm_trans_points = norm_trans_points.clone()
        oks = torch.stack(
            (
                norm_trans_points[0] > -1,
                norm_trans_points[0] < 1,
                norm_trans_points[1] > -1,
                norm_trans_points[1] < 1,
                norm_trans_points[2] > 0,
            )
        )
        oks = torch.sum(oks, dim=0) == 5
        norm_trans_points[2, :] = oks
    return norm_trans_points