#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

"""
-----------
数据格式
-----------

训练、开发、测试集合::

    正在	d	O
    执行	v	O
    第十四	m	O
    次	q	O
    南极	ns	S-Ns
    考察	v	O
    任务	n	O
    的	u	O
    中国	ns	S-Ns
    考察队员	n	O
    ，	wp	O
    目前	nt	O
    分别	d	O
    在	p	O
    长城站	ns	O
    、	wp	O
    中山站	ns	O
    和	c	O
    “	wp	O
    雪龙	nz	O
    ”	wp	O
    号	n	O
    船上	nl	O
    。	wp	O

    宋健	nh	S-Nh
    向	p	O
    考察队员	n	O
    们	k	O
    说	v	O
    ，	wp	O
    你们	r	O
    的	u	O
    工作	v	O
    环境	n	O
    很	d	O
    艰苦	a	O
    ，	wp	O
    任务	n	O
    很	d	O
    艰巨	a	O
    。	wp	O


使用类conll格式，第一列为词，第二列为词性，第三列为实体标签（采用BIESO，也兼容BIO格式）格式。

这里附加特征列（词性）为一列，``other_feature_num = 1`` ，如果特征数量发送改变，则改变这个数值。

------------
常见问题
------------

- 注意训练集需要包含所有实体列标签。测试集中出现训练集未中出现的标签会报错。
- bert模型加载错误，或下载缓慢可以设置 bert_model 到预缓存位置。

"""
from typing import Dict, Any, List

import fire as google_fire
from flask import Response, Request

from dlab.data import Sentence
from .sequence_tagging import _train, _model, _read_conll_corpus, _write_conll_corpus
from .common import ExeAbstractTask


class Ner(ExeAbstractTask):
    TrainFunction = _train

    def train(self, save_dir: str, train_file: str, test_file: str, entity_names: str,
              other_feature_num: int = 1, pure_bert: int = 0,
              dev_file: str = '', use_bert: int = 0, word_base: int = 1, bert_model: str = 'bert-base-chinese',
              ):
        """训练一个实体识别器

        :param save_dir: 模型保存位置（必要）
        :param train_file: 训练数据文件位置（必要）
        :param test_file: 测试数据文件位置（必要）
        :param entity_names: ','分割的实体名列表字符串（必要）
        :param other_feature_num: 除了实体列与词（字）一列，其他特征列数。（默认1）
        :param pure_bert: 是否只使用bert模型的特征进行训练。（默认0）
        :param dev_file: 开发集位置，默认为''，从训练集自动分割。（默认''）
        :param use_bert: 使用bert模型进行训练。（默认0）
        :param word_base: 使用基于词的数据进行训练，数据文件中使用的是分词的数据，否则第一列都为一个字符。（默认1）
        :param bert_model: bert模型初始化位置。（默认bert-base-chinese）
        """
        entity_names = ','.join(entity_names) if isinstance(entity_names, tuple) else entity_names
        local_vars = {k: v for k, v in locals().items() if k in self.train.__code__.co_varnames and k != 'self'}
        return super().train(**local_vars)

    def _manage_options(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        self.kwargs = kwargs
        config = {}
        copy_config = {'train_file', 'dev_file', 'test_file',
                       'word_base',
                       'other_feature_num'}
        for k in copy_config & set(kwargs.keys()):
            config[k] = kwargs[k]

        if 'dev_file' not in kwargs:
            config['dev_file'] = ''

        config['target_tag_names'] = kwargs['entity_names']

        config['switch_model'] = 'bert' if kwargs['pure_bert'] else 'multi'

        config['feature_dims'] = ','.join(['200'] + ['50'] * kwargs['other_feature_num'])

        if kwargs['use_bert']:
            config['max_epoch'] = '30'
            config['early_stopping'] = '10'
            config['optimizer'] = 'BertAdam'
            config['batch_size'] = '6'
        else:
            config['max_epoch'] = '300'
            config['early_stopping'] = '10'
            config['optimizer'] = 'Adam'
            config['batch_size'] = '32'

        if kwargs['word_base']:
            config['token_max_num'] = '80000'
            config['token_min_count'] = '2'
        else:
            config['token_max_num'] = '10000'
            config['token_min_count'] = '1'

        return config

    def _init__model(self):
        return _model(load_from_save=True)

    def _read_corpus(self, filename: str) -> List[Sentence]:
        return _read_conll_corpus(filename)

    def _write_corpus(self, dist_file: str, data: List[Sentence]):
        return _write_conll_corpus(dist_file, data)

    def _request_to_sentence(self, _request: Request) -> Sentence:
        return super()._request_to_sentence(_request)

    def _sentence_to_response(self, sentence: Sentence) -> Response:
        return super()._sentence_to_response(sentence)


if __name__ == '__main__':
    google_fire.Fire(Ner)

