#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import Dict, List, Tuple

import numpy

import torch
import torch.nn.functional as F
from torch.nn import Embedding
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTMCell

from dlab.const import START_SYMBOL, END_SYMBOL
from dlab.data import Dictionary, Sentence
from dlab.embedder import Embedder
from dlab.data.utils import get_mask, pad_batch
from dlab.encoder import Encoder
from .utils import BeamSearch, get_final_encoder_states, get_text_field_mask, weighted_sum, \
    sequence_cross_entropy_with_logits


class Decoder(torch.nn.Module):
    """
        This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
        uses the encoded representations to decode another sequence.  You can use this as the basis for
        a neural machine translation system, an abstractive summarization system, or any other common
        seq2seq problem.  The model here is simple, but should be a decent starting place for
        implementing recent models for these tasks.
        Parameters
        ----------
        vocab : ``Vocabulary``, required
            Vocabulary containing source and target vocabularies. They may be under the same namespace
            (`tokens`) or the target tokens can have a different namespace, in which case it needs to
            be specified as `target_namespace`.
        source_embedder : ``TextFieldEmbedder``, required
            Embedder for source side sequences
        encoder : ``Seq2SeqEncoder``, required
            The encoder of the "encoder/decoder" model
        max_decoding_steps : ``int``
            Maximum length of decoded sequences.
        target_namespace : ``str``, optional (default = 'target_tokens')
            If the target side vocabulary is different from the source side's, you need to specify the
            target's namespace here. If not, we'll assume it is "tokens", which is also the default
            choice for the source side, and this might cause them to share vocabularies.
        target_embedding_dim : ``int``, optional (default = source_embedding_dim)
            You can specify an embedding dimensionality for the target side. If not, we'll use the same
            value as the source embedder's.
        attention : ``Attention``, optional (default = None)
            If you want to use attention to get a dynamic summary of the encoder outputs at each step
            of decoding, this is the function used to compute similarity between the decoder hidden
            state and encoder outputs.
        attention_function: ``SimilarityFunction``, optional (default = None)
            This is if you want to use the legacy implementation of attention. This will be deprecated
            since it consumes more memory than the specialized attention modules.
        beam_size : ``int``, optional (default = None)
            Width of the beam for beam search. If not specified, greedy decoding is used.
        scheduled_sampling_ratio : ``float``, optional (default = 0.)
            At each timestep during training, we sample a random number between 0 and 1, and if it is
            not less than this value, we use the ground truth labels for the whole batch. Else, we use
            the predictions from the previous time step for the whole batch. If this value is 0.0
            (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
            using target side ground truth labels.  See the following paper for more information:
            `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
            2015 <https://arxiv.org/abs/1506.03099>`_.
        use_bleu : ``bool``, optional (default = True)
            If True, the BLEU metric will be calculated during validation.
        """

    def __init__(self, encoder: Encoder,
                 max_decoding_steps: int,
                 target_embedding_dim: int,
                 target_softmax_dim: int,
                 encoder_hidden_highway: bool=False,
                 attention=None,
                 # todo attention: Attention = None,
                 # todo attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = False) -> None:
        super().__init__()
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = target_softmax_dim + 0  # START_SYMBOL
        self._end_index = target_softmax_dim + 1  # END_SYMBOL

        if use_bleu:
            # todo
            pass
            # pad_index = self.vocab.get_token_index(self.vocab._padding_token,
            #                                        self._target_namespace)  # pylint: disable=protected-access
            # self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = target_softmax_dim + 2

        self._encoder_hidden_highway = encoder_hidden_highway

        # Attention mechanism applied to the encoder output for each step.
        # if attention:
        #     if attention_function:
        #         raise ValueError("You can only specify an attention module or an "
        #                          "attention function, but not both.")
        #     self._attention = attention
        # elif attention_function:
        #     # todo
        #     pass
        #     # self._attention = LegacyAttention(attention_function)
        # else:
        #     self._attention = None
        self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.dim
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        if self._encoder_hidden_highway:
            self._decoder_input_dim += self._encoder_output_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        self.dropout_layer = torch.nn.Dropout()

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)

        self.gpu = torch.cuda.is_available()
        self.torch = torch.cuda if self.gpu else torch

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.
        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.
        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.
        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def forward(self, batch_sentence: List[Sentence]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.
        Parameters
        ----------
        batch_sentence :
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(batch_sentence)
        state['source_mask'] = self.torch.LongTensor(get_mask(state['seq_length']))
        return state

    def add_start_end_tokens(self, target_tokens: List[List[int]]):
        return [[self._start_index] + target_token + [self._end_index] for target_token in target_tokens]

    def loss(self, state: Dict[str, torch.Tensor], target_tokens: List[List[int]]):
        state = self._init_decoder_state(state)
        target_tokens = self.add_start_end_tokens(target_tokens)
        # The `_forward_loop` decodes the input sequence and computes the loss during training
        # and validation.
        pad_target_tokens, target_mask = pad_batch(target_tokens, max_length=self._max_decoding_steps, return_mask=True)
        tensor_target_tokens = self.torch.LongTensor(pad_target_tokens)
        tensor_target_mask = self.torch.ByteTensor(target_mask)
        output_dict = self._forward_loop(state, tensor_target_tokens, tensor_target_mask)

        # if not self.training:
        #     state = self._init_decoder_state(state)
        #     predictions = self._forward_beam_search(state)
        #     output_dict.update(predictions)
        #     if target_tokens and self._bleu:
        #         # shape: (batch_size, beam_size, max_sequence_length)
        #         top_k_predictions = output_dict["predictions"]
        #         # shape: (batch_size, max_predicted_sequence_length)
        #         best_predictions = top_k_predictions[:, 0, :]
        #         self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict['loss']

    def decode(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Finalize predictions.
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        state = self._init_decoder_state(state)
        predictions = self._forward_beam_search(state)
        state.update(predictions)
        predicted_indices = state["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            all_predicted_tokens.append(indices)
        state["predicted_tokens"] = all_predicted_tokens
        return state["predicted_tokens"]

    def _encode(self, sentences: List[Sentence]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs, length = self._encoder(sentences)
        return {
            "seq_length": length,
            "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = get_final_encoder_states(
            state["encoder_outputs"],
            state["source_mask"],
            True)  # todo fix this to support un bidirectional
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = self.dropout_layer(final_encoder_output)
        state["encoder_final_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: torch.LongTensor,
                      target_mask: torch.ByteTensor) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.
        :param target_tokens: torch.LongTensor [B*decoder_max_length]
        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens is not None:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            # elif timestep == 0:  # my inserting to fill self._start_index at 0 step.
            #     input_choices = last_predictions
            elif target_tokens is None:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens is not None:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            # target_mask = get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[
        torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        if self._attention:
            # todo
            pass
            # shape: (group_size, encoder_output_dim)
            # attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            # decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        if self._encoder_hidden_highway:
            decoder_input = torch.cat((decoder_input, state['encoder_final_hidden']), -1)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input,
            (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden)

        return output_projections, state

    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(
            decoder_hidden_state, encoder_outputs, encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = weighted_sum(encoder_outputs, input_weights)

        return attended_input

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.
        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.
        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        return all_metrics
