#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from dlab.config import add_config
from pytorch_pretrained_bert.optimization import BertAdam

from dlab.log import log


class InlineBertAdamOptimizer(object):
    def __init__(self, model, total_batches_for_train, bert_lr=5e-5, bert_warmup_proportion=0.1):
        self.lr = bert_lr
        self.warm_up = bert_warmup_proportion
        self.global_step = 0
        self.total_step = total_batches_for_train

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        self.optimizer = BertAdam(optimizer_grouped_parameters,
                                  lr=self.lr,
                                  warmup=self.warm_up,
                                  t_total=int(total_batches_for_train)
                                  )

    def step(self):
        # modify learning rate with special warm up BERT uses
        lr_this_step = self.lr * self.warmup_linear(self.global_step / self.total_step, self.warm_up)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr_this_step
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.global_step += 1

    @staticmethod
    def warmup_linear(x, warmup=0.002):
        if x < warmup:
            return x / warmup
        return 1.0 - x


@add_config('bert_lr', float, 5e-5, 'The initial learning rate for BertAdam.')
@add_config('bert_warmup_proportion', float, 0.1, 'Proportion of training to perform linear learning rate warmup for.')
class BertAdamOptimizer(InlineBertAdamOptimizer):
    def __init__(self, config, model, total_batches_for_train):
        self.config = config
        super().__init__(model, total_batches_for_train, config['bert_lr'], config['bert_warmup_proportion'])

