#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
"""

训练，预测，测试数据文件格式如下::

    1	B	nt	B-nt	_10	-	B-Time
    9	I	nt	I-nt	_10	-	I-Time
    0	I	nt	I-nt	_10	-	I-Time
    5	I	nt	I-nt	_9	-	I-Time
    年	E	nt	I-nt	_9	-	I-Time
    6	B	nt	I-nt	_9	-	I-Time
    月	E	nt	I-nt	_9	-	I-Time
    2	B	nt	I-nt	_9	-	I-Time
    7	I	nt	I-nt	_9	-	I-Time
    日	E	nt	I-nt	_9	-	I-Time
    ，	S	wp	O	_9	-	O
    “	S	wp	O	_9	-	O
    波	S	n	B-nh	_9	-	O
    将	S	d	I-nh	_9	-	O
    金	B	nz	I-nh	_9	-	O
    号	E	nz	I-nh	_9	-	O
    ”	S	wp	O	_9	-	O
    在	S	p	O	_8	-	O
    乌	B	ns	O	_8	-	B-Place
    克	I	ns	O	_8	-	I-Place
    兰	E	ns	O	_8	-	I-Place
    海	B	n	O	_8	-	I-Place
    岸	E	n	O	_8	-	I-Place
    坦	B	ns	O	_8	-	I-Place
    特	I	ns	O	_8	-	I-Place
    拉	I	ns	O	_7	-	I-Place
    岛	E	ns	O	_7	-	I-Place
    （	S	wp	O	_7	-	O
    英	B	nz	O	_7	-	O
    语	E	nz	O	_7	-	O
    ：	S	wp	O	_7	-	O
    T	B	ws	O	_6	-	B-Place
    e	I	ws	O	_6	-	I-Place
    n	I	ws	O	_6	-	I-Place
    d	I	ws	O	_6	-	I-Place
    r	I	ws	O	_5	-	I-Place
    a	I	ws	O	_5	-	I-Place
    I	I	ws	O	_5	-	I-Place
    s	I	ws	O	_4	-	I-Place
    l	I	ws	O	_4	-	I-Place
    a	I	ws	O	_4	-	I-Place
    n	I	ws	O	_3	-	I-Place
    d	E	ws	O	_3	-	I-Place
    ）	S	wp	O	_2	-	O
    附	B	nd	O	_2	-	O
    近	E	nd	O	_2	-	O
    进	B	v	O	_1	-	O
    行	E	v	O	_1	-	O
    炮	B	v	O	_1	-	B-Type
    击	E	v	O	_1	-	I-Type
    演	B	v	O	_0	B-Y	B-E_1
    练	E	v	O	_0	I-Y	I-E_1
    时	S	n	O	_1	-	O
    ，	S	wp	O	_1	-	O
    许	B	m	B-nh	_1	-	O
    多	E	m	I-nh	_1	-	O
    士	B	n	I-nh	_2	-	O
    兵	E	n	I-nh	_2	-	O
    拒	B	v	O	_2	-	O
    绝	E	v	O	_3	-	O
    吃	S	v	O	_3	-	O
    以	S	p	O	_4	-	O
    腐	B	a	O	_4	-	O
    败	E	a	O	_4	-	O
    且	S	c	O	_5	-	O
    生	B	v	O	_5	-	O
    蛆	E	v	O	_5	-	O
    的	S	u	O	_6	-	O
    肉	S	n	O	_6	-	O
    作	B	v	O	_6	-	O
    成	E	v	O	_6	-	O
    的	S	u	O	_7	-	O
    罗	B	n	O	_7	-	O
    宋	I	n	O	_7	-	O
    汤	E	n	O	_7	-	O
    。	S	wp	O	_7	-	O


每列信息分别为:

- [0] 字token信息
- [1] BIES表示的分词信息
- [2] 词性信息
- [3] BIO表示的实体信息
- [4] 与核心词距离信息（如果构造该信息需要分有限桶）
- [5] 触发词信息
- [6] 论元目标标签


这里附加特征列为5列（除首末列外），``other_feature_num = 5``，如果特征数量发送改变，则改变这个数值。

"""

from typing import Dict, Any, List

import fire as google_fire
from flask import Response, Request

from dlab.data import Sentence
from .sequence_tagging import _train, _model, _read_conll_corpus, _write_conll_corpus
from .common import ExeAbstractTask


class Event(ExeAbstractTask):
    TrainFunction = _train

    def train(self, save_dir: str, train_file: str, test_file: str, event_names: str, arguments_names: str,
              other_feature_num: int = 5,
              dev_file: str = '', use_bert: int = 0, word_base: int = 1, bert_model: str = 'bert-base-chinese',
              ):
        """
        训练一个事件论元抽取器。

        :param save_dir: 模型保存位置（必要）
        :param train_file: 训练数据文件位置（必要）
        :param test_file: 测试数据文件位置（必要）
        :param event_names: ','分割的事件名列表字符串，数据中需使用 BIO格式与其配合（必要）
        :param arguments_names: ','分割的论元名列表字符串，数据中需使用 BIO格式与其配合（必要）
        :param other_feature_num: 除了论元列与词（字）一列，其他特征列数。（默认5）
        :param dev_file: 开发集位置，默认为''，从训练集自动分割。（默认''）
        :param use_bert: 使用bert模型进行训练。（默认0）
        :param bert_model: bert模型初始化位置。（默认bert-base-chines）
        :param word_base: 使用基于词的数据进行训练，数据文件中使用的是分词的数据，否则第一列都为一个字符。（默认1）
        :return:
        """
        event_names = ','.join(event_names) if isinstance(event_names, tuple) else event_names
        arguments_names = ','.join(arguments_names) if isinstance(arguments_names, tuple) else arguments_names
        local_vars = {k: v for k, v in locals().items() if k in self.train.__code__.co_varnames and k != 'self'}
        return super().train(**local_vars)

    def _manage_options(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        self.kwargs = kwargs
        config = {}
        copy_config = {'train_file', 'dev_file', 'test_file',
                       'word_base',
                       'other_feature_num'}

        config['target_tag_names'] = kwargs['event_names'] + ',' + kwargs['arguments_names']

        for k in copy_config & set(kwargs.keys()):
            config[k] = kwargs[k]

        if 'dev_file' not in kwargs:
            config['dev_file'] = ''

        config['switch_model'] = 'multi'

        config['feature_dims'] = ','.join(['200'] + ['50'] * kwargs['other_feature_num'])

        if kwargs['use_bert']:
            config['max_epoch'] = '20'
            config['early_stopping'] = '10'
            config['optimizer'] = 'BertAdam'
            config['batch_size'] = '6'
        else:
            config['max_epoch'] = '300'
            config['early_stopping'] = '10'
            config['optimizer'] = 'Adam'
            config['batch_size'] = '32'

        if kwargs['word_base']:
            config['token_max_num'] = '80000'
            config['token_min_count'] = '2'
        else:
            config['token_max_num'] = '10000'
            config['token_min_count'] = '1'

        return config

    def _init__model(self):
        return _model(load_from_save=True)

    def _read_corpus(self, filename: str) -> List[Sentence]:
        return _read_conll_corpus(filename)

    def _write_corpus(self, dist_file: str, data: List[Sentence]):
        return _write_conll_corpus(dist_file, data)

    def _request_to_sentence(self, _request: Request) -> Sentence:
        return super()._request_to_sentence(_request)

    def _sentence_to_response(self, sentence: Sentence) -> Response:
        return super()._sentence_to_response(sentence)


if __name__ == '__main__':
    google_fire.Fire(Event)

