#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import os
from typing import List

import torch
from torch import Tensor

from dlab.data.structures import Dictionary, Sentence
from dlab.evaluate import MetricsValue
from dlab.log import log


class BaseTask(torch.nn.Module):
    """
    可训练模型，完整模型（区别于LSTM这种 模型的一部分）。

    定义接口。
     - 完整模型应有 forward, loss
     - 针对特定任务应该实现相应的（应在特定任务的通用基类上实现）
        - acc, f1（Tensor 级别的，可以在GPU上高速运算，在训练中可以大量执行来看到结果）
        - predict，print_evaluation（python级别的，较慢，但能输出详细的信息）
        - save, load
    """
    def __init__(self, metrics_list: List[str], eval_metrics: str):
        """

        :param metrics_list: 任务实现的所有评价方法（不包括loss），这些方法会在 train, evaluate 时打印。
        :param eval_metrics: 保存模型使用的重要指标。要不在 metrics_list 中，要不为 'nloss' (negative loss)
        """
        super().__init__()
        self.metrics_list = metrics_list
        self.eval_metrics = eval_metrics

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        raise NotImplementedError()

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        raise NotImplementedError()

    def metrics(self, metrics_name: str, batch_sentence: List[Sentence], logit: Tensor) -> MetricsValue:
        if metrics_name not in self.metrics_list:
            raise ValueError('metrics=`%s` not supported. you can use %s instead.' % (metrics_name, self.metrics_list))
        if not hasattr(self, metrics_name):
            raise NotImplementedError('Method `%s:%s` not implement.' % (self.__class__.__name__, metrics_name))
        return getattr(self, metrics_name)(batch_sentence, logit)

    def print_evaluation(self, sentences: List[Sentence]) -> None:
        """ 打印更为详细的信息；python级别的，较慢； """
        pass

    def predict(self, sentences: List[Sentence], batch_size=4):
        raise NotImplementedError()

    def save(self, path, model_name='best.ml'):
        file_name = os.path.join(path, model_name)
        log.info('Model saving to %s' % file_name)
        torch.save(self.state_dict(), file_name)

    def load(self, path, model_name='best.ml'):
        file_name = os.path.join(path, model_name)
        log.info('Model restoring from %s' % file_name)
        map_location = None if torch.cuda.is_available() else 'cpu'
        self.load_state_dict(torch.load(file_name, map_location=map_location))
