#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
"""
--------
示例数据
--------

以下是一个例子，表示 ``AGM-131`` 与 ``B-1`` 的关系是 ``/missile/aircraft/part``。
关系名称使用 ``__`` 前导符，引导在句子前一行。之后每行一个词（字），第一列是词（字）本身，后面每一列是其对应特征。
本例中有两个额外特征，因此训练时需要设置 ``other_feature_num=2`` 。
每个例子之间需要使用空行分隔。

下面是一个例子::

    __/missile/aircraft/part
    B	B-AIRCRAFT	B-e2
    -	I-AIRCRAFT	I-e2
    1	I-AIRCRAFT	I-e2
    被	O	O
    命	O	O
    名	O	O
    为	O	O
    A	B-MISSILE	B-e1
    G	I-MISSILE	I-e1
    M	I-MISSILE	I-e1
    -	I-MISSILE	I-e1
    1	I-MISSILE	I-e1
    3	I-MISSILE	I-e1
    1	I-MISSILE	I-e1
    “	O	O
    斯	O	O
    哈	O	O
    姆	O	O
    ”	O	O
    I	O	O
    I	O	O
    。	O	O

---------
常见问题
---------

1. 关系抽取在这里被抽象成一个多特征的句子分类任务。每次仅能分类其中的一对实体上的关系。
因此，我们在生成数据和预测时需要遍历一句话的所有实体对，并标注他们的关系类型，对于没有关系的实体需要标注``无``。从而生成多个样本。


"""

from typing import Dict, Any, List

import fire as google_fire
from flask import Response, Request

from dlab.data import Sentence
from .text_classify import _train, _model, _read_data
from .common import ExeAbstractTask, TARGET_TAG_NAME


class Rel(ExeAbstractTask):
    TrainFunction = _train

    def train(self, save_dir: str, train_file: str, test_file: str, label_names: str,
              other_feature_num: int = 1,
              dev_file: str = '', word_base: int = 1,
              use_bert_embedder: int = 0, bert_model: str = 'bert-base-chinese'
              ):
        """训练一个文本分类器

        :param save_dir: 模型保存位置（必要）
        :param train_file: 训练数据文件位置（必要）
        :param test_file: 测试数据文件位置（必要）
        :param label_names: ','分割的关系名列表字符串，表示无关的标签需要设置在首位（必要）
        :param other_feature_num: 除了词（字）一列，其他特征列数。应对应传入列数-1（默认1）
        :param dev_file: 开发集位置，默认为''，从训练集自动分割。（默认''）
        :param word_base: 使用基于词的数据进行训练，数据文件中使用的是分词的数据，否则第一列都为一个字符。（默认1）
        :param use_bert_embedder: 使用bert作为一个Embedder特征（默认0）
        :param bert_model: [when use_bert_embedder=1] bert Embedder 位置，注意这里不做finetune，这个位置部署时也要拷贝。（默认bert-base-chinese）
        :return:
        """
        label_names = ','.join(str(l) for l in label_names) if isinstance(label_names, tuple) else label_names
        local_vars = {k: v for k, v in locals().items() if k in self.train.__code__.co_varnames and k != 'self'}
        return super().train(**local_vars)

    def _manage_options(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        self.kwargs = kwargs
        config = {}
        copy_config = {'train_file', 'dev_file', 'test_file',
                       'word_base', 'use_bert_embedder', 'bert_model',
                       'other_feature_num'}
        for k in copy_config & set(kwargs.keys()):
            config[k] = kwargs[k]

        if 'dev_file' not in kwargs:
            config['dev_file'] = ''

        config['data_format'] = 'conll'
        config['sentence_split'] = ''
        config['sentence_pair'] = '0'
        config['use_f1'] = '1'
        config['target_labels'] = kwargs['label_names']

        config['switch_model'] = 'multi'

        config['feature_dims'] = ','.join(['200'] + ['50'] * kwargs['other_feature_num'])

        if kwargs['use_bert_embedder']:
            config['max_epoch'] = '30'
            config['early_stopping'] = '10'
            config['optimizer'] = 'BertAdam'
            config['batch_size'] = '6'
            config['rnn_layers'] = '1'
        else:
            config['max_epoch'] = '300'
            config['early_stopping'] = '10'
            config['optimizer'] = 'Adam'
            config['batch_size'] = '32'
            config['rnn_layers'] = '2'

        if kwargs['word_base']:
            config['token_max_num'] = '80000'
            config['token_min_count'] = '2'
        else:
            config['token_max_num'] = '10000'
            config['token_min_count'] = '1'

        return config

    def _init__model(self):
        return _model(load_from_save=True)

    def _read_corpus(self, filename: str) -> List[Sentence]:
        return _read_data(filename)

    def _write_corpus(self, dist_file: str, data: List[Sentence]):
        with open(dist_file, 'w') as fw:
            fw.write('\n'.join([stn.get_label_name(TARGET_TAG_NAME) for stn in data]))

    def _request_to_sentence(self, _request: Request) -> Sentence:
        return super()._request_to_sentence(_request)

    def _sentence_to_response(self, sentence: Sentence) -> Response:
        return super()._sentence_to_response(sentence)


if __name__ == '__main__':
    google_fire.Fire(Rel)

