#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

from typing import List
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss

from dlab.data import Dictionary, Sentence, pad_batch
from dlab.data.utils import get_mask
from dlab.embedder import Embedder
from dlab.layers.utils import get_final_encoder_states
from dlab.tasks import TextClassifier
from dlab.layers import ULSTM


class FeatureLSTMClassifier(TextClassifier):

    def __init__(self, class_dict: Dictionary,
                 embedder: Embedder,
                 rnn_hidden_size: int,
                 rnn_layers: int=1,
                 dropout=0.5,
                 **kwargs):
        super().__init__(class_dict, **kwargs)
        self.number_labels = len(class_dict)
        self.rnn_layers = rnn_layers
        self.hidden_word = None

        self.hidden_size = rnn_hidden_size
        self.dropout = dropout

        self.embedder = embedder
        self.after_embed_dropout = torch.nn.Dropout(self.dropout)
        self.rnn = ULSTM(self.embedder.dim, self.hidden_size // 2, bidirectional=True, dropout=self.dropout,
                         num_layers=self.rnn_layers)
        self.after_rnn_dropout = torch.nn.Dropout(self.dropout)
        self.hidden2labels = torch.nn.Linear(self.hidden_size, self.number_labels)
        self.torch = torch.cuda if torch.cuda.is_available() else torch

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        batch_words_present, seq_length = self.embedder.embed(batch_sentence)
        input_mask = self.torch.LongTensor(get_mask(seq_length))
        batch_words_present = self.after_embed_dropout(batch_words_present)
        lstm_out, ht = self.rnn(batch_words_present, seq_length, self.hidden_word)
        lstm_out = get_final_encoder_states(lstm_out, input_mask, True)
        lstm_out = self.after_rnn_dropout(lstm_out)
        logit = self.hidden2labels(lstm_out)
        return logit

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        # the following code is copy from BertForSequenceClassification::forward in pytorch_pretrained_bert lib
        labels = self.torch.LongTensor([self.class_dict.get_idx_for_item(stn.get_label_name(self.target_label)) for stn in batch_sentence])
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logit, labels)
        return loss