#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>


import os
from os import path

from typing import List
import torch

from dlab.config import add_config, inherit_config
from dlab.data.reader import TagCorpusReader, DataSet
from dlab.data.structures import Dictionary, Sentence, Token
from dlab.data.utils import fit_dict, write_tag_sentence
from dlab.embedder import StackEmbedder, NormalEmbedder, BertEmbedder, BertCharEmbedder
from dlab.eval.entity_decoder import BIOEntityDecoder
from dlab.log import log
from dlab.model import FeaturesLSTMCRFSequenceTagger, BertSequenceTagger, BertWordSequenceTagger
from dlab.process.trainer import Trainer

from .common import OTHER_FEATURE_TAG_NAME_TEMPLATE, TARGET_TAG_NAME, TAG_DICT_FILENAME_TEMPLATE


@add_config('other_feature_num', int, default=0, desc='using other features. beside word token and target tag.')
@add_config('switch_model', str, default='bert', desc='choose from ["bert", "multi"].')
@add_config(  'use_bert_embedder', int, default=0, desc='[when switch_model is "multi"] use bert as a embedder.')
@add_config(  'feature_dims', str, default='200', desc='[when switch_model is "multi"] feature dims split by ",". The number should be other_feature_num + 1')
@add_config(    'finetune_bert', int, default=0, desc='[when switch_model = "multi" and use_bert_embedder = 1] finetune bert embedder.')
@add_config('word_base', int, default=0, desc='use word base data. not char base.')
@add_config('bert_model', str, default='bert-base-chinese', desc='bert_model or bert model path.')
@add_config('target_tag_names', str, default='Nh,Ns,Ni', desc='All entity names related in tags.')
@add_config('hidden_size', int, default=200, desc='hidden size')
@add_config('save_dir', str, default='[[CONFIG_DIR]]', desc='output base dir')
def _model(config, load_from_save=False):
    save_dir = config['save_dir']

    entity_decoders = [BIOEntityDecoder(e) for e in config['target_tag_names'].split(',')]

    target_tag_dict = Dictionary.load_from_file(path.join(save_dir, TAG_DICT_FILENAME_TEMPLATE % TARGET_TAG_NAME))

    if config['switch_model'] == 'bert':
        bert_init_path = save_dir if load_from_save else config['bert_model']
        BertTaggerClass = BertWordSequenceTagger if config['word_base'] else BertSequenceTagger
        ner_model = BertTaggerClass(target_tag_dict, TARGET_TAG_NAME, bert_init_path, bio_decoders=entity_decoders)
    elif config['switch_model'] == 'multi':
        feature_names = [Token.TOKEN_TAG] + [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in range(config['other_feature_num'])]
        feature_dims = [int(feature_dim) for feature_dim in config['feature_dims'].split(',')]
        embedder_list = []
        if config['use_bert_embedder']:
            BertEmbedderClass = BertEmbedder if config['word_base'] else BertCharEmbedder
            embedder_list.append(BertEmbedderClass(config['bert_model'], finetune=config['finetune_bert']))

        if len(feature_names) != len(feature_dims):
            raise ValueError(f'features={feature_names} and dims={feature_dims} are not match')

        for feature_name, feature_dim in zip(feature_names, feature_dims):
            feature_dict = Dictionary.load_from_file(path.join(save_dir, TAG_DICT_FILENAME_TEMPLATE % feature_name))
            embedder_list.append(NormalEmbedder(feature_dim, dictionary=feature_dict, tag_name=feature_name))

        embedder = StackEmbedder(embedder_list)
        ner_model = FeaturesLSTMCRFSequenceTagger(hidden_size=config['hidden_size'], embedder=embedder,
                                                  tag_dictionary=target_tag_dict, target_tag=TARGET_TAG_NAME,
                                                  bio_decoders=entity_decoders)
    else:
        raise ValueError(f'unknown model `{config["switch_model"]}`')

    if load_from_save and config['switch_model'] == 'multi':
        ner_model.load(save_dir)
    if torch.cuda.is_available():
        ner_model = ner_model.cuda()
    ner_model.eval()
    return ner_model


@inherit_config(_model)
def _get_model_config(config):
    return config


@inherit_config(_model)
def _read_conll_corpus(config, filename: str):
    cols = [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in range(config['other_feature_num'])] + [TARGET_TAG_NAME]
    return TagCorpusReader(filename, cols).read()


@inherit_config(_model)
def _write_conll_corpus(config, filename: str, data: List[Sentence]):
    cols = [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in range(config['other_feature_num'])] + [TARGET_TAG_NAME]
    return write_tag_sentence(filename, data, cols)


@inherit_config(_read_conll_corpus)
@inherit_config(Trainer)
@add_config('token_max_num', int, default=20000, desc='token vocab max size.')
@add_config('token_min_count', int, default=2, desc='token vocab min appearance.')
@add_config('batch_size', int, default=6, desc='token vocab min appearance.')
@add_config('train_file', str, default='data/pku.ner.train.conll', desc='train set file.')
@add_config('dev_file', str, default='data/pku.ner.dev.conll', desc='dev set file.')
@add_config('test_file', str, default='data/pku.ner.test.conll', desc='test set file.')
def _train(config):
    save_dir = config['save_dir']
    log.info('model saving to : %s' % save_dir)
    os.makedirs(save_dir, exist_ok=True)

    train = _read_conll_corpus(config['train_file'])
    dev = _read_conll_corpus(config['dev_file']) if config['dev_file'] else None
    test = _read_conll_corpus(config['test_file'])

    if config['switch_model'] == 'multi':
        name_dict_tuples = []
        word_dict = fit_dict(train, Dictionary(), max_num=config['token_max_num'], min_count=config['token_min_count'])
        name_dict_tuples.append((Token.TOKEN_TAG, word_dict))

        feature_names = [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in range(config['other_feature_num'])]
        for feature_name in feature_names:
            feature_dict = fit_dict(train, Dictionary(), feature_name)
            name_dict_tuples.append((feature_name, feature_dict))

        target_dict = fit_dict(train, Dictionary(add_unk=False, add_pad=False), TARGET_TAG_NAME)
        name_dict_tuples.append((TARGET_TAG_NAME, target_dict))

        for name, dictionary in name_dict_tuples:
            dictionary.save(path.join(save_dir, TAG_DICT_FILENAME_TEMPLATE % name))

    ner_model = _model()
    pku_ner_set = DataSet(train=train, dev=dev, test=test, batch_size=6, shuffle_train_before_split_dev=True)
    trainer = Trainer(data_set=pku_ner_set, task_model=ner_model)
    trainer.train(save_dir)

