#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Dict

from torch import Tensor

from dlab.tasks.hierarchical_classifier import HierarchicalClassifier
from dlab.encoder import Encoder
from dlab.layers import Decoder
from dlab.data import HierarchicalDictionary, Sentence


class Seq2SeqHierarchicalClassifier(HierarchicalClassifier):
    def __init__(self, dictionary: HierarchicalDictionary, label_name: str,
                 encoder: Encoder,
                 target_embedding_dim: int,
                 encoder_final_hidden_highway: bool=False,
                 ):
        super().__init__(dictionary, label_name)
        self.encoder = encoder
        self.decoder = Decoder(self.encoder, self.dictionary.deep(), target_embedding_dim, dictionary.width(),
                               encoder_hidden_highway=encoder_final_hidden_highway)

    def forward(self, batch_sentence: List[Sentence]) -> Dict[str, Tensor]:
        return self.decoder.forward(batch_sentence)

    def loss(self, batch_sentence: List[Sentence], logit: Dict[str, Tensor]) -> Tensor:
        target_tokens = [s.get_label_name(self.label_name) for s in batch_sentence]
        target_tokens_ids = [self.dictionary.get_idx_for_item(t) for t in target_tokens]
        return self.decoder.loss(logit, target_tokens_ids)

    def decode(self, sentences: List[Sentence], logit: Dict[str, Tensor]) -> List[List[int]]:
        return self.decoder.decode(logit)

