#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List

import torch
from torch import nn

from dlab.data import Dictionary, Sentence
from dlab.embedder import Embedder
from dlab.eval.entity_decoder import EntityDecoder
from dlab.layers import ULSTM

from dlab.tasks import SequenceTagger


class FeaturesLSTMCRFSequenceTagger(SequenceTagger):
    """ 使用 Embed + biLSTM + CRF 做 序列标注。其中Embeder是一个参数可以传入，控制输入特征。 """
    def __init__(self,
                 tag_dictionary: Dictionary,
                 target_tag: str,
                 embedder: Embedder,  # local options
                 use_crf: bool = True,
                 bio_decoders: List[EntityDecoder] = None,
                 hidden_size: int = 200,  # local options
                 rnn_layers: int = 1,  # local options
                 dropout: float = 0.5  # local options
                 ):
        super().__init__(tag_dictionary, target_tag, bio_decoders=bio_decoders, use_crf=use_crf)

        self.rnn_layers: int = rnn_layers
        self.hidden_size: int = hidden_size

        self.embedder = embedder

        # initialize the network architecture
        self.rnn_layers: int = rnn_layers
        self.hidden_word = None
        self.after_embed_dropout = nn.Dropout(dropout)
        self.rnn = ULSTM(self.embedder.dim, self.hidden_size // 2, bidirectional=True, dropout=dropout,
                         num_layers=self.rnn_layers)
        self.after_rnn_dropout = nn.Dropout(dropout)
        self.hidden2tag = torch.nn.Linear(self.hidden_size, self.tagset_size)

    def forward(self, sentences: List[Sentence]):
        batch_words_present, seq_length = self.embedder.embed(sentences)
        batch_words_present = self.after_embed_dropout(batch_words_present)
        lstm_out, ht = self.rnn(batch_words_present, seq_length, self.hidden_word)
        lstm_out = self.after_rnn_dropout(lstm_out)
        logit = self.hidden2tag(lstm_out)
        return logit

