#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import os
import sys
from typing import List, Tuple

import torch
from torch import Tensor
from pytorch_pretrained_bert import BertModel, BertTokenizer

from dlab.data.structures import Sentence
from dlab.data.utils import pad_batch
from dlab.utils.bert import get_bert_model
from .base_embedder import Embedder


class BertAddEmbedder(Embedder):
    __bert_pretrained_dims = {
        'bert-base-chinese': 768
    }

    def __init__(self, model_name: str, **kwargs):
        super().__init__(self.__bert_pretrained_dims[model_name], **kwargs)
        self.trainable = False
        self.model_name = model_name
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.model = BertModel.from_pretrained(self.model_name)

    def char_2_word_transform_matrix(self, batch_sentence: List[Sentence], seq_char_length: List[int]):
        batch_size = len(batch_sentence)
        max_word_len = max(len(sentence) for sentence in batch_sentence)
        max_char_len = max(seq_char_length)
        matrix = [[[0.0 for _ in range(max_char_len)] for __ in range(max_word_len)] for ___ in range(batch_size)]
        for sentence_id, sentence in enumerate(batch_sentence):
            start = 0
            for word_id, word in enumerate(sentence):
                word_len = len(self.tokenizer.tokenize(word.text))
                for i in range(start, start + word_len):
                    try:
                        matrix[sentence_id][word_id][i] = 1.0 / word_len
                    except IndexError as e:
                        print(e)
                start += word_len

        if torch.cuda.is_available():
            return torch.cuda.FloatTensor(matrix)
        else:
            return torch.FloatTensor(matrix)

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        batch_tokens = [
            self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(sentence.text))
            for sentence in batch_sentence]
        seq_word_length = [len(sentence) for sentence in batch_sentence]
        pad_batch_char_tokens, seq_char_length = pad_batch(batch_tokens, return_length=True)
        if torch.cuda.is_available():
            pad_batch_char_tokens_tensor = torch.cuda.LongTensor(pad_batch_char_tokens)
        else:
            pad_batch_char_tokens_tensor = torch.LongTensor(pad_batch_char_tokens)
        pad_batch_char_present, _ = self.model(pad_batch_char_tokens_tensor)
        char_2_word_matrix = self.char_2_word_transform_matrix(batch_sentence, seq_char_length)
        pad_batch_word_present: Tensor = torch.matmul(char_2_word_matrix, pad_batch_char_present[0])
        # todo check here, no update
        return pad_batch_word_present, seq_word_length


class BertEmbedderBase(Embedder):
    def __init__(self, dim, finetune=False, **kwargs):
        super().__init__(dim, **kwargs)
        self.finetune = finetune
        self.model = None

    def _set_model(self, model: BertModel):
        if self.finetune:
            self.model = model
        else:
            # keep the model beside the field of the class, so that pytorch model will not track this.
            self.model = [model]

    def _get_model(self):
        if self.finetune:
            return self.model
        else:
            return self.model[0]

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        raise NotImplementedError()


class BertEmbedder(BertEmbedderBase):
    """
    基于词的Bert Embedder。

    句子可以以词为token，每个词会被在解码之后被记录第一个字符的解码位置，最后的输出是这些记录位置的bert隐层结果。
    可以保证输入与输出 timestep 匹配。
    """
    def __init__(self, model_name: str, do_lower_case=True, finetune=False, **kwargs):
        """

        :param model_name: bert pretrained model name OR bert model dir
                           IF bert model dir, file 'bert_config.json', 'pytorch_model.bin' and 'vocab.txt'
                            must in the directory.
        :param finetune: True for finetune. False for no finetune and will not save.
        """

        self.trainable = False
        self.model_name = model_name
        self.tokenizer, model = get_bert_model(BertModel, self.model_name, do_lower_case, **kwargs)
        super().__init__(model.config.hidden_size, finetune=finetune, **kwargs)
        self._set_model(model)
        self._get_model().eval()
        self.gpu = torch.cuda.is_available()
        if self.gpu:
            self._set_model(self._get_model().cuda())
        self.torch = torch.cuda if self.gpu else torch

    def _get_bert_token_and_map(self, batch_sentence: List[Sentence]):
        tokens, orig_to_tok_map = [], []
        for sentence in batch_sentence:
            bert_tokens = ['[CLS]']
            bert_tokens_map = []
            for token in sentence:
                bert_tokens_map.append(len(bert_tokens))
                bert_tokens.extend(self.tokenizer.tokenize(token.text))
            bert_tokens.append('[SEP]')
            tokens.append(self.tokenizer.convert_tokens_to_ids(bert_tokens))
            orig_to_tok_map.append(bert_tokens_map)
        return tokens, orig_to_tok_map

    def _get_mask(self, seq_len: List[int], max_len: int):
        """

        :param seq_len: int
        :param lengths: List[int]
        :return: ByteTensor [len(lengths), seq_len]
        """
        mask = [[i < l for i in range(max_len)] for l in seq_len]
        return self.torch.ByteTensor(mask)

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        batch_tokens_idx, bert_char2word_map = self._get_bert_token_and_map(batch_sentence)
        pad_batch_char_tokens, seq_char_length = pad_batch(batch_tokens_idx, return_length=True)
        pad_batch_char_tokens_tensor = self.torch.LongTensor(pad_batch_char_tokens)
        mask = self._get_mask(seq_char_length, max(seq_char_length))
        pad_batch_char_present, _ = self._get_model()(pad_batch_char_tokens_tensor, attention_mask=mask, output_all_encoded_layers=False)
        bert_char2word_map_tensor = self.torch.LongTensor(pad_batch(bert_char2word_map))
        b_size, s_size = bert_char2word_map_tensor.size()
        bert_char2word_map_tensor = bert_char2word_map_tensor.view(b_size, s_size, 1).repeat(1, 1, pad_batch_char_present.size(2))
        pad_batch_word_present = torch.gather(pad_batch_char_present, 1, bert_char2word_map_tensor)
        return pad_batch_word_present, [len(sentence) for sentence in batch_sentence]


class BertCharEmbedder(BertEmbedder):
    """ 基于字的Bert Embedder。

    句子需要是以字符为token的，每个字符会被bert的tokenizer解码为一个数字，被解码为空的会增加[UNK]。
    可以保证输入与输出 timestep 匹配。
    """
    def tokenize_one_char(self, text: str) -> str:
        tok_list = self.tokenizer.tokenize(text)
        return '[UNK]' if len(tok_list) == 0 else tok_list[0]

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        for s in batch_sentence:
            if len(s) > 510:
                raise ValueError('bert embedder can not encode the sentence longer than 510 characters as: %s' % s.text)
        batch_token_idx = [
            self.tokenizer.convert_tokens_to_ids(['[CLS]'] + [self.tokenize_one_char(tok.text) for tok in sentence] + ['[SEP]'])
            for sentence in batch_sentence
        ]
        pad_batch_char_tokens, seq_char_length = pad_batch(batch_token_idx, return_length=True)  # seq_char_length 每一行加了2
        pad_batch_char_tokens_tensor = self.torch.LongTensor(pad_batch_char_tokens)
        mask = self._get_mask(seq_char_length, max(seq_char_length))
        pad_batch_char_present, _ = self._get_model()(pad_batch_char_tokens_tensor, attention_mask=mask,
                                                      output_all_encoded_layers=False)
        pad_batch_char_present = pad_batch_char_present.narrow(1, 1, pad_batch_char_present.size(1) - 2)
        return pad_batch_char_present, [len(sentence) for sentence in batch_sentence]


