#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import codecs
import os
import torch
from typing import List

from pytorch_pretrained_bert import BertTokenizer, BertForSequenceClassification
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP, CONFIG_NAME, WEIGHTS_NAME
from torch import Tensor
from torch.nn import CrossEntropyLoss

from dlab.data import Dictionary, Sentence, pad_batch
from dlab.utils.bert import get_bert_model, save_bert_model, load_bert_model
from dlab.tasks import TextClassifier


class BertCharBaseClassifier(TextClassifier):
    """
    基于字符的Bert文本分类器。
    """
    def __init__(self, model_name: str, class_dict: Dictionary,
                 target_label: str='category',
                 use_f1=True,
                 fp16=False,
                 do_lower_case=True,
                 sentence_pair=False, spliter='|||'):
        super().__init__(class_dict, target_label, use_f1=use_f1)
        self.model_name = model_name
        self.sentence_pair = sentence_pair
        self.spliter = spliter
        self.number_labels = len(class_dict)
        self.do_lower_case = do_lower_case
        self.gpu = torch.cuda.is_available()
        self.tokenizer, self.model = get_bert_model(BertForSequenceClassification, self.model_name, do_lower_case,
                                                    num_labels=self.number_labels)
        self.fp16 = fp16
        if self.fp16:
            self.model.half()
            raise NotImplementedError('support fp16. refer from FusedAdam and  [todo]')
        self.torch = torch.cuda if self.gpu else torch

    def get_token_type_mask(self, batch_sentence: List[Sentence]):
        bert_tokens_list = []
        bert_token_types_list = []
        for sentence in batch_sentence:
            if self.sentence_pair:
                sentence_pair = sentence.text.split(self.spliter)
                if not (len(sentence_pair) == 2):
                    raise ValueError('sentence "%s" has on or more than one spliter "%s"' % (sentence.text, self.spliter))
                bert_token_words = ["[CLS]"]
                bert_token_types = []
                for sentence_id in range(2):
                    bert_token_words += self.tokenizer.tokenize(sentence_pair[sentence_id]) + ["[SEP]"]
                    bert_token_types += ([sentence_id] * (len(bert_token_words) - len(bert_token_types)))
            else:
                bert_token_words = ["[CLS]"] + self.tokenizer.tokenize(sentence.text) + ["[SEP]"]
                bert_token_types = [0] * len(bert_token_words)
            bert_tokens_list.append(self.tokenizer.convert_tokens_to_ids(bert_token_words))
            bert_token_types_list.append(bert_token_types)
        pad_bert_tokens_list, att_mask = pad_batch(bert_tokens_list, return_mask=True)
        pad_bert_token_types_list = pad_batch(bert_token_types_list)
        return [self.torch.LongTensor(matrix) for matrix in [pad_bert_tokens_list, pad_bert_token_types_list, att_mask]]

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        tensor_tokens, tensor_type, tensor_mask = self.get_token_type_mask(batch_sentence)
        logit = self.model.forward(tensor_tokens, tensor_type, tensor_mask)
        return logit

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        # the following code is copy from BertForSequenceClassification::forward in pytorch_pretrained_bert lib
        labels = self.torch.LongTensor([self.class_dict.get_idx_for_item(stn.get_label_name(self.target_label)) for stn in batch_sentence])
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logit, labels)
        return loss

    def save(self, path, **kwargs):
        save_bert_model(path, self.model, self.tokenizer)

    def load(self, path, **kwargs):
        self.tokenizer, self.model = load_bert_model(path, BertForSequenceClassification, self.do_lower_case,
                                                     num_labels=self.number_labels)


