#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Tuple

from dlab.data.structures import Sentence
from torch import Tensor
from torch.nn import Module


class Embedder(Module):
    """
    Embedder 基类，不能使用
    """

    def __init__(self, dim, **kwargs):
        super().__init__()
        self.dim = dim

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        """接口方法，调用可以返回 embed 之后的3维矩阵和 每个句子的原始长度

        :param batch_sentence: batched sentence to be input.
        :type batch_sentence: List[Sentence]
        :return: (Tensor[B*max_len*embed_dim], list of seq length)
        :rtype: Tuple[Tensor, List[int]]
        """
        raise NotImplementedError()

    def forward(self, *input):
        return self.embed(*input)

