##################################
#                                #
# Last modified 2019/07/08       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import time
from collections import OrderedDict
from seqdataloader.batchproducers import coordbased
from seqdataloader.batchproducers.coordbased import coordstovals
from seqdataloader.batchproducers.coordbased import coordbatchproducers
from seqdataloader.batchproducers.coordbased import coordbatchtransformers
import keras
import keras.layers as kl
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from seqdataloader.batchproducers import coordbased
from seqdataloader.batchproducers.coordbased import coordstovals
from seqdataloader.batchproducers.coordbased import coordbatchproducers
from seqdataloader.batchproducers.coordbased import coordbatchtransformers

def multinomial_nll(true_counts, logits):
# Compute the multinomial negative log-likelihood
# Args:
#       true_counts: observed count values
#       logits: predicted logit values

    counts_per_example = tf.reduce_sum(true_counts, axis=-1)
    dist = tfp.distributions.Multinomial(total_count=counts_per_example,
                                         logits=logits)
    return (-tf.reduce_sum(dist.log_prob(true_counts)) / 
            tf.to_float(tf.shape(true_counts)[0]))

class MultichannelMultinomialNLL(object):
    def __init__(self, n):
        self.__name__ = "MultichannelMultinomialNLL"
        self.n = n

    def __call__(self, true_counts, logits):
        for i in range(self.n):
            loss = multinomial_nll(true_counts[..., i], logits[..., i])
            if i == 0:
                total = loss
            else:
                total += loss
        return total

    def get_config(self):
        return {"n": self.n}

#If we want to avoid zero-padding, then the size of the output predictions
# will depend on the size of the input sequence supplied. We define the
# API for an AbstractProfileModel class which returns the length of the
# output profile in addition to returning the model, given information
# on the size of the input sequence and the model parameters.
class AbstractProfileModel(object):
    
    def get_output_profile_len(self):
        raise NotImplementedError()
  
    def get_model(self):
        raise NotImplementedError()
  
#The architecture by Žiga Avsec involves residual connections, which means
# that the layers being added together in an elementwise fashion need
# to have the same dimensions. To achieve this without zero-padding, we
# have to trim away the flanks of earlier convolutional layers. That
# is what this function is meant to do. (Note that the original BP-net
# architecture zero-pads; this is a modification to avoid the zero
# padding and use information on actual sequence instead)
def trim_flanks_of_conv_layer(conv_layer, output_len, width_to_trim, filters):
    layer = keras.layers.Lambda(
        lambda x: x[:,
          int(0.5*(width_to_trim)):-(width_to_trim-int(0.5*(width_to_trim)))],
        output_shape=(output_len, filters))(conv_layer)
    return layer
        
#This model architecture is based on BP-Net by Žiga Avsec
# https://drive.google.com/file/d/1kg6Ic0-FvJtVUva9Mh3FPnOAHJcN6VB-/view
#It has been modified for this specific use-case.
class BPnetArch(AbstractProfileModel):   

    def __init__(self, input_seq_len, c_task_weight, filters,
                       n_dil_layers, conv1_kernel_size,
                       dil_kernel_size,
                       outconv_kernel_size, lr):
        self.input_seq_len = input_seq_len
        self.c_task_weight = c_task_weight
        self.filters = filters
        self.n_dil_layers = n_dil_layers
        self.conv1_kernel_size = conv1_kernel_size
        self.dil_kernel_size = dil_kernel_size
        self.outconv_kernel_size = outconv_kernel_size
        self.lr = lr
    
    def get_embedding_len(self):
        embedding_len = self.input_seq_len
        embedding_len -= (self.conv1_kernel_size-1)     
        for i in range(1, self.n_dil_layers+1):
            dilation_rate = (2**i)
            embedding_len -= dilation_rate*(self.dil_kernel_size-1)
        return embedding_len
    
    def get_output_profile_len(self):
        embedding_len = self.get_embedding_len()
        out_profile_len = embedding_len - (self.outconv_kernel_size - 1)
        return out_profile_len
    
    def get_keras_model(Xbin):
      
        out_pred_len = self.get_output_profile_len()
        
        #'inp' is the one-hot encoded DNA sequence input
        inp = kl.Input(shape=(self.input_seq_len, 4), name='sequence')
        first_conv = kl.Conv1D(filters=self.filters,
                               kernel_size=self.conv1_kernel_size,
                               padding='valid',
                               activation='relu')(inp)
        #Need to keep track of the layer size for trimming purposes when
        # we get to the residual connections.
        curr_layer_size = self.input_seq_len - (self.conv1_kernel_size-1)
        
        #Define input layers for the control tracks - both counts and profile
        #Dimension is '1' for the ChIP-seq control counts because the positive
        # and negative strands are added together
        # chipseq_control_counts_input = kl.Input(shape=(1,), name="ChIPseq.control.logcount")
        # cutnrun_control_standard_counts_input = kl.Input(shape=(1,), name="CUTNRUN.control.standard.logcount")
        # cutnrun_control_auto_counts_input = kl.Input(shape=(1,), name="CUTNRUN.control.auto.logcount")
        
        #Second dimension for the profile tasks are twice those for the
        # respective counts tasks because in addition to the original profile,
        # we supply the profile smoothed by 50bp.
        chipseq_control_profile_input = kl.Input(shape=(out_pred_len, 2), name="ChIPseq.control.profile")
        # cutnrun_control_standard_profile_input = kl.Input(shape=(out_pred_len, 2), name="CUTNRUN.control.standard.profile")
        # cutnrun_control_auto_profile_input = kl.Input(shape=(out_pred_len, 2), name="CUTNRUN.control.auto.profile")
        
        #Gather together all the tensors representing the model inputs
        model_inputs = [
            inp
            # inp,
            # chipseq_control_counts_input,
            # cutnrun_control_standard_counts_input,
            # cutnrun_control_auto_counts_input,
            # chipseq_control_profile_input,
            # cutnrun_control_standard_profile_input,
            # cutnrun_control_auto_profile_input
        ]
        
        #Prepare the stack of dilated convolutions with residual connections
        prev_layers = [first_conv]
        for i in range(1, self.n_dil_layers + 1):
          dilation_rate = 2**i
          if i == 1:
              prev_sum = first_conv
          else:
              prev_sum = kl.merge.Add()(prev_layers)
          conv_output = kl.Conv1D(filters=self.filters,
                                  kernel_size=self.dil_kernel_size,
                                  padding='valid',
                                  activation='relu',
                                  dilation_rate=dilation_rate)(prev_sum)          
          width_to_trim = dilation_rate*(self.dil_kernel_size-1)
          curr_layer_size = (curr_layer_size - width_to_trim)
          prev_layers = [trim_flanks_of_conv_layer(
              conv_layer=x, output_len=curr_layer_size,
              width_to_trim=width_to_trim, filters=self.filters)
              for x in prev_layers]
          prev_layers.append(conv_output)

        combined_conv = kl.merge.Add()(prev_layers)

        # gap = GlobalAveragePooling. This layer is used as input for the
        # counts prediction tasks.
        gap_combined_conv = kl.GlobalAvgPool1D()(combined_conv)
        
        lossarr = []
        lossweightsarr = []
        model_outputs = []
        
        #Define the output layers for the counts prediction tasks
        for countouttaskname, numunits in [
          ("TF.logcount", 1),
        ]:
          count_out = kl.Dense(units=numunits,
                                 name=countouttaskname)(
              kl.concatenate([gap_combined_conv], axis=-1))
          model_outputs.append(count_out)
          lossarr.append('mse')
          lossweightsarr.append(self.c_task_weight)

        #Define the output layers for the profile prediction tasks
        for profileouttaskname, numunits in [
          ("TF.profile", 1),
        ]:     
          profile_out_precontrol = kl.Conv1D(
                                 filters=numunits,
                                 kernel_size=self.outconv_kernel_size,
                                 padding='valid')(combined_conv)
          profile_out = kl.Conv1D(
              filters=numunits, kernel_size=1, stride=Xbin, name=profileouttaskname)(
                      kl.concatenate([profile_out_precontrol], axis=-1))
          model_outputs.append(profile_out)
          lossarr.append(MultichannelMultinomialNLL(numunits)) 
          # We downweight the loss by the number of channels because, if you
          # read the code for MultichannelMultinomialNLL, you'll see that
          # the loss for different channels is added together;
          # but I (Avanti Shrikumar) didn't want to implicitly upweight
          # the prediction tasks that happen
          # to have more channels, hence this downweighting.
          lossweightsarr.append(1.0/numunits)

        #Compile the model and return it
        model = keras.models.Model(inputs=model_inputs, outputs=model_outputs)
        model.compile(keras.optimizers.Adam(lr=self.lr),
                      loss=lossarr,
                      loss_weights=lossweightsarr)
        return model

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta V-plit.bgz Xbin Ybin XX ' % sys.argv[0]
        sys.exit(1)

    input = sys.argv[1]
    Vplot = sys.argv[2]
    Xbin = int(sys.argv[3])
    Xbin = int(sys.argv[4])

    seq_len = 1346
    modelwrapper = BPnetArch(input_seq_len=seq_len, 
                             c_task_weight=500,
                             filters=64, 
                             n_dil_layers=6,
                             conv1_kernel_size=21,
                             dil_kernel_size=3,
                             outconv_kernel_size=75,
                             lr=0.001)
    out_pred_len = modelwrapper.get_output_profile_len()
    print out_pred_len, seq_len-out_pred_len

    def vplot_coords_to_vals(coords):
        vplots = []
        for coord in coords:
            chrom = coord.chrom
            start = coord.start
            end = coord.end
            isplusstrand = coord.isplusstrand
            #write code to fetch the vplot 
            vplots.append(the_vplot)
        return {'vplot': np.array(vplots)}   

    # the code below is used to prepare instances of keras.utils.Sequence that
    # can be supplied to model.fit_generator(...)
    # Note that we log-transform our counts using np.log(counts+1)
    # Also note that the profiles for the control are smoothed by windows of
    # size 1 and 50 (smoothing by a window of size 1 just returns
    # the original profile)
    inputs_coordstovals = coordstovals.core.CoordsToValsJoiner(
        coordstovals_list=[
            coordbased.coordstovals.fasta.PyfaidxCoordsToVals(
                genome_fasta_path=fasta,
                mode_name="sequence",
                center_size_to_use=seq_len),
             coordstovals.bigwig.MultiTrackCountsAndProfile(
                counts_mode_name="CUTNRUN.control.standard.logcount",
                profile_mode_name="CUTNRUN.control.standard.profile",
                bigwig_paths=["CUTNRUN.IgG.standard.5p.counts.bigWig",
                      #"CUTNRUN.IgG.standard.maxFL120bp.5p.counts.bigWig",
                      #"CUTNRUN.IgG.standard.minFL150bp.5p.counts.bigWig"
                     ],
        counts_and_profiles_transformer=
          coordstovals.bigwig.LogCountsPlusOne().chain(
            coordstovals.bigwig.SmoothProfiles(smoothing_windows=[1,50])),       
        center_size_to_use=out_pred_len),
     vplot_coords_to_vals
    ])
targets_coordstovals = coordstovals.core.CoordsToValsJoiner(
    coordstovals_list=[
      coordstovals.bigwig.PosAndNegSeparateLogCounts(
        counts_mode_name="ChIPseq.POU5F1.logcount",
        profile_mode_name="ChIPseq.POU5F1.profile",
        pos_strand_bigwig_path="ChIPseq.POU5F1.merged.5p.counts.plus.bigWig",
        neg_strand_bigwig_path="ChIPseq.POU5F1.merged.5p.counts.minus.bigWig",
        center_size_to_use=out_pred_len),
      coordstovals.bigwig.PosAndNegSeparateLogCounts(
        counts_mode_name="ChIPseq.NANOG.logcount",
        profile_mode_name="ChIPseq.NANOG.profile",
        pos_strand_bigwig_path="ChIPseq.NANOG.merged.5p.counts.plus.bigWig",
        neg_strand_bigwig_path="ChIPseq.NANOG.merged.5p.counts.minus.bigWig",
        center_size_to_use=out_pred_len),
      coordstovals.bigwig.MultiTrackCountsAndProfile(
        counts_mode_name="CUTNRUN.POU5F1.logcount",
        profile_mode_name="CUTNRUN.POU5F1.profile",
        bigwig_paths=["CUTNRUN.POU5F1.5p.counts.bigWig",
                      #"CUTNRUN.POU5F1.maxFL120bp.5p.counts.bigWig",
                      #"CUTNRUN.POU5F1.minFL150bp.5p.counts.bigWig"
                     ],
        counts_and_profiles_transformer=
          coordstovals.bigwig.LogCountsPlusOne().chain(
              coordstovals.bigwig.SmoothProfiles(smoothing_windows=[10])),
        center_size_to_use=out_pred_len),
      #coordstovals.bigwig.MultiTrackCountsAndProfile(
      #  counts_mode_name="CUTNRUN.NANOG.logcount",
      #  profile_mode_name="CUTNRUN.NANOG.profile",
      #  bigwig_paths=["CUTNRUN.NANOG.5p.counts.bigWig",
      #                "CUTNRUN.NANOG.maxFL120bp.5p.counts.bigWig",
      #                "CUTNRUN.NANOG.minFL150bp.5p.counts.bigWig"],
      #  counts_and_profiles_transformer=coordstovals.bigwig.LogCountsPlusOne().chain(
      #        coordstovals.bigwig.SmoothProfiles(smoothing_windows=[10])),
      #  center_size_to_use=out_pred_len)
    ])

def qc_func(inputs, targets):
    return (targets['CUTNRUN.POU5F1.logcount'][:,0] <= 8)

....

    keras_train_batch_generator = coordbased.core.KerasBatchGenerator(
      coordsbatch_producer=coordbatchproducers.SimpleCoordsBatchProducer(
          bed_file="train_2k_around_summits.bed.gz",
          coord_batch_transformer=
              coordbatchtransformers.ReverseComplementAugmenter().chain(
              coordbatchtransformers.UniformJitter(
                  maxshift=200, chromsizes_file="hg38.chrom.sizes")),
          batch_size=64,
          shuffle_before_epoch=True, 
          seed=1234),
      inputs_coordstovals=inputs_coordstovals,
      targets_coordstovals=targets_coordstovals,
      qc_func=qc_func
    )
    keras_valid_batch_generator = coordbased.core.KerasBatchGenerator(
      coordsbatch_producer=coordbatchproducers.SimpleCoordsBatchProducer(
                bed_file="valid_2k_around_summits.bed.gz",
                batch_size=64,
                shuffle_before_epoch=False, 
                seed=1234),
      inputs_coordstovals=inputs_coordstovals,
      targets_coordstovals=targets_coordstovals
    )

    #As a sanity check, print out the dimensions of everything in individual batches
    sampinputs,samptargets = keras_train_batch_generator[0]
    for key in sampinputs:
          print(key, sampinputs[key].shape)
    for key in samptargets:
        print(key, samptargets[key].shape)
    
    model = modelwrapper.get_keras_model()
    print(model.summary())
    early_stopping_callback = keras.callbacks.EarlyStopping(
                            patience=10, restore_best_weights=True)
    loss_history = model.fit_generator(keras_train_batch_generator,
                    epochs=200,
                    validation_data=keras_valid_batch_generator,
                    callbacks=[early_stopping_callback])
    model.set_weights(early_stopping_callback.best_weights)

    outfile = open(outfilename, 'w')

    listoflines = open(input)
    for line in listoflines:
        if line.startswith('"intervals_file":'):
            newline = line.replace(train_suffix, test_suffix)
            outfile.write(newline)
        else:
            outfile.write(line)

    outfile.close()

run()
