#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

"""
----------
数据格式
----------

训练、开发、测试集合（基于字，word_base=0）::

    中性       我们的纯真与失落
    积极       男人要受打击才行啊
    中性       开始喜欢威尔斯密斯~
    中性       除了借鉴它的片头形式拍过很多作业情节真的不太记得

训练、开发、测试集合（基于词，word_base=1）::

    中性       Laughing 死 了
    消极       真 想 打 0 分
    中性       雅诗兰黛 纸巾 头 100 句 经典 台词 烈火 战车 为 戴茜 小姐 开车 大白鲨 毕业生 教会 战火 浮生 淑女 伊芙易居 圣诞节
    消极       lenovo 出没 慎入 弱智 幽默
    中性       不谙世事 的 年代 啊
    中性       现代 文艺 小资 电影


每行一条语料，可以分词或不分词。注意在参数中指明你的选择。

如果你的语料中有明显的表示 ``空``、``NA``、``无`` 等特征的分类，这种分类应该被作排在第一传给 ``label_names`` 参数。

----------
常见问题
----------

- 注意训练集需要包含所有实体列标签。测试集中出现训练集未中出现的标签会报错。
- bert模型加载错误，或下载缓慢可以设置 bert_model 到预缓存位置。

"""
from typing import Dict, Any, List

import fire as google_fire
from flask import Response, Request

from dlab.data import Sentence
from .text_classify import _train, _model, _read_data
from .common import ExeAbstractTask, TARGET_TAG_NAME


class Tc(ExeAbstractTask):
    TrainFunction = _train

    def train(self, save_dir: str, train_file: str, test_file: str, label_names: str,
              use_f1: int = 0, pure_bert: int = 0,
              dev_file: str = '', word_base: int = 0, bert_model: str = 'bert-base-chinese',
              ):
        """训练一个文本分类器（此处限制只能使用字或词信息）

        :param save_dir: 模型保存位置（必要）
        :param train_file: 训练数据文件位置（必要）
        :param test_file: 测试数据文件位置（必要）
        :param label_names: ','分割的分类名列表字符串，如果use_f1=1则需要将空类设置在第一位，注意数据中不能出现列表中没有的类别（必要）
        :param use_f1: 是否使用f1值评判保存模型，否则使用ACC（默认0）
        :param pure_bert: 是否只使用bert模型的特征进行训练。这种设置下不能使用其他特征。（默认0）
        :param dev_file: 开发集位置，默认为''，从训练集自动分割。（默认''）
        :param word_base: 使用基于词的数据进行训练，数据文件中使用的是分词的数据，否则第一列都为一个字符。（默认0）
        :param bert_model: bert模型初始化位置。（默认bert-base-chinese）
        """
        label_names = ','.join(str(l) for l in label_names) if isinstance(label_names, tuple) else label_names
        local_vars = {k: v for k, v in locals().items() if k in self.train.__code__.co_varnames and k != 'self'}
        return super().train(**local_vars)

    def _manage_options(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        self.kwargs = kwargs
        config = {}
        copy_config = {'train_file', 'dev_file', 'test_file',
                       'word_base', 'bert_model', 'use_f1'}
        for k in copy_config & set(kwargs.keys()):
            config[k] = kwargs[k]

        if 'dev_file' not in kwargs:
            config['dev_file'] = ''

        config['data_format'] = 'line_tab'
        config['other_feature_num'] = '0'
        config['sentence_split'] = ''
        config['sentence_pair'] = '0'
        config['target_labels'] = kwargs['label_names']

        if kwargs['pure_bert'] and kwargs['word_base']:
            raise ValueError('pure_bert model only support char base classification.')

        config['switch_model'] = 'bert' if kwargs['pure_bert'] else 'multi'

        if config['switch_model'] == 'bert':
            config['max_epoch'] = '30'
            config['early_stopping'] = '10'
            config['optimizer'] = 'BertAdam'
            config['batch_size'] = '6'
        else:
            config['max_epoch'] = '300'
            config['early_stopping'] = '10'
            config['optimizer'] = 'Adam'
            config['batch_size'] = '32'
            if kwargs['word_base']:
                config['feature_dims'] = '200'
                config['token_max_num'] = '80000'
                config['token_min_count'] = '2'
            else:
                config['feature_dims'] = '50'
                config['token_max_num'] = '10000'
                config['token_min_count'] = '1'

        return config

    def _init__model(self):
        return _model(load_from_save=True)

    def _read_corpus(self, filename: str) -> List[Sentence]:
        return _read_data(filename)

    def _write_corpus(self, dist_file: str, data: List[Sentence]):
        with open(dist_file, 'w') as fw:
            fw.write('\n'.join([stn.get_label_name(TARGET_TAG_NAME) for stn in data]))

    def _request_to_sentence(self, _request: Request) -> Sentence:
        return super()._request_to_sentence(_request)

    def _sentence_to_response(self, sentence: Sentence) -> Response:
        return super()._sentence_to_response(sentence)


if __name__ == '__main__':
    google_fire.Fire(Tc)

