#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import torch

from torch import nn, Tensor
import numpy as np


class URNN(nn.Module):
    """
    RNN封装。

    自动解决了按照长度排序的问题。
    """
    __RNN_TYPE = nn.RNN

    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
                 bidirectional=True, only_use_last_hidden_state=False):
        """
        LSTM which can hold variable length sequence, use like TensorFlow's RNN(input, length...).

        :param input_size: The number of expected features in the input x
        :param hidden_size: The number of features in the hidden state h
        :param num_layers: Number of recurrent layers.
        :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
        :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: True
        :param dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
        :param bidirectional: If True, becomes a bidirectional RNN. Default: True
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.batch_dim, self.seq_dim = (0, 1) if self.batch_first else (1, 0)
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.only_use_last_hidden_state = only_use_last_hidden_state
        self.rnn = self.__RNN_TYPE(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=batch_first,
            dropout=dropout,
            bidirectional=bidirectional
        )
        self.gpu = torch.cuda.is_available()
        self.torch = torch.cuda if self.gpu else torch

    def forward(self, x: Tensor, x_len, hidden_word=None, return_last_hidden=None):
        """
        sequence -> sort -> pad and pack ->process using RNN -> unpack ->unsort

        :param x: Variable (seq_len, batch, features) if batch first (batch, seq_len, feature)
        :param x_len: numpy list
        :param hidden_word: lstm start hidden word default=None
        :param return_last_hidden: return last hidden state. None for only_use_last_hidden_state in construction setting
        :return:
        """
        return_last_hidden = self.only_use_last_hidden_state if return_last_hidden is None else return_last_hidden
        """sort"""
        if not isinstance(x_len, np.ndarray):
            x_len = np.array(x_len)
        x_sort_idx = np.argsort(-x_len)
        x_unsort_idx = self.torch.LongTensor(np.argsort(x_sort_idx))
        x_len = x_len[x_sort_idx]
        x = torch.index_select(x, self.batch_dim, self.torch.LongTensor(x_sort_idx))
        """pack"""
        x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
        """process using RNN"""
        out_pack, h_n = self.rnn(x_emb_p, hidden_word)

        num_directions = 2 if self.rnn.bidirectional else 1

        h_n = h_n.view(self.num_layers, num_directions, h_n.size(1), h_n.size(2))  # (layer, direction, batch, hidden/2)
        h_n = h_n[-1]  # (direction, batch, hidden/2) choose last layer
        if self.bidirectional:  # concat direction dim
            h_n = torch.cat((h_n[0], h_n[1]), -1)
        else:
            h_n = h_n[0]

        """unsort: h"""
        h_n = torch.index_select(h_n, self.batch_dim, x_unsort_idx)

        if return_last_hidden:
            return h_n
        else:
            """unpack: out"""
            out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first)  # (sequence, lengths)
            out = out[0]  #
            """unsort: out c"""
            out = torch.index_select(out, self.batch_dim, x_unsort_idx)
            # ct = torch.index_select(ct, self.batch_dim, x_unsort_idx)
            return out, h_n


class ULSTM(URNN):
    __RNN_TYPE = nn.LSTM


class UGRU(URNN):
    __RNN_TYPE = nn.GRU

