#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
import torch
from typing import List, Dict

from torch import Tensor
from torch.nn import CrossEntropyLoss

from dlab.const import PAD
from dlab.data.utils import get_mask
from dlab.layers.utils import get_final_encoder_states
from dlab.tasks.hierarchical_classifier import HierarchicalClassifier
from dlab.encoder import Encoder
from dlab.data import HierarchicalDictionary, Dictionary, Sentence


class FlatHierarchicalClassifier(HierarchicalClassifier):
    def __init__(self, dictionary: HierarchicalDictionary, label_name: str,
                 encoder: Encoder,
                 dropout: float=0.5,
                 ):
        super().__init__(dictionary, label_name)
        self.flat_dictionary = Dictionary(False, False)
        for class_name in dictionary.full_name_list:
            if not class_name.endswith(PAD):
                self.flat_dictionary.add_item(class_name)
        self.encoder = encoder
        self.dropout = dropout
        self.after_rnn_dropout = torch.nn.Dropout(dropout)
        self.tagset_size = len(self.flat_dictionary)
        self.hidden2tag = torch.nn.Linear(self.encoder.dim, self.tagset_size)
        self.torch = torch.cuda if torch.cuda.is_available() else torch

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        encoder_output, length = self.encoder(batch_sentence)
        mask = self.torch.LongTensor(get_mask(length))
        hidden_state = get_final_encoder_states(encoder_output, mask, bidirectional=True)
        hidden_state = self.after_rnn_dropout(hidden_state)
        return self.hidden2tag(hidden_state)

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        target_tokens = [s.get_label_name(self.label_name) for s in batch_sentence]
        target_tokens_ids = [self.flat_dictionary.get_idx_for_item(t) for t in target_tokens]
        tensor_target_token_ids = self.torch.LongTensor(target_tokens_ids)
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logit, tensor_target_token_ids)
        return loss

    def decode(self, sentences: List[Sentence], logit: Tensor) -> List[List[int]]:
        predict_tensor = torch.argmax(logit, -1)
        label_ids = predict_tensor.detach().cpu().numpy()
        hc_list = [self.dictionary.get_idx_for_item(self.flat_dictionary.get_item_for_index(idx)) for idx in label_ids]
        return hc_list

