#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import codecs
import os
import random

import torch
from typing import Tuple, List

from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP, WEIGHTS_NAME, BertPreTrainedModel, \
    CONFIG_NAME

from dlab.data import Sentence
from dlab.log import log
from .create_pretraining_data import create_training_instances


def get_bert_model(Cls, model_name, do_lower_case, **kwargs) -> Tuple[BertTokenizer, BertPreTrainedModel]:
    """
    This function define the way to load the bert tokenizer and bert_model (Cls)

    :param Cls: choice from BertModel, BertForPreTraining, BertForMaskedLM ...
    :param model_name: pretrained model_name or dir that contains `bert_config.json`, `pytorch_model.bin`, `vocab.txt`
    :param do_lower_case: do_lower_case pass to tokenizer
    :param kwargs: other params to pass to Cls.__init__
    :return: tokenizer, bert_model
    """
    log.info('restore bert model from %s' % model_name)
    if os.path.isdir(model_name):
        tokenizer, model = load_bert_model(model_name, Cls, do_lower_case, **kwargs)
    elif model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
        tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=do_lower_case)
        model = Cls.from_pretrained(model_name, **kwargs)
    else:
        raise ValueError('model_name `%s` must be neither a dir or a pretrained bert model name.' % model_name)
    return tokenizer, model


def save_bert_model(path, bert_model: BertPreTrainedModel, bert_tokenizer: BertTokenizer) -> None:
    log.info('save bert model (including weight/tokenizer/config) to %s' % path)
    # save config
    with open(os.path.join(path, CONFIG_NAME), "w") as fw:
        fw.write(bert_model.config.to_json_string())
    # save weight
    # code copy from pytorch-pretrained-BERT/blob/master/examples/run_classifier.py#L555
    model_to_save = bert_model.module if hasattr(bert_model, 'module') else bert_model  # Only save model it-self
    output_model_file = os.path.join(path, WEIGHTS_NAME)
    torch.save(model_to_save.state_dict(), output_model_file)
    # save vocab
    with codecs.open(os.path.join(path, 'vocab.txt'), 'w', encoding='utf8') as fw:
        for k in bert_tokenizer.vocab.keys():
            fw.write(k + '\n')


def load_bert_model(path, Cls, do_lower_case: bool, **kwargs) -> Tuple[BertTokenizer, BertPreTrainedModel]:
    # load config
    # load weight
    map_location = None if torch.cuda.is_available() else 'cpu'

    output_model_file = os.path.join(path, WEIGHTS_NAME)
    model_state_dict = torch.load(output_model_file, map_location=map_location)
    bert_model = Cls.from_pretrained(path, state_dict=model_state_dict, **kwargs)
    # load vocab
    tokenizer = BertTokenizer.from_pretrained(os.path.join(path, 'vocab.txt'), do_lower_case=do_lower_case)
    return tokenizer, bert_model


def get_bert_token_and_map(tokenizer: BertTokenizer, batch_sentence: List[Sentence]) -> Tuple[List[List[int]], List[List[int]]]:
    """
    将词token化为字的bert位置及映射
    :param tokenizer: bert tokenizer
    :param batch_sentence: 词作为token的句子。
    :return: (ids[B*char_length], map[B*stn_length])
    """
    tokens, orig_to_tok_map = [], []
    for sentence in batch_sentence:
        bert_tokens = ['[CLS]']
        bert_tokens_map = []
        for token in sentence:
            bert_tokens_map.append(len(bert_tokens))
            bert_tokens.extend(tokenizer.tokenize(token.text))
        bert_tokens.append('[SEP]')
        if len(bert_tokens) > 512:
            raise ValueError('Bert Position can not encode sentence longer than 510, which is `%s`' % sentence.text)
        tokens.append(tokenizer.convert_tokens_to_ids(bert_tokens))
        orig_to_tok_map.append(bert_tokens_map)
    return tokens, orig_to_tok_map


def generate_bert_pretrained_instances(all_document: List[List[str]],
                                       tokenizer_model_name: str,
                                       do_lower_case: bool = True,
                                       random_seed: int = 12345,
                                       max_seq_length: int = 128,
                                       dupe_factor: int = 10,
                                       short_seq_prob: float = 0.1,
                                       masked_lm_prob: float = 0.15,
                                       max_predictions_per_seq: int = 20,
                                       ) -> List[Sentence]:
    """
    use https://github.com/google-research/bert to generate samples for pretrain and finetune.

    :param all_document: List[List[str]] str is a sentence.
    :param tokenizer_model_name: str
    :param do_lower_case: tokenizer if do_lower_case
    :param random_seed: Random seed for data generation.
    :param max_seq_length: Maximum sequence length.
    :param dupe_factor: Number of times to duplicate the input data (with different masks).
    :param short_seq_prob: Probability of creating sequences which are shorter than the maximum length.
    :param masked_lm_prob: Masked LM probability.
    :param max_predictions_per_seq: Maximum number of masked LM predictions per sequence.
    :return: List[Sentence]
    """
    if os.path.isdir(tokenizer_model_name):
        tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_model_name, 'vocab.txt'),
                                                  do_lower_case=do_lower_case)
    elif tokenizer_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
        tokenizer = BertTokenizer.from_pretrained(tokenizer_model_name, do_lower_case=do_lower_case)
    else:
        raise ValueError(
            'model_name `%s` must be neither a dir or a pretrained bert model name.' % tokenizer_model_name)
    rng = random.Random(random_seed)

    # tokenize
    for document in all_document:
        for stn_id, stn in enumerate(document):
            document[stn_id] = tokenizer.tokenize(stn)

    bert_origin_train_instances = create_training_instances(all_document, tokenizer,
                                                            max_seq_length=max_seq_length,
                                                            dupe_factor=dupe_factor,
                                                            short_seq_prob=short_seq_prob,
                                                            masked_lm_prob=masked_lm_prob,
                                                            max_predictions_per_seq=max_predictions_per_seq,
                                                            rng=rng)
    ret = []
    for train_instance in bert_origin_train_instances:
        stn = Sentence(train_instance.tokens)
        stn.set_label('is_random_next', train_instance.is_random_next)
        for mask_tok_id, mask_tok_ans in zip(train_instance.masked_lm_positions, train_instance.masked_lm_labels):
            stn.get_token(mask_tok_id).add_tag('mask_lm_label', mask_tok_ans)
        for tok, seg_id in zip(stn, train_instance.segment_ids):
            tok.add_tag('segment_id', seg_id)
        ret.append(stn)

    return ret

