#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Tuple

from torch import Tensor

from dlab.data import Sentence, Dictionary
from dlab.embedder import Embedder
from dlab.layers import URNN, ULSTM, UGRU
import numpy as np
from .base_encoder import Encoder


class RNNEncoder(Encoder):
    __RNN_TYPE = URNN

    def __init__(self, dim: int, embedder: Embedder, **kwargs):
        """

        :param input_dim:
        :param output_dim:
        :param kwargs: num_layers=1, dropout=0, bidirectional=True,
        """
        if dim % 2:
            raise ValueError('dim must be a 2 times value.')
        super().__init__(dim, support_time_step=True)
        self.embedder = embedder
        self.rnn = self.__RNN_TYPE(embedder.dim, dim // 2, only_use_last_hidden_state=False, **kwargs)

    def encode(self, batch_sentence: List[Sentence], **kwargs) -> Tuple[Tensor, np.ndarray]:
        embed_output, seq_length = self.embedder.embed(batch_sentence)
        seq_length = np.array(seq_length)
        out, ht = self.rnn(embed_output, seq_length)
        return out, seq_length


class LSTMEncoder(RNNEncoder):
    __RNN_TYPE = ULSTM


class GRUEncoder(RNNEncoder):
    __RNN_TYPE = UGRU
