#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

import os
from os import path
from typing import List

import torch

from dlab.config import add_config, inherit_config, fire
from dlab.data import Dictionary, LineBasedTCCorpusReader, DataSet, Sentence, LabelCorpusReader, fit_dict, Token
from dlab.embedder import BertEmbedder, BertCharEmbedder, NormalEmbedder, StackEmbedder
from dlab.log import log
from dlab.model import BertCharBaseClassifier, FeatureLSTMClassifier
from dlab.process.trainer import Trainer

from .common import OTHER_FEATURE_TAG_NAME_TEMPLATE, TARGET_TAG_NAME, TAG_DICT_FILENAME_TEMPLATE


@add_config('switch_model', str, default='bert', desc='choose from ["bert", "multi"].')
@add_config('use_bert_embedder', int, default=0, desc='[when switch_model is "multi"] use bert as a embedder.')
@add_config('other_feature_num', int, default=0,
            desc='[when switch_model="multi"], using other features. beside word token and target tag.')
@add_config('feature_dims', str, default='200',
            desc='[when switch_model is "multi"] feature dims split by ",". The number should be other_feature_num + 1')
@add_config('bert_model', str, default='bert-base-chinese',
            desc='[when switch_model="bert" or use_bert_embedder], bert_model or bert model path.')
@add_config('finetune_bert', int, default=0,
            desc='[when switch_model="multi" and use_bert_embedder], bert embedder will be fine-tuned.')
@add_config('word_base', int, default=0, desc='use word base data. not char base.')
@add_config('save_dir', str, default='[[CONFIG_DIR]]', desc='output base dir')
@add_config('target_labels', str, default='', desc='target class names, the negative one should be set at first.')
@add_config('sentence_pair', int, default=0, desc='if the sample is 2 sentences join with a spliter.')
@add_config('sentence_spliter', str, default='|||',
            desc='[when sentence_pair=1], sentence spliter when sentence_pair is set to 1.')
@add_config('hidden_size', int, default=200, desc='[when switch_model is "multi"] rnn hidden dim.')
@add_config('rnn_layers', int, default=1, desc='[when switch_model is "multi"] rnn hidden dim.')
@add_config('use_f1', int, default=0, desc='if judge as f1.')
def _model(config, load_from_save=False):
    save_dir = config['save_dir']

    # dictionary management
    label_dict = Dictionary(add_pad=False, add_unk=False)
    if config['target_labels'] == '':
        raise ValueError('target labels not set')
    else:
        for label in config['target_labels'].split(','):
            label_dict.add_item(label)

    # model switch
    if config['switch_model'] == 'bert':
        bert_init_path = save_dir if load_from_save else config['bert_model']
        tc_model = BertCharBaseClassifier(bert_init_path, label_dict,
                                          use_f1=config['use_f1'], target_label=TARGET_TAG_NAME,
                                          sentence_pair=config['sentence_pair'], spliter=config['sentence_spliter'])

    elif config['switch_model'] == 'multi':
        feature_names = [Token.TOKEN_TAG] + [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in
                                             range(config['other_feature_num'])]
        feature_dims = [int(feature_dim) for feature_dim in config['feature_dims'].split(',')]
        embedder_list = []
        if config['use_bert_embedder']:
            BertEmbedderClass = BertEmbedder if config['word_base'] else BertCharEmbedder
            embedder_list.append(BertEmbedderClass(config['bert_model'], finetune=config['finetune_bert']))

        if len(feature_names) != len(feature_dims):
            raise ValueError(f'features={feature_names} and dims={feature_dims} are not match')

        for feature_name, feature_dim in zip(feature_names, feature_dims):
            feature_dict = Dictionary.load_from_file(path.join(save_dir, TAG_DICT_FILENAME_TEMPLATE % feature_name))
            embedder_list.append(NormalEmbedder(feature_dim, dictionary=feature_dict, tag_name=feature_name))

        embedder = StackEmbedder(embedder_list)
        tc_model = FeatureLSTMClassifier(label_dict, embedder=embedder, rnn_hidden_size=config['hidden_size'],
                                         target_label=TARGET_TAG_NAME,
                                         rnn_layers=config['rnn_layers'])
    else:
        raise ValueError(f'unknown model `{config["switch_model"]}`')

    if load_from_save and config['switch_model'] == 'multi':
        tc_model.load(save_dir)
    if torch.cuda.is_available():
        tc_model = tc_model.cuda()
    tc_model.eval()
    return tc_model


@inherit_config(_model)
@add_config('data_format', str, 'conll', desc='choose from [conll, line_tab]')
@add_config('sentence_split', str, default='\t', desc='line sentence field spliter. "" is all blanks.')
def _read_data(config, file_path):
    sentence_split = None if config['sentence_split'] == '' else config['sentence_split']
    if config['data_format'] == 'conll':
        feature_names = [Token.TOKEN_TAG] + [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in
                                             range(config['other_feature_num'])]
        return LabelCorpusReader(file_path=file_path, label_name=TARGET_TAG_NAME, tag_names=feature_names).read()
    elif config['data_format'] == 'line_tab':
        return LineBasedTCCorpusReader(file_path=file_path, char_base=not config['word_base'],
                                       label_name=TARGET_TAG_NAME, split=sentence_split).read()
    else:
        raise ValueError(f'unknown data_format `{config["data_format"]}`')


@inherit_config(_read_data)
@inherit_config(Trainer)
@add_config('token_max_num', int, default=20000, desc='token vocab max size.')
@add_config('token_min_count', int, default=2, desc='token vocab min appearance.')
@add_config('batch_size', int, default=6, desc='token vocab min appearance.')
@add_config('train_file', str, default='data/pku.ner.train.conll', desc='train set file.')
@add_config('dev_file', str, default='data/pku.ner.dev.conll', desc='dev set file.')
@add_config('test_file', str, default='data/pku.ner.test.conll', desc='test set file.')
def _train(config):
    save_dir = config['save_dir']
    log.info('model saving to : %s' % save_dir)
    os.makedirs(save_dir, exist_ok=True)

    train = _read_data(config['train_file'])
    dev = None if config['dev_file'] == '' else _read_data(config['dev_file'])
    test = _read_data(config['test_file'])

    if config['switch_model'] == 'multi':
        name_dict_tuples = []
        word_dict = fit_dict(train, Dictionary(), max_num=config['token_max_num'], min_count=config['token_min_count'])
        name_dict_tuples.append((Token.TOKEN_TAG, word_dict))

        feature_names = [OTHER_FEATURE_TAG_NAME_TEMPLATE % n for n in range(config['other_feature_num'])]
        for feature_name in feature_names:
            feature_dict = fit_dict(train, Dictionary(), feature_name)
            name_dict_tuples.append((feature_name, feature_dict))

        for name, dictionary in name_dict_tuples:
            dictionary.save(path.join(save_dir, TAG_DICT_FILENAME_TEMPLATE % name))

    tc_model = _model()
    pku_ner_set = DataSet(train=train, dev=dev, test=test, batch_size=6, dev_split=0.1,
                          shuffle_train_before_split_dev=True)
    trainer = Trainer(data_set=pku_ner_set, task_model=tc_model)
    trainer.train(save_dir)
