#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import torch
from typing import List
import os

from pytorch_pretrained_bert import BertModel, BertForTokenClassification
from torch import Tensor, nn

from dlab.data import Sentence, Dictionary, pad_batch
from dlab.eval.entity_decoder import EntityDecoder
from dlab.utils.bert import get_bert_model, save_bert_model, load_bert_model, get_bert_token_and_map

from dlab.tasks import SequenceTagger


class BertSequenceTagger(SequenceTagger):
    def __init__(self, tag_dictionary: Dictionary, target_tag: str,
                 bert_model_name: str, do_lower_case=True,
                 use_crf=False, bio_decoders: List[EntityDecoder] = None,
                 use_f1=True,
                 ):
        super().__init__(tag_dictionary, target_tag, use_crf=use_crf, bio_decoders=bio_decoders, use_f1=use_f1)
        self.model_name: str = bert_model_name
        self.do_lower_case = do_lower_case
        self.tokenizer, self.bert_model = get_bert_model(BertForTokenClassification, self.model_name, do_lower_case,
                                                         num_labels=len(self.target_tag_dictionary))

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        for s in batch_sentence:
            if len(s) > 510:
                raise ValueError('bert embedder can not encode the sentence longer than 510 characters as: %s' % s.text)
        batch_token_idx = [
            self.tokenizer.convert_tokens_to_ids(
                ['[CLS]'] + [self.tokenize_one_char(tok.text) for tok in sentence] + ['[SEP]'])
            for sentence in batch_sentence
        ]
        pad_batch_char_tokens, seq_char_length = pad_batch(batch_token_idx,
                                                           return_length=True)  # seq_char_length 每一行加了2
        pad_batch_char_tokens_tensor = self.torch.LongTensor(pad_batch_char_tokens)
        mask = self._get_mask(batch_token_idx, max(seq_char_length))
        pad_batch_char_present = self.bert_model(pad_batch_char_tokens_tensor, attention_mask=mask)
        pad_batch_char_present = pad_batch_char_present.narrow(1, 1, pad_batch_char_present.size(1) - 2)
        return pad_batch_char_present

    def tokenize_one_char(self, text: str) -> str:
        tok_list = self.tokenizer.tokenize(text)
        return '[UNK]' if len(tok_list) == 0 else tok_list[0]

    def save(self, path, model_name=None):
        save_bert_model(path, self.bert_model, self.tokenizer)

    def load(self, path, model_name=None):
        self.tokenizer, self.bert_model = load_bert_model(path, BertForTokenClassification, self.do_lower_case,
                                                          num_labels=len(self.target_tag_dictionary))


class BertWordSequenceTagger(BertSequenceTagger):

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        batch_tokens_idx, bert_char2word_map = get_bert_token_and_map(self.tokenizer, batch_sentence)
        pad_batch_char_tokens, seq_char_length, mask = pad_batch(batch_tokens_idx, return_length=True, return_mask=True)
        pad_batch_char_tokens_tensor = self.torch.LongTensor(pad_batch_char_tokens)
        mask = self.torch.ByteTensor(mask)
        pad_batch_tag_prob_on_char = self.bert_model(pad_batch_char_tokens_tensor, attention_mask=mask)
        bert_char2word_map_tensor = self.torch.LongTensor(pad_batch(bert_char2word_map))
        b_size, s_size = bert_char2word_map_tensor.size()
        bert_char2word_map_tensor = bert_char2word_map_tensor.view(b_size, s_size, 1). \
            repeat(1, 1, pad_batch_tag_prob_on_char.size(2))
        pad_batch_tag_prob_on_word = torch.gather(pad_batch_tag_prob_on_char, 1, bert_char2word_map_tensor)
        return pad_batch_tag_prob_on_word
