#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Dict

from torch import Tensor

from dlab.data.structures import Sentence
from dlab.evaluate import Ratio, F1beta, MetricsValue
from dlab.tasks import BaseTask


def do_nothing(batch_sentence: List[Sentence], forward_output: Tensor, loss: Tensor=None):
    pass


def data_runner(loader, model: BaseTask,
                cal_loss=True, metrics=None, after_batch_iter_hook=do_nothing) -> Dict[str, MetricsValue]:
    """
    批量操作函数，包括了训练、验证、预测的公共部分。

    :param loader: data loader. iter out batch of Sentence
    :param model:
    :param cal_loss: 是否计算loss，出于安全考虑，默认为true。
    :param metrics: 计算其他任务相关指标。调用方式为 model.metrics('acc', batch, logit)
    :param after_batch_iter_hook: fn[(**kwargs)->None] kwargs = [optional]{loss=, batch_id=, batch_size=}
    :return: List[MetricsValue] return the required options metrics
    """
    if metrics is None:
        metrics = []

    passed_sample_num = 0
    # mocked values initial
    ret_dict = {}
    loss = None
    if cal_loss:
        ret_dict['loss'] = Ratio()

    for batch_id, batch_sentence in enumerate(loader):
        # batch count
        batch_size = len(batch_sentence)
        passed_sample_num += batch_size
        batch_end_process_print = '\rbatch: %d ' % batch_id
        # forward
        model.zero_grad()
        forward_output = model(batch_sentence)
        if cal_loss:
            loss = model.loss(batch_sentence, forward_output)
            ret_dict['loss'].update(loss.data.item(), float(len(batch_sentence)))
            batch_end_process_print += ('%s ' % ret_dict['loss'])
        for metrics_name in metrics:
            metrics_value = model.metrics(metrics_name, batch_sentence, forward_output)
            if metrics_name in ret_dict:
                ret_dict[metrics_name] += metrics_value
            else:
                ret_dict[metrics_name] = metrics_value
            batch_end_process_print += ('%s: %s ' % (metrics_name, ret_dict[metrics_name]))
        after_batch_iter_hook(forward_output=forward_output, loss=loss, batch_sentence=batch_sentence)
        print(batch_end_process_print, end='')

    print('\r', end='')

    return ret_dict
