#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Union, Dict, Tuple
import numpy as np
import torch
from torch import nn, Tensor

from dlab.data.structures import Sentence, Dictionary, Token
from dlab.data.utils import pad_batch, read_embedding
from dlab.log import log
from .base_embedder import Embedder


class NormalEmbedder(Embedder):
    """
    普通的Embedder （torch.Embedding 层的封装）。

    支持预训练加载、选择feature。
    """
    def __init__(self, dim: int, dictionary: Dictionary, pre_trained=None,
                 tag_name: Union[str, None]=None, trainable: bool=True, **kwargs):
        """

        :param dim: Embed 隐层维度
        :param dictionary: 输入字典
        :param pre_trained: path to load pretrained embeddings, None or '' to not load.
        :param tag_name: embedder 传入的特征的tag名字。None for 词本身。
        :param trainable: 是否是可以训练的。
        """
        super().__init__(dim, **kwargs)
        self.trainable = trainable
        self.dictionary: Dictionary = dictionary
        self.tag_name = tag_name
        self.pad = self.dictionary.add_pad
        self.model = nn.Embedding(len(self.dictionary), self.dim, padding_idx=0 if self.pad else None)
        if pre_trained:
            self.fill_embedding(read_embedding(pre_trained))
        if not trainable:
            self.model.weight.requires_grad = False

    def fill_embedding(self, pretrained_embedding_dict: Dict[str, List[float]]):
        np_array = self.model.weight.detach().numpy()
        embeding_size = len(self.dictionary)
        pretrained_size = len(pretrained_embedding_dict)
        filled_counter = 0
        for token in self.dictionary.item2idx:
            if token in pretrained_embedding_dict:
                np_array[self.dictionary.get_idx_for_item(token)] = pretrained_embedding_dict[token]
                filled_counter += 1
        self.model.weight.data.copy_(torch.from_numpy(np_array))
        log.debug('fill %d of embedding matrix [%dx%d] (%.2f%%) with pretrained [%dx%d] (%.2f%%).' %
                  (filled_counter, embeding_size, self.dim, filled_counter / embeding_size * 100,
                   pretrained_size, self.dim, filled_counter / pretrained_size * 100))

    def _get_token_string(self, token: Token) -> str:
        if self.tag_name is None:  # default for text
            token_string = token.text
        else:
            token_string = token.get_tag(self.tag_name)
        return token_string

    def _sentence2token_string(self, sentence: Sentence) -> List[str]:
        return [self._get_token_string(token) for token in sentence]

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        batch_token_id = [
            [self.dictionary.get_idx_for_item(token_string) for token_string in self._sentence2token_string(sentence)]
            for sentence in batch_sentence
        ]
        if self.pad:
            pad_batch_token_id, seq_len = pad_batch(batch_token_id, return_length=True)
        else:
            raise NotImplementedError('Sentence must be pad and the input dictionary did not support pad.')
        if torch.cuda.is_available():
            idx_tensor = torch.cuda.LongTensor(pad_batch_token_id)
        else:
            idx_tensor = torch.LongTensor(pad_batch_token_id)
        pad_batch_present = self.model(idx_tensor)
        return pad_batch_present, seq_len


class NGramEmbedder(NormalEmbedder):
    def __init__(self, dim: int, dictionary: Dictionary, ngram=1, ngram_spliter='', ngram_pad='B',
                 pre_trained=None, tag_name: Union[str, None]=None, trainable: bool=True, **kwargs):
        super().__init__(dim, dictionary, pre_trained=pre_trained, tag_name=tag_name, trainable=trainable, **kwargs)
        self.ngram = ngram
        self.ngram_spliter = ngram_spliter
        self.ngram_pad = ngram_pad

    def _sentence2token_string(self, sentence: Sentence) -> List[str]:
        token_strings = [self.ngram_pad] * (self.ngram - 1)
        token_strings += [self._get_token_string(token) for token in sentence]
        return [self.ngram_spliter.join(token_strings[i:i + self.ngram]) for i in range(len(sentence))]

