##################################
#                                #
# Last modified 01/28/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set
import os
import random
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s <fasta> <N occupancy sites> <FRiP value> <number unique fragments> <number reads> <fragment length mean> <fragment length stdev> <read length> <site strength distribution mean> [-outputPeaks outfile]' % sys.argv[0]
        print '\tNote: the script will print to stdout, use bzip2 or gzip to capture the output to a compressed fasta file'
        print '\tNote: The FRiP values is in the [0;1] range'
        print '\tNote: The site strength distribution is modelled as exponential; the input required the mean (beta, as parametrized in numpy)'
        sys.exit(1)

    fasta = sys.argv[1]
    Nsites = int(sys.argv[2])
    FRiP = float(sys.argv[3])
    Nfrags = int(sys.argv[4])
    Nreads = int(sys.argv[5])
    FLmean = int(sys.argv[6])
    FLstdev = int(sys.argv[7])
    RL = int(sys.argv[8])
    ExpGamma = float(sys.argv[9])

    doOP = False
    if '-outputPeaks' in sys.argv:
        doOP = True
        OP = open(sys.argv[sys.argv.index('-outputPeaks') + 1],'w')
        

    SeqDict = {}
    inputdatafile = open(fasta)
    chr = ''
    for line in inputdatafile:
        if line[0]=='>':
            if chr != '':
                sequence = ''.join(sequence)
                SeqDict[chr] = sequence
            chr = line.strip().split('>')[1]
            sequence=[]
        else:
            sequence.append(line.strip())
    sequence = ''.join(sequence)
    SeqDict[chr] = sequence

    TGL = 0

    for chr in SeqDict.keys():
        TGL += len(SeqDict[chr])

    print TGL

    

    PeakDict = {}
    for chr in SeqDict.keys():
        print chr
        PeakDict[chr] = {}
        for i in range(len(SeqDict[chr])):
            j = random.randint(0,TGL)
            if j <= Nsites:
# Site strength distribution
                PeakDict[chr][(i,S)] = 1
                if doOP:
                    outline = chr + '\t' + str(i - 100) + '\t' + str(i + 100)
                    OP.write(outline + '\n')

    if doOP:
        OP.close()


### Site strength distribution

    for chr in PeakDict.keys():
        for (i,S) in PeakDict[chr].keys():
            for k in range(S):
               pos1 = i - FLmean/2 + np.random.normal(0, sigma, 1)[0]
               pos2 = i + FLmean/2 + np.random.normal(0, sigma, 1)[0]
               p = randint(0,100)
               if p < 50:
                   read = SeqDict[chr][pos1:pos1+RL]
                   strand = 'for'
               else:
                   read = getReverseComplement(SeqDict[chr][pos2-RL:pos2])
                   strand = 'rev'
               print '>' + chr + ':' + pos1 + '-' + pos2 + '::' + strand
               print read

### Background reads


run()

