#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import torch


def gold_match(predict_probs, gold_tags):
    """

    :param predict_probs: B*Tag tensor
    :param gold_tags: Tag
    :return: numbers of matching the gold
    """
    _, predict_index = torch.max(predict_probs, 1)
    match_matrix = torch.eq(predict_index, gold_tags)
    match = torch.sum(match_matrix).data.item()
    return match


def acc(predict_probs, gold_tags):
    """

    :param predict_probs: B*Tag tensor
    :param gold_tags: Tag
    :return: acc value
    """
    match = gold_match(predict_probs, gold_tags)
    return match / gold_tags.shape[0]


class MetricsValue(object):
    def __add__(self, other):
        raise NotImplementedError()

    def __iadd__(self, other):
        return self.__add__(other)

    def __str__(self):
        raise NotImplementedError()

    def __float__(self):
        """ 需要使用 float 进行偏序的比较 """
        raise NotImplementedError()


class Ratio(MetricsValue):
    """
    比值的一个实现。用于实现``loss``, ``ACC``等比值指标的累加
    """
    def __init__(self, name=''):
        self.name = name
        self._n = 0
        self._d = 0

    def update(self, n, d):
        """
        累加分数

        :param n: 分子增加的值
        :param d: 分母增加的值
        """
        self._n += n
        self._d += d

    def __add__(self, other):
        r = Ratio()
        r.update(self.n + other.n, self.d + other.d)
        return r

    def __iadd__(self, other):
        self._n += other.n
        self._d += other.d
        return self

    @property
    def n(self):
        return self._n

    @property
    def d(self):
        return self._d

    def __float__(self):
        return float(0.0 if self._d == 0 else self._n / self._d)

    def __str__(self):
        leading = ('%s: ' % self.name) if self.name else ''
        return '%s%d/%d=%.4f' % (leading, self._n, self._d, self.__float__())


class F1beta(object):
    """
    F值的一个实现
    """
    def __init__(self):
        self._p_all, self._r_all, self._t = 0, 0, 0

    def add_gold(self, n=1):
        """
        增加黄金语料中正例的统计数。

        :param n:
        :return:
        """
        self._r_all += n

    def add_pred(self, n=1):
        """
        增加预测出的整理的统计数。

        :param n:
        :return:
        """
        self._p_all += n

    def add_match(self, n=1):
        """
        增加预测与黄金语料匹配的统计数。

        :param n:
        :return:
        """
        self._t += n

    @property
    def p(self):
        return self._t / self._p_all if self._p_all else 0

    @property
    def r(self):
        return self._t / self._r_all if self._r_all else 0

    @property
    def f(self):
        p = self.p
        r = self.r
        return 2 * p * r / (p + r) if self._p_all + self._r_all and self._t else 0

    def __str__(self):
        return '[p:%d/%d=%.2f%% r:%d/%d=%.2f%% f:%.2f%%]' % \
               (self._t, self._p_all, self.p * 100, self._t, self._r_all, self.r * 100, self.f * 100)

    def __float__(self):
        return float(self.f)

    def __add__(self, other):
        ret = F1beta()
        ret._p_all, ret._r_all, ret._t = self._p_all + other._p_all, self._r_all + other._r_all, self._t + other._t
        return ret

    def __iadd__(self, other):
        self.add_gold(other._r_all)
        self.add_match(other._t)
        self.add_pred(other._p_all)
        return self

    def as_tuple(self):
        return self.p, self.r, self.f


def f1(predict_probs: torch.Tensor, gold_tags: torch.Tensor) -> F1beta:
    """
    计算二维f1值，0为负例（NULL），>=1为正例。

    :param predict_probs: B*time*tag_set
    :param gold_tags: B*time
    :return: f1值对象
    """
    ret = F1beta()
    predict_probs, gold_tags = predict_probs.data, gold_tags.data
    _, predict_index = torch.max(predict_probs, 1)
    gold_no_zero_mask = 1 - (gold_tags == 0)
    ret.add_gold(torch.sum(gold_no_zero_mask).item())
    ret.add_pred(torch.sum(1 - (predict_index == 0)).item())
    ret.add_match(torch.sum(torch.eq(gold_tags, predict_index) * gold_no_zero_mask).item())
    return ret
