#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import os
import torch
from typing import List, Type

from torch import Tensor, optim

from dlab.config import add_config, inherit_config
from dlab.data.reader import DataSet
from dlab.data.structures import Sentence
from dlab.data.utils import EarlyStopping
from dlab.log import log
from dlab.tasks import BaseTask
from dlab.optimizer import BertAdamOptimizer, InlineBertAdamOptimizer
from .common import data_runner


class BaseTrainer(object):
    """
    训练过程封装。封装了基本的训练逻辑。

    这是一个可配置类。
    """

    def __init__(self, data_set: DataSet, task_model: BaseTask, early_stopping=10, check_point='',
                 optimizer_name='Adam'):
        super().__init__()
        self.data_set = data_set
        self.model = task_model
        self.early_stopping = early_stopping
        self.check_point = check_point
        self.optimizer_name = optimizer_name
        self.optimizer = self._get_optimizer(self.optimizer_name)

    def _get_optimizer(self, optimizer_name: str):
        if optimizer_name == 'Adam':
            return optim.Adam(self.model.parameters())
        else:
            raise ValueError('unknown optimizer "%s"' % optimizer_name)

    # interface
    def _iter_train(self):
        return self.data_set.train_set

    # interface
    def _iter_epoch(self) -> str:
        """ iter the outer space """
        raise NotImplementedError()

    def train(self, save_dir: str):
        """训练，并将最好的模型存储在 save_dir

        :param save_dir: 保存模型位置
        :return: None
        """
        log.info('start training.')

        def train_step_hook(forward_output: Tensor, loss: Tensor, batch_sentence: List[Sentence]):
            loss.backward()
            self.optimizer.step()

        early_stopping = EarlyStopping(patient=self.early_stopping)
        max_score = float('-inf')
        train_metrics = ['loss'] + self.model.metrics_list
        for epoch_message in self._iter_epoch():
            log.info('%s -------------' % epoch_message)
            # train step
            self.model.train()
            metrics = data_runner(self._iter_train(), self.model, metrics=self.model.metrics_list, cal_loss=True,
                                  after_batch_iter_hook=train_step_hook)

            log.info('training -- %s' % (' '.join('%s: %s' % (k, metrics[k]) for k in train_metrics)))

            with torch.no_grad():
                # develop evaluation part
                cur_score = self._eval(self.data_set.dev_set, 'developing')

                if cur_score > max_score:
                    log.info('max score update from %.4f to %.4f' % (max_score, cur_score))
                    self.model.save(save_dir)
                    max_score = cur_score
                    if self.data_set.test_set is not None:
                        self._eval(self.data_set.test_set, 'testing')
                    else:
                        log.debug('test set is not set. skip evaluation on test set.')

            if early_stopping.next_score(cur_score):
                log.debug('reach early stopping patience, break training.')
                break

        log.info('finish training.')
        self.model.load(save_dir)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        with torch.no_grad():
            self.model.eval()
            log.info('----------- evaluation on dev -------------')
            self.model.print_evaluation(self.data_set.dev_set)
            if self.data_set.test_set is not None:
                log.info('----------- evaluation on test ------------')
                self.model.print_evaluation(self.data_set.test_set)
            else:
                log.warning('No test data set. Skip evaluation on test.')

    def _eval(self, data_set, prefix='developing'):
        self.model.eval()
        if self.model.eval_metrics == 'nloss':
            metrics = data_runner(data_set, self.model, cal_loss=True)
            loss = float(metrics['loss'])
            log.info('%s -- loss: %.4f' % (prefix, loss))
            cur_score = -float(loss)
        elif self.model.eval_metrics in self.model.metrics_list:
            metrics = data_runner(data_set, self.model, cal_loss=False, metrics=self.model.metrics_list)
            log.info('%s -- %s' % (prefix, (' '.join('%s: %s' % (k, metrics[k]) for k in self.model.metrics_list))))
            cur_score = float(metrics[self.model.eval_metrics])
        else:
            raise ValueError('Metric function=(`%s`) is not available. choice from %s'
                             % (self.model.eval_metrics, self.model.metrics_list))
        return cur_score


class InlineTrainer(BaseTrainer):
    def __init__(self, data_set: DataSet, task_model: BaseTask,
                 max_epoch=300, early_stopping=10, optimizer='Adam', check_point='', **kwargs):
        """

        :param data_set: 模型训练需要的数据集
        :param task_model: 需要训练的模型
        :param max_epoch: 最大训练轮数 max_epoch=300
        :param early_stopping: 早停止轮数 default=10
        :param optimizer: 使用的优化器 ['BertAdam', 'Adam'] default='Adam'
        :param bert_lr: BertOptimizer 使用Bert优化器时使用的学习率 default=5e-5
        :param bert_warmup_proportion: 使用Bert优化器时预热的轮数占比 default=0.1
        """
        self.max_epoch = max_epoch
        self.kwargs = kwargs
        super().__init__(data_set, task_model,
                         early_stopping=early_stopping,
                         check_point=check_point,
                         optimizer_name=optimizer)
        arg_dict = {k: self.__dict__[k] for k in ['max_epoch', 'early_stopping', 'optimizer_name', 'check_point']}
        log.info(f'Trainer config: {arg_dict}')

    def _get_optimizer(self, optimizer_name: str):
        if optimizer_name == 'BertAdam':
            if self.max_epoch > 15:
                log.warning('You may set BertAdam with train max_epoch = %d (> 15). Please recheck your settings '
                            'unless you know what you are doing.' % self.max_epoch)
            total_train_step = len(self.data_set.train_set) * self.max_epoch / self.data_set.batch_size
            return InlineBertAdamOptimizer(self.model, total_train_step, **self.kwargs)
        return super()._get_optimizer(optimizer_name)

    def _iter_epoch(self):
        for epoch in range(1, 1 + self.max_epoch):
            yield 'epoch -- %d/%d' % (epoch, self.max_epoch)


@inherit_config(BertAdamOptimizer)
@add_config('max_epoch', int, 300, 'max epoch to iteration the train data.')
@add_config('early_stopping', int, 10, 'early stop the train process when no greater score get.')
@add_config('optimizer', str, 'Adam', 'optimizer function. choose from [Adam, BertAdam]')
@add_config('check_point', str, '', 'checkpoint file.')
class Trainer(BaseTrainer):
    def __init__(self, config, data_set: DataSet, task_model: BaseTask):
        self.max_epoch = config['max_epoch']
        super().__init__(data_set, task_model,
                         early_stopping=config['early_stopping'],
                         check_point=config['check_point'],
                         optimizer_name=config['optimizer'])

    def _get_optimizer(self, optimizer_name: str):
        if optimizer_name == 'BertAdam':
            if self.max_epoch > 15:
                log.warning('You may set BertAdam with train max_epoch = %d (> 15). Please recheck your settings '
                            'unless you know what you are doing.' % self.max_epoch)
            total_train_step = len(self.data_set.train_set) * self.max_epoch / self.data_set.batch_size
            return BertAdamOptimizer(self.model, total_train_step)
        return super()._get_optimizer(optimizer_name)

    def _iter_epoch(self):
        for epoch in range(1, 1 + self.max_epoch):
            yield 'epoch -- %d/%d' % (epoch, self.max_epoch)


WholeEpochTrainer: Type[Trainer] = Trainer


@inherit_config(BertAdamOptimizer)
@add_config('max_num_step', int, 10**5, 'max batches to train, train data will use by cycle.')
@add_config('eval_num_stop', int, 1000, 'batches to evaluate and check to save model.')
@add_config('early_stopping', int, 10, 'early stop the train process when no greater score get.')
@add_config('optimizer', str, 'Adam', 'optimizer function. choose from [Adam, BertAdam]')
@add_config('check_point', str, '', 'checkpoint file.')
class SemiEpochTrainer(BaseTrainer):

    def __init__(self, config, data_set: DataSet, task_model: BaseTask):
        self.max_num_step = config['max_num_step']
        self.eval_num_stop = config['eval_num_stop']
        self.train_iter = None
        self.global_step = 0
        super().__init__(data_set, task_model,
                         early_stopping=config['early_stopping'],
                         check_point=config['check_point'],
                         optimizer_name=config['optimizer'])

    def _get_optimizer(self, optimizer_name: str):
        if optimizer_name == 'BertAdam':
            return BertAdamOptimizer(self.model, self.max_num_step)
        return super()._get_optimizer(optimizer_name)

    def get_next_train_batch(self):
        if self.train_iter is None:
            self.train_iter = iter(self.data_set.train_set)

        batch = next(self.train_iter, None)
        if batch is None:
            log.info('train data reach the cycle end. start next iter.')
            self.train_iter = None
            return self.get_next_train_batch()
        self.global_step += 1
        return batch

    def _iter_epoch(self) -> str:
        while self.global_step < self.max_num_step:
            yield 'train step -- %d/%d' % (self.global_step, self.max_num_step)

    def _iter_train(self):
        for batch_id in range(self.eval_num_stop):
            if self.global_step >= self.max_num_step:
                break
            yield self.get_next_train_batch()

