#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from collections import defaultdict

from dlab.evaluate import F1beta


def tag_f1(predict_tags, gold_tags, other_tag='O'):
    """

    :param predict_tags: 预测标签，每句一个数组，传入一个句子标签列表的列表
    :type predict_tags: list[list[str]]
    :param gold_tags: 黄金标签，每句一个数组，传入一个句子标签列表的列表
    :type gold_tags: list[list[str]]
    :return: F1beta
    """
    f_beta = F1beta()
    for predict_sentence, gold_sentence in zip(predict_tags, gold_tags):
        assert len(predict_sentence) == len(gold_sentence)
        for i in range(len(predict_sentence)):
            if not predict_sentence[i] == other_tag:
                f_beta.add_pred()
            if not gold_sentence[i] == other_tag:
                f_beta.add_gold()
            if gold_sentence[i] == predict_sentence[i] and not gold_sentence[i] == other_tag:
                f_beta.add_match()
    return f_beta


def entity_f1(predict_tags, gold_tags, decoders, return_entity_grouped=False):
    f_beta = F1beta()
    entity_grouped_f_beta = defaultdict(F1beta)
    for predict_sentence, gold_sentence in zip(predict_tags, gold_tags):
        assert len(predict_sentence) == len(gold_sentence)
        fake_words = [str(i) + '-' for i in range(len(predict_sentence))]
        for decoder in decoders:
            pred_entities = set(decoder.decode(fake_words, predict_sentence))
            gold_entities = set(decoder.decode(fake_words, gold_sentence))
            intersection = pred_entities & gold_entities
            f_beta.add_pred(len(pred_entities))
            f_beta.add_gold(len(gold_entities))
            f_beta.add_match(len(intersection))
            if return_entity_grouped:
                entity_grouped_f_beta[decoder.name].add_pred(len(pred_entities))
                entity_grouped_f_beta[decoder.name].add_gold(len(gold_entities))
                entity_grouped_f_beta[decoder.name].add_match(len(intersection))
    return (f_beta, entity_grouped_f_beta) if return_entity_grouped else f_beta


def get_bmeso_entity(word_seq, tag_seq, entity_name,
                     begin_prefix='B-', middle_prefix='M-', end_prefix='E-', single_prefix='S-'):
    """
    :param tag_seq: word list
    :param tag_seq: The tags of the words
    :param entity_name: The name of the entity that needs to be extracted.
    :param begin_prefix:
    :param middle_prefix: you can set this as 'I-' to use the BIO decode mode
    :param end_prefix: you can set this as 'I-' to use the BIO decode mode
    :param single_prefix: you can set this as 'B-' to use the BIO decode mode
    :return: list of entities
    :rtype: list[str]
    """
    if not len(word_seq) == len(tag_seq):
        raise ValueError('word_seq length (%d) and tag_seq length (%d) should be equal' % (len(word_seq), len(tag_seq)))
    if isinstance(word_seq[0], tuple) or isinstance(word_seq[0], list):
        # if multi features, only extract the first feature.
        return get_bmeso_entity([f[0] for f in word_seq], tag_seq, entity_name,
                                begin_prefix, middle_prefix, end_prefix, single_prefix)
    ret = []
    begin = begin_prefix + entity_name
    middle = middle_prefix + entity_name
    single = single_prefix + entity_name
    end = end_prefix + entity_name
    entity_recording_flag = False
    for word, label in zip(word_seq, tag_seq):
        # consider that BMES tags may be same, the order of 'if-else' must be careful.
        if label == middle or label == end:
            if entity_recording_flag:
                ret[-1] += word
                if label == end and not label == middle:
                    entity_recording_flag = False
            else:
                ret.append(word)
        elif label == begin:
            ret.append(word)
            entity_recording_flag = True
        elif label == single:
            ret.append(word)
            entity_recording_flag = False
        else:
            entity_recording_flag = False

    return ret


class BaseDecoder(object):
    """
    BaseDecoder is a extensible base class that you can implement the decode method to decode the entity of the tags,
    differ decoder should extract different entities. List of this object will be pass to the model for judge entities'
    f1 score.
    """
    name = 'default'

    def decode(self, words, tags):
        """

        :param words: the word list to be decode on.
        :rtype words: list[str]
        :param tags:
        :rtype tags: list[str]
        :return:
        :rtype: list
        """
        raise NotImplementedError()


class EntityDecoder(BaseDecoder):
    begin_prefix = 'B-'
    middle_prefix = 'M-'
    end_prefix = 'E-'
    single_prefix = 'S-'

    def __init__(self, entity_name):
        self.entity_name = entity_name
        self.name = self.entity_name

    def decode(self, words, tags):
        return get_bmeso_entity(words, tags, self.entity_name, begin_prefix=self.begin_prefix,
                                middle_prefix=self.middle_prefix, end_prefix=self.end_prefix, single_prefix=self.single_prefix)


class BMESEntityDecoder(EntityDecoder):
    pass


class BIEOEntityDecoder(EntityDecoder):
    begin_prefix = 'B-'
    middle_prefix = 'I-'
    end_prefix = 'E-'
    single_prefix = 'S-'


class BIOEntityDecoder(BIEOEntityDecoder):
    pass


class BMESWordSegDecoder(EntityDecoder):
    """
    >>> from dlab.eval.entity_decoder import BMESWordSegDecoder
    >>> decoder = BMESWordSegDecoder()
    >>> decoder.decode('我爱北京天安门', 'SSBEBME')
    ['我', '爱', '北京', '天安门']
    >>> decoder.decode(['我', '爱', '北', '京', '天', '安', '门'], ['S', 'S', 'B', 'E', 'B', 'M', 'E'])
    ['我', '爱', '北京', '天安门']
    """
    begin_prefix = 'B'
    middle_prefix = 'M'
    end_prefix = 'E'
    single_prefix = 'S'

    def __init__(self):
        super().__init__('')

