#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
from os import path
from typing import Any, Dict, List

from dlab.config import RunMockClass, fire, GlobalConfig, macro_manager
from dlab.config import MACRO_WORKING_DIR, MACRO_CONFIG_FILE, MACRO_CONFIG_DIR
from dlab.data import Sentence
from dlab.log import log
from dlab.tasks.base_task import BaseTask
from flask import Flask, Request, request, Response

OTHER_FEATURE_TAG_NAME_TEMPLATE = 'feature_%d'
TARGET_TAG_NAME = 'target'
TARGET_TAG_PREDICT_NAME = 'target_predict'
TAG_DICT_FILENAME_TEMPLATE = '%s.dict'


class ExeAbstractTaskTrainPart(object):
    TrainFunction: RunMockClass = None

    def __init__(self):
        self.model = None
        self.kwargs = None

    @staticmethod
    def _manage_options(kwargs: Dict[str, Any]) -> Dict[str, Any]:
        """
        这个类在train的时候接受其大部分构造参数，预测、evaluate是不需要，而是从配置文件取得参数。
        因此采用此方法作为附加参数，配合训练运作。这个方法用来通过阉割过的train参数生成dlab原始参数配置数组。
        注意这个方法是静态方法，我们应该保证这个方法与训练无关，通过对config的train的配置进行不同的选择。
        :param kwargs:
        :return:
        """
        return kwargs

    def _read_corpus(self, filename: str) -> List[Sentence]:
        with open(filename, 'r') as f:
            return list(Sentence.from_json(json_string) for json_string in f.read().split('\n'))

    def _get_config_path_by_save_dir(self, save_dir: str):
        return path.join(save_dir, "%s.cfg" % self.__class__.__name__.lower())

    def train(self, save_dir: str, **kwargs):
        """
        训练函数，用来保存模型参数和启动训练

        :param save_dir:
        :return:
        """
        options = self._manage_options(kwargs)
        f_name = self.TrainFunction.name
        config_path = self._get_config_path_by_save_dir(save_dir)
        fire([self.TrainFunction], [f'--save_cfg={config_path}', f_name, *[f'--{k}={v}' for k, v in options.items()]])
        log.add_file_handler(path.join(save_dir, 'train.log'))
        fire([self.TrainFunction], [f'--cfg={config_path}', f_name])
        log.remove_file_handler()


class ExeAbstractTaskEvalPart(ExeAbstractTaskTrainPart):
    def _init_config(self, save_dir: str):
        macro_manager.add_replacement(MACRO_WORKING_DIR, path.abspath('.'))
        macro_manager.add_replacement(MACRO_CONFIG_DIR, save_dir)
        macro_manager.add_replacement(MACRO_CONFIG_FILE, self._get_config_path_by_save_dir(save_dir))
        return GlobalConfig().read_from_file(self._get_config_path_by_save_dir(save_dir))

    def _init__model(self) -> BaseTask:
        """接口函数，返回预测态模型"""
        raise NotImplementedError()

    def get_model(self, save_dir: str):
        self._init_config(save_dir)
        return self._init__model()

    def evaluate(self, save_dir: str, eval_file: str):
        """
        测试该模型在某个测试集上的效果。

        :param save_dir: 必要参数。模型的存储目录路径。
        :param eval_file: 必要参数。测试集文件路径。
        """
        model = self.get_model(save_dir)
        test = self._read_corpus(eval_file)
        model.print_evaluation(test)

    def _write_corpus(self, dist_file: str, data: List[Sentence]):
        with open(dist_file, 'w') as f:
            f.write('\n'.join(s.to_json() for s in data))

    def predict(self, save_dir: str, src_file: str, dist_file: str):
        """
        预测对一个数据文件进行预测，输出到目标位置。

        :param save_dir: 必要参数。模型的存储目录路径。
        :param src_file: 必要参数。预测文件路径。
        :param dist_file: 必要参数。输出文件路径。
        """
        model = self.get_model(save_dir)
        test = self._read_corpus(src_file)
        model.predict(test)
        self._write_corpus(dist_file, test)


class ExeAbstractTaskServerPart(ExeAbstractTaskEvalPart):
    def _request_to_sentence(self, _request: Request) -> Sentence:
        json_data = _request.get_json(silent=True)
        sample = Sentence.from_dict(json_data)
        return sample

    def _sentence_to_response(self, sentence: Sentence) -> Response:
        return Response(sentence.to_json(), mimetype='application/json')

    def server(self, save_dir: str, ip: str, port: int):
        model = self.get_model(save_dir)
        app = Flask(__name__)

        @app.route('/', methods=['POST'])
        def post():
            sentence = self._request_to_sentence(request)
            model.predict([sentence])
            return self._sentence_to_response(sentence)

        app.run(host=ip, port=port)


class ExeAbstractTask(ExeAbstractTaskServerPart):
    pass
