#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import codecs
from typing import List
from collections import Counter
import numpy as np

from dlab.data.structures import Dictionary, HierarchicalDictionary, Sentence
from dlab.eval.entity_decoder import BMESWordSegDecoder, EntityDecoder
from dlab.log import log


def read_embedding(filename):
    """读取Embedding文件

    :param filename: Embedding 文件名
    :return: Dict[str, np.array(dtype=float32)]
    """

    with codecs.open(filename, encoding='utf8') as fi:
        embedding_map = {}
        for line in fi:
            if line:
                line = line.strip()
                line = line.split(' ')
                if len(line) <= 2:
                    continue
                embedding_map[line[0]] = np.asarray(line[1:], dtype='float32')
    return embedding_map


def get_mask(length_list_batch, max_length=None):
    max_length = max(length_list_batch) if max_length is None else max_length
    return np.array([[1] * l + [0] * (max_length - l) for l in length_list_batch])


def pad_batch(int_list_batch, max_length=None, padding_num=0, return_length=False, return_mask=False):
    """
    pad a batch to the same length

    :param padding_num: padding 时填补的元素
    :param int_list_batch: List[List[int]]
    :param max_length: int
    :param return_length: 是否返回序列长度
    :param return_mask: 是否返回attention mask，有效位为1，pad位为0。
    :return: pad_int (B*max_len), [seq_length (B)], [ seq_mask(B*max_len) ] of List[List[int]
    :rtype: Union[List[List[int]], List[List[List[int]]]
    """
    pad_int = []
    mask = []
    length = [len(int_list) for int_list in int_list_batch]
    if max_length is None:
        max_length = max(length)
    for int_list in int_list_batch:
        pad_int.append([(int_list[i] if i < len(int_list) else padding_num) for i in range(max_length)])
        if return_mask:
            mask.append([(1 if i < len(int_list) else 0) for i in range(max_length)])
    # returning part
    ret = [pad_int]
    if return_length:
        ret.append(length)
    if return_mask:
        ret.append(mask)
    return ret[0] if len(ret) == 1 else ret


def fit_dict(sentences: List[Sentence], dictionary: Dictionary, tag_name: str = None, min_count=1, max_num=None,
             ngram=1, ngram_pad='B', ngram_spliter=''):
    """
    将句子列表中每个词的某种属性（词文本或tag）来填充字典

    :param sentences: sentence 的列表
    :param dictionary: 初始化好的字典
    :param tag_name: 对应 token 中tag的名字，不设置(None)则依赖 token 的文本
    :param min_count: 最少出现次数
    :param max_num: 最大的填充数。0 或 None 表示不限制。
    :return:
    """
    max_num = None if max_num is None or max_num == 0 else max_num
    log.debug('fit dict with %d sentences of tag=%s, min_count=%d, max_num=%s.' %
              (len(sentences), tag_name, min_count, max_num))
    counter = Counter()
    for sentence in sentences:
        token_strings = [ngram_pad] * (ngram - 1)
        if tag_name is None:
            token_strings += [token.text for token in sentence]
        else:
            token_strings += [token.get_tag(tag_name) for token in sentence]
        if ngram > 1:
            token_strings = [ngram_spliter.join(token_strings[i:i + ngram]) for i in range(len(sentence))]
        counter.update(token_strings)

    for word, count in counter.most_common(max_num):
        if count < min_count:
            break
        dictionary.add_item(word)
    log.debug('fitted dict=%s... len=%s' % (dictionary.idx2item[:5], len(dictionary)))
    return dictionary


def fit_hierarchical_dict(sentences: List[Sentence], dictionary: HierarchicalDictionary, label_name: str='category',
                          min_count=1, max_num=None):
    max_num = None if max_num is None or max_num == 0 else max_num
    log.debug('fit hierarchical dict with %d sentences of label=%s, min_count=%d, max_num=%s.' %
              (len(sentences), label_name, min_count, max_num))
    counter = Counter()
    counter.update([sentence.get_label_name(label_name) for sentence in sentences])
    for label, count in counter.most_common(max_num):
        if count < min_count:
            break
        dictionary.add_item(label)
    log.debug('fitted dicts=%s... dicts-len=%s' % (list(dictionary.dict_set.keys())[:5], len(dictionary.dict_set.keys())))
    return dictionary


def write_tag_sentence(filename: str, sentences: List[Sentence], tags: List[str] = None):
    """
    write sentence to conll format.
    :param filename:
    :param sentences: data
    :param tags: the tags of the token will be write to the right of the token as tags name
    :return: None
    """
    if tags is None:
        tags = []
    with codecs.open(filename, 'w', encoding='utf8') as f:
        for sentence in sentences:
            f.writelines(['\t'.join([token.text] + [token.get_tag(tag) for tag in tags]) + '\n' for token in sentence])
            f.writelines(['\n'])


def write_cws_sentence(filename: str, sentences: List[Sentence], ws_tag='ws', spliter='\t',
                       decoder: EntityDecoder = None):
    """

    :param filename:
    :param sentences:
    :param ws_tag:
    :param spliter:
    :param decoder: default is BMESWordSegDecoder
    :return:
    """
    decoder = BMESWordSegDecoder() if decoder is None else decoder
    with codecs.open(filename, 'w', encoding='utf8') as f:
        for sentence in sentences:
            chars, tags = zip(*[(token.text, token.get_tag(ws_tag)) for token in sentence])
            words = decoder.decode(words=chars, tags=tags)
            f.writelines([spliter.join(words)])
            f.writelines(['\n'])


def _split_sentence(sentence: Sentence, split_char='。', max_len=None) -> List[Sentence]:
    if len(sentence) == 0:
        return []
    ret = [Sentence()]
    for tok_id, tok in enumerate(sentence):
        ret[-1].tokens.append(tok)
        if tok_id < len(sentence) - 1 \
                and (tok.text == split_char
                     or (max_len is not None and len(ret[-1]) >= max_len)):
            ret.append(Sentence())
    return ret


def split_sentences(sentences: List[Sentence], split_char='。', max_len=None) -> List[Sentence]:
    """
    将sentence数组中每个centence按照'。'分句，返回拆分后新的句子数组。注意，inplace 会使用原句的token。

    :param sentences:
    :param split_char:
    :param max_len:
    :return:
    """
    ret = []
    for stn in sentences:
        mini_sentences = _split_sentence(stn, split_char, max_len)
        ret.extend(mini_sentences)
    return ret


class EarlyStopping(object):
    def __init__(self, patient=10):
        self.patient = patient
        self.best_score = -1
        self.no_update_counter = 0

    def next_score(self, score):
        """
        :param score: the dev score
        :return: True to exit training
        """
        self.no_update_counter += 1
        if score > self.best_score:
            self.best_score = score
            self.no_update_counter = 0
        log.info('early-stopping waiting epochs [%d/%d]' % (self.no_update_counter, self.patient))
        return self.no_update_counter >= self.patient
