from __future__ import absolute_import
from __future__ import division

import functools

import numpy as np

import copy
import inspect
import types as python_types
import warnings

from keras import backend as K
from keras import activations
from keras import initializations
from keras import regularizers
from keras import constraints
from keras.engine import InputSpec
from keras.engine import Layer
from keras.engine import Merge
from keras.utils.generic_utils import func_dump
from keras.utils.generic_utils import func_load
from keras.utils.generic_utils import get_from_module

from keras.utils.np_utils import conv_output_length
from keras.utils.np_utils import conv_input_length

from keras.layers import (
    Activation, AveragePooling1D, BatchNormalization,
    Convolution1D, Dense, Dropout, Flatten, Input,
    MaxPooling1D, Merge, Permute, Reshape,
    PReLU, Lambda, merge
)

# from keras.pooling import AveragePooling1D
# from keras.pooling import AveragePooling2D
# from keras.pooling import AveragePooling3D
# from keras.pooling import MaxPooling1D
# from keras.pooling import MaxPooling2D
# from keras.pooling import MaxPooling3D

class DenseAfterRevcompConv1D(Dense):
    '''For dense layers that follow 1D Convolutional or Pooling layers that
    have reverse-complement weight sharing
    '''

    def build(self, input_shape):
        assert len(input_shape) == 3, "layer designed to follow 1D conv/pool"
        num_chan = input_shape[-1]
        input_length = input_shape[-2]
        assert num_chan%2 == 0,\
         ("num_chan should be even if input is WeightedSum layer with"+
          " input_is_revcomp_conv=True")
        self.num_chan = num_chan
        self.input_length = input_length
        self.input_spec = [InputSpec(dtype=K.floatx(),
                                     ndim='2+')]

        self.W = self.add_weight((input_length*num_chan/2, self.output_dim),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight((self.output_dim,),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def get_output_shape_for(self, input_shape):
        assert input_shape and len(input_shape) == 3
        output_shape = [input_shape[0], self.output_dim]
        return tuple(output_shape)

    def call(self, x, mask=None):
        W = K.reshape(self.W, (self.input_length, int(self.num_chan/2),
                               self.output_dim))
        concatenated_reshaped_W = K.reshape(K.concatenate(
            tensors=[W, W[::-1,::-1,:]], axis=1),
            (self.input_length*self.num_chan, self.output_dim))
        reshaped_x = K.reshape(x, (-1, self.input_length*self.num_chan))
        output = K.dot(reshaped_x, concatenated_reshaped_W)
        if self.bias:
            output += self.b
        return self.activation(output)


class DenseAfterRevcompWeightedSum(Dense):
    '''For dense layers that follow WeightedSum layers
    that have input_is_revcomp_conv=True
    '''

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = input_shape[-1]
        assert input_dim%2 == 0,\
         ("input_dim should be even if input is WeightedSum layer with"+
          " input_is_revcomp_conv=True")
        self.input_dim = input_dim
        self.input_spec = [InputSpec(dtype=K.floatx(),
                                     ndim='2+')]

        self.W = self.add_weight((input_dim/2, self.output_dim),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight((self.output_dim,),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, x, mask=None):
        output = K.dot(x, K.concatenate(
                             tensors=[self.W, self.W[::-1,:]], axis=0))
        if self.bias:
            output += self.b
        return self.activation(output)


class RevCompConv1D(Convolution1D):
    '''Like Convolution1D, except the reverse-complement filters with tied
    weights are added in the channel dimension. The reverse complement
    of the channel at index i is at index -i.
    # Example
    ```python
        # apply a reverse-complemented convolution 1d of length 20
        # to a sequence with 100bp input, with 2*64 output filters
        model = Sequential()
        model.add(RevCompConv1D(nb_filter=64, filter_length=20,
                                border_mode='same', input_shape=(100, 4)))
        # now model.output_shape == (None, 100, 128)
        # add a new reverse-complemented conv1d on top
        model.add(RevCompConv1D(nb_filter=32, filter_length=10,
                                border_mode='same'))
        # now model.output_shape == (None, 10, 64)
    ```
    # Arguments
        nb_filter: Number of non-reverse complemented convolution kernels
            to use (half the dimensionality of the output).
        filter_length: The extension (spatial or temporal) of each filter.
        init: name of initialization function for the weights of the layer
            (see [initializations](../initializations.md)),
            or alternatively, Theano function to use for weights initialization.
            This parameter is only relevant if you don't pass a `weights` argument.
        activation: name of activation function to use
            (see [activations](../activations.md)),
            or alternatively, elementwise Theano function.
            If you don't specify anything, no activation is applied
            (ie. "linear" activation: a(x) = x).
        weights: list of numpy arrays to set as initial weights
            (reverse-complemented portion should not be included as 
            it's applied during compilation)
        border_mode: 'valid', 'same' or 'full'. ('full' requires the Theano backend.)
        subsample_length: factor by which to subsample output.
        W_regularizer: instance of [WeightRegularizer](../regularizers.md)
            (eg. L1 or L2 regularization), applied to the main weights matrix.
        b_regularizer: instance of [WeightRegularizer](../regularizers.md),
            applied to the bias.
        activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
            applied to the network output.
        W_constraint: instance of the [constraints](../constraints.md) module
            (eg. maxnorm, nonneg), applied to the main weights matrix.
        b_constraint: instance of the [constraints](../constraints.md) module,
            applied to the bias.
        bias: whether to include a bias
            (i.e. make the layer affine rather than linear).
        input_dim: Number of channels/dimensions in the input.
            Either this argument or the keyword argument `input_shape`must be
            provided when using this layer as the first layer in a model.
        input_length: Length of input sequences, when it is constant.
            This argument is required if you are going to connect
            `Flatten` then `Dense` layers upstream
            (without it, the shape of the dense outputs cannot be computed).
    # Input shape
        3D tensor with shape: `(samples, steps, input_dim)`.
    # Output shape
        3D tensor with shape: `(samples, new_steps, nb_filter)`.
        `steps` value might have changed due to padding.
    '''

    def get_output_shape_for(self, input_shape):
        length = conv_output_length(input_shape[1],
                                    self.filter_length,
                                    self.border_mode,
                                    self.subsample[0])
        return (input_shape[0], length, 2*self.nb_filter)

    def call(self, x, mask=None):
        #create a rev-comped W. The last axis is the output channel axis.
        #dim 1 is dummy axis of size 1 (see 'build' method in Convolution1D)
        #Rev comp is along both the length (dim 0) and input channel (dim 2)
        #axes; that is the reason for ::-1, ::-1 in the first and third dims.
        #The rev-comp of channel at index i should be at index -i
        #This is the reason for the ::-1 in the last dim.
        rev_comp_W = K.concatenate([self.W, self.W[::-1,:,::-1,::-1]],axis=-1)
        if (self.bias):
            rev_comp_b = K.concatenate([self.b, self.b[::-1]], axis=-1)
        x = K.expand_dims(x, 2)  # add a dummy dimension
        output = K.conv2d(x, rev_comp_W, strides=self.subsample,
                          border_mode=self.border_mode,
                          dim_ordering='tf')
        output = K.squeeze(output, 2)  # remove the dummy dimension
        if self.bias:
            output += K.reshape(rev_comp_b, (1, 1, 2*self.nb_filter))
        output = self.activation(output)
        return output

class RevCompConv1DBatchNorm(Layer):
    '''Batch norm that shares weights over reverse complement channels
    '''
    def __init__(self, epsilon=1e-3, mode=0, axis=-1, momentum=0.99,
                 weights=None, beta_init='zero', gamma_init='one',
                 gamma_regularizer=None, beta_regularizer=None, **kwargs):
        self.supports_masking = True
        self.beta_init = initializations.get(beta_init)
        self.gamma_init = initializations.get(gamma_init)
        self.epsilon = epsilon
        self.mode = mode
        assert axis==-1 or axis==2, "Intended for Conv1D"
        self.axis = axis
        self.momentum = momentum
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.initial_weights = weights
        if self.mode == 0:
            self.uses_learning_phase = True
        super(RevCompConv1DBatchNorm, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        self.num_input_chan = input_shape[self.axis]
        self.input_len = input_shape[1]
        assert len(input_shape)==3,\
         "Implementation done with RevCompConv1D input in mind"
        assert self.input_len is not None,\
         "not implemented for undefined input len"
        assert self.num_input_chan%2 == 0, "should be even for revcomp input"
        shape = (int(self.num_input_chan/2),)

        self.gamma = self.add_weight(shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='{}_gamma'.format(self.name))
        self.beta = self.add_weight(shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='{}_beta'.format(self.name))
        self.running_mean = self.add_weight(shape, initializer='zero',
                                            name='{}_running_mean'.format(self.name),
                                            trainable=False)
        self.running_std = self.add_weight(shape, initializer='one',
                                           name='{}_running_std'.format(self.name),
                                           trainable=False)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, x, mask=None):
        orig_x = x
        #create a fake x by concatentating reverse-complemented pairs
        #along the length dimension
        x = K.concatenate(
            tensors=[x[:,:,:int(self.num_input_chan/2)],
                     x[:,:,int(self.num_input_chan/2):][:,:,::-1]],
            axis=1)
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'

            reduction_axes = list(range(3))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * 3
            broadcast_shape[self.axis] = int(self.num_input_chan/2)

            x_normed, mean, std = K.normalize_batch_in_training(
                x, self.gamma, self.beta, reduction_axes,
                epsilon=self.epsilon)

            if self.mode == 0:
                self.add_update([K.moving_average_update(self.running_mean, mean, self.momentum),
                                 K.moving_average_update(self.running_std, std, self.momentum)], x)

                # need broadcasting
                broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
                broadcast_running_std = K.reshape(self.running_std, broadcast_shape)
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                x_normed_running = K.batch_normalization(
                    x, broadcast_running_mean, broadcast_running_std,
                    broadcast_beta, broadcast_gamma,
                    epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=-1, keepdims=True)
            std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
            x_normed = (x - m) / (std + self.epsilon)
            x_normed = self.gamma * x_normed + self.beta
        #recover the reverse-complemented channels
        true_x_normed = K.concatenate(
            tensors=[x_normed[:,:self.input_len,:],
                     x_normed[:,self.input_len:,:][:,:,::-1]],
            axis=2)
        return true_x_normed

    def get_config(self):
        config = {'epsilon': self.epsilon,
                  'mode': self.mode,
                  'axis': self.axis,
                  'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None,
                  'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None,
                  'momentum': self.momentum}
        base_config = super(RevCompConv1DBatchNorm, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

