#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import codecs
from typing import List, Union, Tuple, Dict
import numpy as np

from dlab.log import log
from dlab.data.structures import Sentence, Token
from dlab.data.language import LANG


class CorpusReader(object):
    """
    这里定义 Corpus 指一个文件，描述如何将文件读入转化为Sentence的过程。
    另外参考 DataSet 是一个数据集合，其中应该包括训练集、开发集、测试集等。
    """
    def __init__(self, file_path, lang=LANG.ZH, strip=None, split=None):
        """

        :param file_path: 文件位置
        :param lang: 语言
        :param strip: 每行两端删除字符集合，默认为空白符
        :param split: 字段分隔符，默认为空白符
        """
        self.file_path = file_path
        self.strip = strip
        self.split = split
        self.lang = lang

    def _iter_raw_lines(self, forever=False):
        log.info('read corpus from %s' % self.file_path)
        while 1:
            line_num = 0
            with codecs.open(self.file_path, encoding='utf-8') as fr:
                while 1:
                    line = fr.readline()
                    line_num += 1
                    if line == '':  # EOF
                        yield (line_num, '')  # 输出空行，简化上层逻辑
                        break
                    line = line.strip(self.strip)
                    yield (line_num, line)
            if not forever:
                log.info('Reading finished. total lines:%d' % line_num)
                break

    def _iter_lines(self, forever=False):
        """ 使用这个iter会自动按照 spliter 分字段。 """
        for line_num, raw_line in self._iter_raw_lines(forever):
            if not raw_line:  # end of a sentence
                yield (line_num, [])  # 输出空行
            else:
                yield (line_num, raw_line.split(self.split))

    def _read_lines(self):
        return list(self._iter_lines())

    def _iter_blocks(self, forever=False):
        rows = []
        for line_num, line_features in self._iter_lines(forever):
            if len(line_features):
                rows.append(line_features)
            else:
                if len(rows):
                    yield line_num, rows
                    rows = []

    def _read_blocks(self):
        return list(self._iter_blocks())

    def iter(self, forever=False) -> Sentence:
        raise NotImplementedError()

    def read(self) -> List[Sentence]:
        return list(self.iter())

    def __iter__(self):
        return self.iter()


# 每行一条的文本分类数据读取器
class LineBasedTCCorpusReader(CorpusReader):
    """基于行的分类数据读取。
    每行一个样本，使用split分割的字段，第一个字段为分类名称，之后的字段为句子的每个词。（若不分词，返回字的token的句子）
    """
    def __init__(self, file_path, char_base=False, label_name='category', **kwargs):
        """

        :param file_path: 文件路径
        :param char_base: 是否为不分词的数据 default=False
        :param split: 分隔符
        """
        super().__init__(file_path, **kwargs)
        self.char_base = char_base
        self.label_name = label_name

    def iter(self, forever=False) -> Sentence:
        for line_num, line_fields in self._iter_lines():
            if len(line_fields):
                stn = Sentence()
                cls = line_fields[0]
                stn.set_label(self.label_name, cls)
                token_strings = line_fields[1] if self.char_base else line_fields[1:]
                for token_string in token_strings:
                    stn.add_token(token_string)
                yield stn


# 中文分词数据读取器
class WordSegmentCorpusReader(CorpusReader):
    """
    读取中文分词数据，每行一句话。每行的每个字段为一个词。会自动按照字为token读取，并给每个字填入BMES标签。
    """
    def __init__(self, file_path, tag_name: str = 'ws', proto: str = 'BMES', **kwargs):
        """

        :param file_path:
        :param tag_name: 分词 tag 名字
        :param proto: choice from ['BMES']
        """
        super().__init__(file_path, LANG.ZH, **kwargs)
        self.tag_name = tag_name
        self.proto = proto

    def iter(self, forever=False) -> Sentence:
        for line_num, line_words in self._iter_lines():
            if len(line_words):
                stn = Sentence()
                for word in line_words:
                    for i, c in enumerate(word):
                        tok = Token(c, len(stn), len(stn))
                        if len(word) == 1:
                            tok.add_tag(self.tag_name, 'S')
                        elif i == 0:
                            tok.add_tag(self.tag_name, 'B')
                        elif i == len(word) - 1:
                            tok.add_tag(self.tag_name, 'E')
                        else:
                            tok.add_tag(self.tag_name, 'M')
                        stn.add_token(tok)
                yield stn


class ConllReader(CorpusReader):
    """
        Conll格式文件读取器，提供一些通用方法。Conll格式为一个个Block，每个Block中n行m列，每行一个token，每列一个字段。
    """
    TOKEN_TAG = Token.TOKEN_TAG
    SKIP_TAG = '__SKIP'

    def __init__(self, file_path, tag_names: List[str] = None, lang: LANG = LANG.ZH, **kwargs):
        """

        :param file_path:
        :param tag_names: 每一个tag类型的名字，除第一列为token外，每一列都要有名字。 len(tag_names) == 列数 - 1
                          其中有特殊标签 'TOKEN' 如果在其中可以表示 TOKEN不在第一列情况，如果在第一列可以省略。
        """
        super().__init__(file_path, lang, **kwargs)
        if tag_names is None:
            tag_names = []
        if self.TOKEN_TAG not in tag_names:
            tag_names = [self.TOKEN_TAG] + tag_names
        self.tag_names = tag_names
        self.token_id = tag_names.index(self.TOKEN_TAG)

    def add_tokens_from_block(self, sentence, block, line_num=0):
        for row in block:
            if not len(row) == len(self.tag_names):
                raise ValueError('line %s and tag_name %s not match at file %s:%d ' %
                                 (row, self.tag_names, self.file_path, line_num))
            sentence.add_token(row[self.token_id])
            for tag_id, tag_name in enumerate(self.tag_names):
                if not (tag_name in [self.TOKEN_TAG, self.SKIP_TAG]):
                    sentence[-1].add_tag(tag_name, row[tag_id])


# 这里指的 Tag 是指得序列标注任务，tag是词上的标签

class TagCorpusReader(ConllReader):
    """
    序列标注 Conll格式读取器。 每个词占一行，每行中可以写多个字段，包含目标和特征字段，会作为token的tag写入。句句之间有空行。
    """
    def iter(self, forever=False):
        extract_stn_num = 0
        for line_num, block in self._iter_blocks(forever):
            s = Sentence(lang=self.lang)
            self.add_tokens_from_block(s, block, line_num)
            yield s
            extract_stn_num += 1
        log.info('total sentence:%d' % extract_stn_num)


# 这里指的 Label 是指得文本分类任务，label是句子上的标签

class LabelCorpusReader(ConllReader):
    """
    文本分类 Conll格式读取器。"__" 前导的分类名称占一行，之后每个词占一行，每行中可以写多个字段，会作为token的tag写入。句句之间有空行。
    """
    def __init__(self, file_path, label_name='category', **kwargs):
        super().__init__(file_path, **kwargs)
        self.label_name = label_name

    def iter(self, forever=False):
        extract_stn_num = 0
        for line_num, block in self._iter_blocks(forever):
            s = Sentence(lang=self.lang)
            label = block[0][0]
            if not (label.startswith('__') and len(label) >= 3):
                raise ValueError('illegal class_name "%s" at file %s:%d ' % (label, self.file_path, line_num))
            label = label[2:]
            s.set_label(self.label_name, label)
            self.add_tokens_from_block(s, block[1:], line_num)
            yield s
            extract_stn_num += 1
        log.info('Reading finished. total sentence:%d' % extract_stn_num)


# 文档读取器
class DocumentCorpusReader(CorpusReader):

    def iter(self, forever=False) -> List[str]:
        document_num = 0
        sentence_num = 0
        document = []
        for line_num, line in self._iter_raw_lines(forever):
            if not line and len(document):
                yield document
                document = []
            else:
                document.append(line)


# csv 文件读取器
class CSVReader(CorpusReader):
    def __init__(self,file_path, header=True, **kwargs):
        super().__init__(file_path, **kwargs)
        self.header = header
        self.header_names = None

    def iter(self, forever=False) -> Union[List[str], Dict[str, str]]:
        import csv
        line_num = 0
        with codecs.open(self.file_path, encoding='utf-8') as fr:
            csv_reader = csv.reader(fr)
            for item in csv_reader:
                if self.header:
                    if csv_reader.line_num == 1:
                        self.header_names = item
                        continue
                    else:
                        yield dict(zip(self.header_names, item))
                else:
                    yield item
                line_num += 1
        log.info('Reading finished. total csv lines:%d' % line_num)


class BatchedCorpus(object):
    def __init__(self, reader: Union[CorpusReader, List], batch_size=32, shuffle=False, preload=False):
        """
        该类可以把reader，或者list batch化。每次 iter 一个 batch。可以设置shuffle。

        :param reader:
        :param shuffle: if true, preload will be override as True.
        :param preload: if true, all data will be read into memory.
        """
        self.reader: CorpusReader = reader
        self.batch_size = batch_size
        self.preload = preload or shuffle
        self.shuffle = shuffle
        if isinstance(reader, CorpusReader):
            self.data = reader.read() if self.preload else None
        elif isinstance(reader, list):
            self.data = reader
            self.preload = True
        else:
            raise ValueError('reader must be List[Sentence] or reader')

    def _iter_item(self):
        if self.preload:
            if self.shuffle:
                order = np.random.permutation(range(len(self.data)))
            else:
                order = list(range(len(self.data)))
            for order_idx in order:
                yield self.data[order_idx]
        else:
            for item in self.reader:
                yield item

    def iter_batch(self):
        batch = []
        for item in self._iter_item():
            batch.append(item)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch):
            yield batch

    def __iter__(self):
        return self.iter_batch()

    def __len__(self):
        return len(self.data) if self.data is not None else -1


class DataSet(object):

    def __init__(self, train: List[Sentence]=None, dev: List[Sentence]=None, test: List[Sentence]=None,
                 shuffle=True, batch_size=32, dev_split=0.1, shuffle_train_before_split_dev=False):
        """

        :param train:
        :param dev:
        :param test:
        :param shuffle: shuffle train set.
        :param batch_size:
        :param dev_split: if dev is None and set the split. the dev_split % end of train set will be use as dev.
        :param shuffle_train_before_split_dev: 在自动换分开发集时，首先shuffle训练集
        """
        if isinstance(train, CorpusReader):
            train = train.read()
        if dev_split is not None and dev is None:
            if shuffle_train_before_split_dev:
                from random import Random
                Random(1).shuffle(train)
            train_dev_len = len(train)
            train_split = int(train_dev_len * (1.0 - dev_split))
            train, dev = train[:train_split], train[train_split:]
            log.info('train set [len=%d] split to train[:%d] (len=%d) and dev[%d:] (len=%d)' %
                     (train_dev_len, train_split, len(train), train_split, len(dev)))
        self.batch_size = batch_size
        self.train_set = BatchedCorpus(train, batch_size, shuffle, preload=True) if train else None
        self.dev_set = BatchedCorpus(dev, batch_size) if dev else None
        self.test_set = BatchedCorpus(test, batch_size) if test else None

    def iter_train(self):
        return self.train_set.__iter__()

    def iter_dev(self):
        return self.dev_set.__iter__()

    def iter_test(self):
        return self.test_set.__iter__()


class ReaderDataSet(DataSet):
    __Reader_Class = TagCorpusReader

    def __init__(self, train: str = None, dev: str = None, test: str = None, tag_names: List[str] = None, batch_size=32,
                 shuffle=True, lang=LANG.ZH):
        super().__init__()
        self.tag_names = [] if tag_names is None else tag_names
        self.train_file = train
        self.train_set = None
        if self.train_file is not None:
            self.train_set = BatchedCorpus(self.__Reader_Class(self.train_file, self.tag_names, lang), batch_size, shuffle, shuffle)

        self.dev_file = dev
        self.dev_set = None
        if self.dev_file is not None:
            self.dev_set = BatchedCorpus(self.__Reader_Class(self.dev_file, self.tag_names, lang), batch_size)

        self.test_file = test
        self.test_set = None
        if self.test_file is not None:
            self.test_set = BatchedCorpus(self.__Reader_Class(self.test_file, self.tag_names, lang), batch_size)


class SeqTagDataSet(DataSet):
    __Reader_Class = TagCorpusReader


class LabelDataSet(DataSet):
    __Reader_Class = LabelCorpusReader

