##################################
#                                #
# Last modified 06/11/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import random
import numpy

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def getRNAFoldEnergies(input):
    
    valuelist=[]
    linelist = open(input)
    for line in linelist:
        if line.endswith(')\n'):
            valuelist.append(float(line.split(')\n')[0].rpartition('(')[2]))
        else:
            continue
    return valuelist

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s fasta regionlength number_iterations RNAfold_path outfile_prefix' % sys.argv[0]
        print ' Note: circular chromosomes are assumed'
        sys.exit(1)
    
    fasta = sys.argv[1]
    length = int(sys.argv[2])
    NumIter = int(sys.argv[3])
    RNAFoldPath = sys.argv[4]
    outfileprefix = sys.argv[5]

    SequenceDict={}
    inputdatafile = open(fasta)
    Keep=False
    for line in inputdatafile:
        if line[0]=='>':
            if Keep:
                sequence = ''.join(sequence)
                SequenceDict[chr]=sequence
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=True
            continue
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[chr]=sequence

    tempfile=outfileprefix + '.temp'
    tempfile2=outfileprefix + '.temp2'

    keys=SequenceDict.keys()
    keys.sort()
    ScoreDict={}
    for chr in keys:
        ScoreDict[chr]={}
        ScoreDict[chr]['+']={}
        ScoreDict[chr]['-']={}
        sequence = SequenceDict[chr] + SequenceDict[chr][0:length]
        for i in range(len(SequenceDict[chr])):
            if i % 100 == 0:
                print '+ strand ', i
            seqfile = open(tempfile,'w')
            subseq = sequence[i:i+length]
            seqfile.write('>read' + '\n')
            seqfile.write(subseq + '\n')
            seqfile.close()
            cmd = RNAFoldPath + ' < ' + tempfile + ' > ' + tempfile2
            contents = os.system(cmd)
            subseqEnergy = getRNAFoldEnergies(tempfile2)
            if len(subseqEnergy) > 1:
                print 'error, multiple energies returned for sequence'
            subseqEnergy = subseqEnergy[0]
            seqfile = open(tempfile,'w')
            for j in range(NumIter):
                s=list(subseq)
                random.shuffle(s)
                shuffled=''.join(s)
                seqfile.write('>read' + str(j) + '\n')
                seqfile.write(shuffled + '\n')
            seqfile.close()
            cmd = RNAFoldPath + ' < ' + tempfile + ' > ' + tempfile2
            contents = os.system(cmd)
            shuffledSubseqEnergy = getRNAFoldEnergies(tempfile2)
            SSEmean = numpy.mean(shuffledSubseqEnergy)
            SSEstd = numpy.std(shuffledSubseqEnergy)
            zscore = (-subseqEnergy - (-SSEmean)) / SSEstd
            ScoreDict[chr]['+'][i] = zscore
        sequence = getReverseComplement(SequenceDict[chr]) + getReverseComplement(SequenceDict[chr])[0:length]
        for i in range(len(SequenceDict[chr])):
            if i % 100 == 0:
                print '- strand ', i
            seqfile = open(tempfile,'w')
            subseq = sequence[i:i+length]
            seqfile.write('>read' + '\n')
            seqfile.write(subseq + '\n')
            seqfile.close()
            cmd = RNAFoldPath + ' < ' + tempfile + ' > ' + tempfile2
            contents = os.system(cmd)
            subseqEnergy = getRNAFoldEnergies(tempfile2)
            if len(subseqEnergy) > 1:
                print 'error, multiple energies returned for sequence'
            subseqEnergy = subseqEnergy[0]
            seqfile = open(tempfile,'w')
            for j in range(NumIter):
                s=list(subseq)
                random.shuffle(s)
                shuffled=''.join(s)
                seqfile.write('>read' + str(j) + '\n')
                seqfile.write(shuffled + '\n')
            seqfile.close()
            cmd = RNAFoldPath + ' < ' + tempfile + ' > ' + tempfile2
            contents = os.system(cmd)
            shuffledSubseqEnergy = getRNAFoldEnergies(tempfile2)
            SSEmean = numpy.mean(shuffledSubseqEnergy)
            SSEstd = numpy.std(shuffledSubseqEnergy)
            zscore = (- subseqEnergy - (-SSEmean)) / SSEstd
            ScoreDict[chr]['-'][len(SequenceDict[chr])-i] = zscore

    os.remove(tempfile)
    os.remove(tempfile2)

    outfile = open(outfileprefix + '.plus.wig','w')
    for chr in keys:
        poskeys = ScoreDict[chr]['+'].keys()
        poskeys.sort()
        for i in poskeys:
            outline = chr + '\t' + str(i) + '\t' + str(i+1) + '\t' + str(ScoreDict[chr]['+'][i])
            outfile.write(outline + '\n')
    outfile.close()

    outfile = open(outfileprefix + '.minus.wig','w')
    for chr in keys:
        poskeys = ScoreDict[chr]['-'].keys()
        poskeys.sort()
        for i in poskeys:
            outline = chr + '\t' + str(i) + '\t' + str(i+1) + '\t' + str(ScoreDict[chr]['-'][i])
            outfile.write(outline + '\n')
    outfile.close()
   
run()
