#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Tuple
import numpy as np

from dlab.data.structures import Sentence
from torch import Tensor
from torch.nn import Module


class Encoder(Module):
    """
    Encoder 基类，不能使用
    """

    def __init__(self, dim, support_time_step=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.support_time_step = support_time_step

    def encode(self, batch_sentence: List[Sentence], **kwargs) -> Tuple[Tensor, np.ndarray]:
        """接口方法，调用可以返回 encode 之后的2维矩阵

        :param batch_sentence: batched sentence to be input.
        :type batch_sentence: List[Sentence]
        :return: (Tensor[B*max_seq_len,dim], np.ndarray)
        :rtype: Tuple[Tensor, List[int]]
        """
        raise NotImplementedError()

    def forward(self, batch_sentence: List[Sentence], **kwargs):
        return self.encode(batch_sentence, **kwargs)

