##################################
#                                #
# Last modified 2020/07/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import gzip
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fasta inputfilename outpurfilename [-seqradius bp] [-usepeak fieldID|bed] [-chrfield fieldID] [-first number] [-name fieldID(s)] [-strand fieldID] [-narrowPeak] [-TtoU] [-allCAP]' % sys.argv[0]
        print '\tby default the forward strand is outputted' 
        print '\tdefault sequence names format is chr:start-stop format' 
        print '\tthe fasta file can be in .gz or .bz2 format, but it has to end with these suffixes' 
        print '\tuse the -name option if you want to change the sequence output format - the content of the fields indicated will be used and the sequence name construct from those separated by "::"' 
        sys.exit(1)

    fieldID = 0
    if '-chrfield' in sys.argv:
        fieldID = int(sys.argv[sys.argv.index('-chrfield') + 1])
    
    doCAP = False
    if '-allCAP' in sys.argv:
        doCAP = True
        print 'will capitalized output sequences'
 
    doName = False
    if '-name' in sys.argv:
        doName = True
        fields = sys.argv[sys.argv.index('-name') + 1].split(',')
        nameFields=[]
        for ID in fields:
            nameFields.append(int(ID))

    doStrand = False
    if '-strand' in sys.argv:
        doStrand = True
        strandField = int(sys.argv[sys.argv.index('-strand') + 1])

    doPeak = False
    if '-usepeak' in sys.argv:
        doPeak = True
        doBed = False
        if sys.argv[sys.argv.index('-usepeak') + 1] == 'bed':
            doBed = True
        else:
            peakFieldID = int(sys.argv[sys.argv.index('-usepeak') + 1])
    if '-seqradius' in sys.argv:
        radius = int(sys.argv[sys.argv.index('-seqradius') + 1])

    doNarrowPeak=False
    if '-narrowPeak' in sys.argv:
        doNarrowPeak=True
        doPeak=True

    doU = False
    if '-TtoU' in sys.argv:
        print 'will output RNA sequence'
        doU = True

    fasta = sys.argv[1]
    inputfilename = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')
    
    doFirst=False   
    if '-first' in sys.argv: 
        doFirst=True
        firstN = int(sys.argv[sys.argv.index('-first') + 1])

    GenomeDict={}
    sequence=''
    if fasta.endswith('.gz'):
        inputdatafile = gzip.open(fasta)
    else:
        inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    i=0
    if inputfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + inputfilename
    elif inputfilename.endswith('.gz'):
        cmd = 'gunzip -c ' + inputfilename
    else:
        cmd = 'cat ' + inputfilename
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        i+=1
        if doFirst and i >= firstN+1:
            continue
        if line[0]=='#':
            outline = line.strip() + '\tsequence'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        if doPeak:
            if doNarrowPeak:
                if len(fields) < 9:
                    continue
                chr = fields[0]
                peak = int(fields[1]) + int(fields[9])
                start = peak-radius
                stop = peak+radius
            else:
                if doBed:
                    peak = int((int(fields[fieldID + 1]) + int(fields[fieldID + 2]))/2.0)
                else:
                    peak = int(fields[peakFieldID].strip())
                chr = fields[fieldID].strip()
                start = peak - radius
                stop = peak + radius
        else:
            chr = fields[fieldID].strip()
            start = int(fields[fieldID+1].strip())
            stop = int(fields[fieldID+2].strip())
        sense = '+'
        if doStrand:
            strand=fields[strandField]
            if strand == 'F' or strand == '+':
                sense='+'
            if strand == 'R' or strand == '-':
                sense='+'
        sequence = GenomeDict[chr][start:stop]
        if sense == '-':
            sequence = getReverseComplement
        outfile.write(line.strip() + '\t' + sequence + '\n')
   
run()
