#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Union, Dict, Tuple
import torch
from torch import nn, Tensor

from dlab.data.structures import Sentence, Dictionary, Token
from .base_embedder import Embedder


class StackEmbedder(Embedder):
    """
    这是一个组合多个Embedder的Embedder，这个Embedder使用其构造的多个Embedder输出进行拼接再输出。
    """
    def __init__(self, embedder_list: List[Embedder], **kwargs):
        super().__init__(dim=sum([embedder.dim for embedder in embedder_list]), **kwargs)
        self.embedders = embedder_list
        for i, emb_layer in enumerate(self.embedders): # register as field to register as pytorch submodule.
            setattr(self, 'the_embedder_%d' % i, emb_layer)

    def embed(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, List[int]]:
        word_presents: List[Tensor] = []
        seq_len = [len(sentence) for sentence in batch_sentence]
        for embedder in self:
            present, _ = embedder(batch_sentence)
            word_presents.append(present)
        return torch.cat(word_presents, -1), seq_len

    def __len__(self):
        return len(self.embedders)

    def __getitem__(self, item):
        return self.embedders[item]

    def __iter__(self):
        return self.embedders.__iter__()