#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

from collections import defaultdict
from typing import List, Any, Tuple
import random

from dlab.data import Sentence


class ListsManager(object):
    def __init__(self):
        self.sets = defaultdict(list)

    def add_data(self, data: List[Any], set_name: str):
        self.sets[set_name].extend(data)

    def _filter_by_ids(self, data: List[Any], ids: List[int]) -> Tuple[List, List]:
        """

        :param data:
        :param ids:
        :return: data in remain, data in ids
        """
        filter = [0] * len(data)
        for id in ids:
            filter[id] = 1
        sets = ([], [])
        for id, item in enumerate(data):
            sets[filter[id]].append(item)
        return sets

    def mv(self, ori: str, tar: str, ids: List[int]=None):
        if ids is None:
            temp_list = self.sets.pop(ori)
        else:
            self.sets[ori], temp_list = self._filter_by_ids(self.sets[ori], ids)
        self.sets[tar].extend(temp_list)

    def cp(self, ori: str, tar: str, ids: List[int]=None):
        if ids is None:
            temp_list = self.sets[ori][:]
        else:
            _, temp_list = self._filter_by_ids(self.sets[ori], ids)
        self.sets[tar].extend(temp_list)

    def print_item(self, name: str, idx: int):
        print('# %s:%8d\t%s' % (name, idx, self.sets[name][idx]))

    def show(self, name: str, start=0, end=None, step=1):
        s = slice(start, end, step)
        start, end , step = s.indices(len(self.sets[name]))
        i = start
        while i < end:
            self.print_item(name, i)
            i += step

    def head(self, name: str, n: int=5):
        return self.show(name, start=0, end=n)

    def tail(self, name: str, n: int=5):
        return self.show(name, start=-n)

    def sample(self, name: str, n: int=5):
        ids = random.sample(range(len(self.sets[name])), n)
        for idx in ids:
            self.print_item(name, idx)

    def len(self, name: str=None):
        if name is None:
            return {name: len(s) for name, s in self.sets.items()}
        return len(self.sets[name])


class FilterManagerBase(ListsManager):
    UNFILTERED = '__unfiltered_set'
    FILTERED = '__filtered_set'


class FilterAction(ListsManager):
    FROM = '__from_set'
    CHANGE = '__change_set'
    TO = '__to_set'

    def __init__(self, manager: FilterManagerBase, ori: str, tar: str):
        super().__init__()
        self.manager = manager
        self.legal_set = [self.FROM, self.CHANGE, self.TO]
        self.ori = ori
        self.tar = tar
        self.sets_name_2_origin_map = {
            self.FROM: ori,
            self.CHANGE: ori,
            self.TO: tar,
        }

    def print_item(self, name: str, idx: int):
        if name not in self.legal_set:
            raise ValueError('set %s can not be reachable in %s.' % (name, self.__class__.__name__))
        m_set_name = self.sets_name_2_origin_map[name]
        m_set_idx = self.sets[name][idx]
        print('# %s:%d\t%s' % (m_set_name, m_set_idx, self.manager.sets[m_set_name][m_set_idx]))

    def commit(self):
        self.manager.mv(self.ori, self.tar, self.sets[self.CHANGE])
        print('%d items moved' % (self.len(self.CHANGE)))

    def show_changes(self, start=0, end=None, step=1):
        self.show(self.CHANGE, start=start, end=end, step=step)

    def show_ori(self, start=0, end=None, step=1):
        self.show(self.FROM, start=start, end=end, step=step)


class FilterManager(FilterManagerBase):
    UNFILTERED = '__unfiltered_set'
    FILTERED = '__filtered_set'

    def __init__(self):
        super().__init__()

    def add_data(self, data: List[Any], set_name: str=UNFILTERED):
        super().add_data(data, set_name=set_name)

    def _filter(self, data: List[Sentence], func):
        selected = []
        remain = []
        for stn in data:
            if func(stn):
                selected.append(stn)
            else:
                remain.append(stn)

    def filter(self, filter_function, from_set: str=UNFILTERED, to_set: str=FILTERED):
        match_ids = [idx for idx, item in enumerate(self.sets[from_set]) if filter_function(item)]
        act = FilterAction(self, from_set, to_set)
        act.add_data(list(range(len(self.sets[from_set]))), act.FROM)
        act.add_data(list(range(len(self.sets[to_set]))), act.TO)
        act.mv(act.FROM, act.CHANGE, match_ids)
        return act

