#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from functools import reduce
from typing import List, Dict, Union, Generator, Any

import torch
from collections import Counter
from collections import defaultdict
import json
import codecs
from dlab.log import log
from dlab.const import PAD, UNK
from dlab.data.language import LANG


class Dictionary:
    """
    This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.
    """

    def __init__(self, add_pad=True, add_unk=True):
        """
        init a dictionary.

        :param add_pad: set to True when storing tokens which are features on variable length sequence.
        :param add_unk: set to True when some feature value not appear in the fitting data.
        """
        # init dictionaries
        self.item2idx: Dict[str, int] = {}
        self.idx2item: List[str] = []

        # in order to deal with unknown tokens, add <unk>
        self.add_pad = add_pad
        self.add_unk = add_unk
        if self.add_pad:
            self.add_item(PAD)
        if self.add_unk:
            self.add_item(UNK)

    def add_item(self, item: str) -> int:
        """
        add string - if already in dictionary returns its ID. if not in dictionary, it will get a new ID.

        :param item: a string for which to assign an id
        :return: ID of string
        """
        if item not in self.item2idx:
            self.idx2item.append(item)
            self.item2idx[item] = len(self.idx2item) - 1
        return self.item2idx[item]

    def get_idx_for_item(self, item: str) -> int:
        """
        returns the ID of the string, otherwise 0

        :param item: string for which ID is requested
        :return: ID of string, otherwise 0
        """
        if item in self.item2idx:
            return self.item2idx[item]
        elif UNK in self.item2idx:
            return self.item2idx[UNK]
        else:
            raise ValueError(f'Unknown Key=`{item}` in Dictionary(add_unk=False)')

    def get_items(self) -> List[str]:
        return self.idx2item

    def __len__(self) -> int:
        return len(self.idx2item)

    def get_item_for_index(self, idx: int):
        if idx > len(self.idx2item):
            raise ValueError('idx=%d out of index.' % idx)
        return self.idx2item[idx]

    def save(self, file: str):
        """
        save this dictionary to a file

        :param file:
        """
        import pickle
        with codecs.open(file, 'wb') as f:
            mappings = {
                'idx2item': self.idx2item,
                'item2idx': self.item2idx
            }
            pickle.dump(mappings, f)

    def show(self):
        print(self.item2idx)

    @classmethod
    def load_from_file(cls, filename: str):
        """
        return a Dictionary object in the given file.

        :param filename: path to dict file.
        :return:
        :rtype: Dictionary
        """
        import pickle
        dictionary: Dictionary = Dictionary()
        with codecs.open(filename, 'rb') as f:
            mappings = pickle.load(f)
            idx2item = mappings['idx2item']
            item2idx = mappings['item2idx']
            dictionary.item2idx = item2idx
            dictionary.idx2item = idx2item
        return dictionary

    @classmethod
    def load(cls, file: str):
        """
        ::ref

        :param file:
        :return:
        """

        return Dictionary.load_from_file(file)

    def to_dict(self) -> Dict:
        return {'idx2item': self.idx2item, 'item2idx': self.item2idx}

    @classmethod
    def from_dict(cls, obj: Dict):
        dictionary = cls()
        idx2item = obj['idx2item']
        item2idx = obj['item2idx']
        dictionary.item2idx = item2idx
        dictionary.idx2item = idx2item
        return dictionary


class HierarchicalDictionary(object):

    def __init__(self, add_unk=False, sep=';', max_len=None):
        """
        注：必须加入pad，预留0来保证层级分类结束

        :param add_unk: 是否允许unk类？默认关闭。
        :param sep:
        """
        self.add_unk = add_unk
        self.sep = sep
        self.max_len = max_len
        self.dict_set: Dict[str, Dictionary] = {}

    def _gen_dict(self) -> Dictionary:
        return Dictionary(add_unk=self.add_unk, add_pad=True)

    def add_dict(self, parent_name: str) -> Dictionary:
        """ 返回 name 指定的dict，如果没有则创建 """
        if parent_name not in self.dict_set:
            self.dict_set[parent_name] = self._gen_dict()
        return self.dict_set[parent_name]

    def get_dict(self, parent_name: str) -> Dictionary:
        if parent_name in self.dict_set:
            return self.dict_set[parent_name]
        else:
            if self.add_unk:
                return self._gen_dict()
            else:
                raise ValueError('dict `%s` is not in the dictionary' % parent_name)

    def _str2list(self, item: str) -> List[str]:
        if not item.startswith(self.sep):
            item = self.sep + item
        item_list = item.split(self.sep)
        return item_list

    def add_item(self, item: str) -> List[int]:
        item_list = self._str2list(item)
        ret = []
        for i in range(1, len(item_list)):
            parent = self.sep.join(item_list[:i])
            ret.append(self.add_dict(parent).add_item(item_list[i]))
        if self.max_len is not None:
            ret.extend([0] * self.max_len - len(ret))  # pad to max_len
        return ret

    def get_idx_for_item(self, item: str) -> List[int]:
        item_list = self._str2list(item)
        ret = []
        for i in range(1, len(item_list)):
            parent = self.sep.join(item_list[:i])
            ret.append(self.get_dict(parent).get_idx_for_item(item_list[i]))
        if self.max_len is not None:
            ret.extend([0] * self.max_len - len(ret))  # pad to max_len
        return ret

    def get_item_for_index(self, ids: List[int]) -> str:
        cur_dict_name = ''
        for idx in ids:
            if idx == 0 or cur_dict_name not in self.dict_set:
                # get pad or no cluster after cur, quit decoding
                break
            elif idx >= len(self.dict_set[cur_dict_name]):
                # get out of vocab, decode as first class in this cluster
                idx = 1

            # normal
            cur_dict_name += self.sep + self.dict_set[cur_dict_name].get_item_for_index(idx)
        return cur_dict_name

    def to_dict(self):
        return {k: v.to_dict() for k, v in self.dict_set.items()}

    @classmethod
    def from_dict(cls, dict_obj: Dict, **kwargs):
        d = cls(**kwargs)
        d.dict_set = {n: Dictionary.from_dict(dictionary) for n, dictionary in dict_obj.items()}
        return d

    def save(self, file):
        import pickle
        with codecs.open(file, 'wb') as f:
            pickle.dump(self.to_dict(), f)

    def show(self):
        print(self.to_dict())

    @property
    def full_name_set(self):
        return {k: set(k + self.sep + vv for vv in v.idx2item) for k, v in self.dict_set.items()}

    @property
    def full_name_list(self):
        name_sets = reduce(lambda x, y: x | y, self.full_name_set.values(), set())
        names = sorted(list(name_sets))
        return names

    def __iter__(self):
        for n in self.full_name_list:
            yield n

    def deep(self):
        return max(k.count(self.sep) for k in self.dict_set.keys()) + 1

    def width(self):
        return max(len(v) for v in self.dict_set.values())

    @classmethod
    def load_from_file(cls, filename: str, **kwargs):
        import pickle
        with codecs.open(filename, 'rb') as f:
            mappings = pickle.load(f)
            dictionary: HierarchicalDictionary = cls.from_dict(mappings, **kwargs)
        return dictionary

    @classmethod
    def load(cls, file: str, **kwargs):
        return cls.load_from_file(file, **kwargs)


class Label:
    """
    This class represents a label of a sentence. Each label has a value and optionally a confidence score. The
    score needs to be between 0.0 and 1.0. Default value for the score is 1.0.
    """

    def __init__(self, value: str, score: float = 1.0):
        self.value = value
        self.score = score
        super().__init__()

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, value):
        self._value = value

    @property
    def score(self):
        return self._score

    @score.setter
    def score(self, score):
        if 0.0 <= score <= 1.0:
            self._score = float(score)
        else:
            self._score = 1.0

    def to_dict(self):
        return {
            'value': self.value,
            'confidence': self.score
        }

    @classmethod
    def from_dict(cls, label_dict):
        if isinstance(label_dict, str):
            label_name = label_dict
            return cls(label_name)
        else:
            confidence = label_dict['confidence'] if 'confidence' in label_dict else 1.0
            return cls(label_dict['value'], confidence)

    def __str__(self):
        return "{}".format(self._value)

    def __repr__(self):
        return "{}".format(self._value)


class Token:
    """
    This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point
    to its head in a dependency tree.
    """
    TOKEN_TAG = '__TOKEN'  # build-in tag name for token.text, this can't be remove.

    def __init__(self,
                 text: str,
                 idx: int = None,
                 start_position: int = None
                 ):
        self.text: str = text
        self.idx: int = idx

        self.start_pos = start_position
        self.end_pos = start_position + len(text) if start_position is not None else None

        self.tags: Dict[str, Label] = {}

    def add_tag(self, tag_type: str, tag_value: str, confidence=1.0):
        tag = Label(tag_value, confidence)
        self.tags[tag_type] = tag

    def get_tag(self, tag_type: str) -> str:
        if tag_type in self.tags:
            return self.tags[tag_type].value
        elif tag_type == self.TOKEN_TAG:
            return self.text
        else:
            raise ValueError('no `%s` tag in token.' % tag_type)

    def has_tag(self, tag_type: str) -> bool:
        return tag_type in self.tags

    def remove_tag(self, tag_type: str) -> None:
        self.tags.pop(tag_type)

    @property
    def start_position(self) -> int:
        return self.start_pos

    @property
    def end_position(self) -> int:
        return self.end_pos

    def __str__(self) -> str:
        ret = 'Token: {} {}'.format(self.idx, self.text) if self.idx is not None else 'Token: {}'.format(self.text)
        for tag_name, tag_value in self.tags.items():
            ret += ' %s="%s"' % (tag_name, tag_value)
        return ret

    def __repr__(self) -> str:
        return self.__str__()

    def __len__(self):
        return len(self.text)

    def to_dict(self):
        return {'idx': self.idx, 'text': self.text, 'start_pos': self.start_pos, 'end_pos': self.end_pos,
                'tags': {k: v.to_dict() for k, v in self.tags.items()}}

    @classmethod
    def from_dict(cls, token_dict):
        token = cls(text=token_dict['text'],
                    idx=token_dict['idx'],
                    start_position=token_dict['start_pos'])
        if 'end_pos' in token_dict:
            token.end_pos = token_dict['end_pos']
        if 'tags' in token_dict:
            token.tags = {k: Label.from_dict(v) for k, v in token_dict['tags'].items()}
        return token


class Sentence:
    """
    A Sentence is a list of Tokens and is used to represent a sentence or text fragment.
    """

    def __init__(self, words: Union[List[str], str, Generator] = None, label_name: str = None,
                 label: Union[Label, str] = None, lang: LANG = LANG.ZH):
        super().__init__()
        self.lang: LANG = lang

        if words is None:
            words = []
        if not isinstance(words, list):
            words = list(words)
        self.tokens: List[Token] = []
        read_length = 0
        for idx, word in enumerate(words):
            self.tokens.append(Token(word, idx, read_length))
            read_length += len(word)

        self.labels: Dict[str, Label] = {}
        if label_name is not None and label is not None:
            self.set_label(label_name, label)

    def add_token(self, token: Union[Token, str]):
        """增加一个token

        :param token: str 或 一个 Token。 如果是str，则会增加一个没有tag的token。
        :return:
        """
        if isinstance(token, Token):
            self.tokens.append(token)
        else:
            self.tokens.append(Token(token, len(self.tokens), len(self.text)))

    @property
    def text(self):
        """ 句子的纯文本，和语言设置有关 """
        return self._partition().join([t.text for t in self.tokens])

    def _partition(self) -> str:
        return '' if self.lang == LANG.ZH else ' '

    def get_token(self, token_id: int) -> Token:
        return self.tokens[token_id]

    # label part
    def set_label(self, label_name: str, label: Union[Label, str], confidence: float=1.0):
        """
        设置句子的 label

        same present:
        set_label('category', 'lit', 1.0)
        set_label('category', 'lit')
        set_label('category', Label('lit'))
        set_label('category', Label('lit', 1.0))

        :param label_name: label name
        :param label: Union[Label, str]
        :param confidence: optional default = 1.0
        :return:
        """
        if type(label) is Label:
            self.labels[label_name] = label
        else:
            self.labels[label_name] = Label(label, confidence)

    def get_label(self, label_name: str='category') -> Label:
        return self.labels[label_name]

    def get_label_name(self, label_name: str='category') -> str:
        return self.labels[label_name].value

    def index(self, word: str):
        for tok_id, tok in enumerate(self.tokens):
            if tok.text == word:
                return tok_id
        return -1

    def __getitem__(self, idx: int) -> Token:
        return self.get_token(idx)

    def __iter__(self):
        return iter(self.tokens)

    def __repr__(self):
        return 'Sentence: "{}" - {} Tokens'.format(self.text, len(self))

    def __copy__(self):
        s = Sentence(None, lang=self.lang)
        for label_name, label in self.labels.items():
            s.set_label(label_name, Label(label.value, label.score))
        for token in self.tokens:
            nt = Token(token.text)
            for tag_type, tag_label in token.tags.items():
                nt.add_tag(tag_type, token.get_tag(tag_type), token.tags[tag_type].score)

            s.add_token(nt)
        return s

    def __str__(self) -> str:
        return 'Sentence: "{}" - {} Tokens'.format(self.text, len(self))

    def __len__(self) -> int:
        return len(self.tokens)

    def to_dict(self):
        return {
            'lang': self.lang.name,
            'tokens': [t.to_dict() for t in self.tokens],
            'labels': {k: v.to_dict() for k, v in self.labels.items()},
        }

    @classmethod
    def from_dict(cls, sentence_dict):
        stn = cls(lang=LANG(sentence_dict['lang']) if 'lang' in sentence_dict else LANG.ZH)
        if 'tokens' in sentence_dict:
            for t_id, t in enumerate(sentence_dict['tokens']):
                if 'start_pos' not in t:
                    t['start_pos'] = len(stn.text)
                if 'idx' not in t:
                    t['idx'] = len(stn)
                stn.tokens.append(Token.from_dict(t))
        if 'labels' in sentence_dict:
            stn.labels = {k: Label.from_dict(v) for k, v in sentence_dict['labels']}
        return stn

    def to_json(self):
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, json_string: str):
        return cls.from_dict(json.loads(json_string))

