##################################
#                                #
# Last modified 2022/05/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fasta inputfilename outpurfilename [-seqradius bp] [-usepeak fieldID|bed] [-chrfield fieldID] [-first number] [-name fieldID(s)] [-rename string] [-strand fieldID] [-narrowPeak] [-TtoU] [-allCAP]' % sys.argv[0]
        print '\tby default the forward strand is outputted' 
        print '\tdefault sequence names format is chr:start-stop format' 
        print '\tthe fasta file can be in .gz or .bz2 format, but it has to end with these suffixes' 
        print '\tuse the -name option if you want to change the sequence output format - the content of the fields indicated will be used and the sequence name construct from those separated by "::"' 
        sys.exit(1)

    fieldID=1
    if '-chrfield' in sys.argv:
        fieldID = int(sys.argv[sys.argv.index('-chrfield') + 1])
    
    doCAP = False
    if '-allCAP' in sys.argv:
        doCAP = True
        print 'will capitalized output sequences'
 
    doReName = False
    if '-rename' in sys.argv:
        doReName = True
        ReNameString = sys.argv[sys.argv.index('-rename') + 1]

    doName = False
    if '-name' in sys.argv:
        doName = True
        fields = sys.argv[sys.argv.index('-name') + 1].split(',')
        nameFields=[]
        for ID in fields:
            nameFields.append(int(ID))

    doStrand = False
    if '-strand' in sys.argv:
        doStrand = True
        strandField = int(sys.argv[sys.argv.index('-strand') + 1])

    doPeak = False
    if '-usepeak' in sys.argv:
        doPeak = True
        doBed = False
        if sys.argv[sys.argv.index('-usepeak') + 1] == 'bed':
            doBed = True
        else:
            peakFieldID = int(sys.argv[sys.argv.index('-usepeak') + 1])
    if '-seqradius' in sys.argv:
        radius = int(sys.argv[sys.argv.index('-seqradius') + 1])

    doNarrowPeak=False
    if '-narrowPeak' in sys.argv:
        doNarrowPeak=True
        doPeak=True

    doU = False
    if '-TtoU' in sys.argv:
        print 'will output RNA sequence'
        doU = True

    fasta = sys.argv[1]
    inputfilename = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')
    
    RegionDict={}

    doFirst=False   
    if '-first' in sys.argv: 
        doFirst=True
        firstN = int(sys.argv[sys.argv.index('-first') + 1])

    i=0
#    inputdatafile = open(inputfilename)
#    for line in inputdatafile:
    if inputfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + inputfilename
    elif inputfilename.endswith('.gz'):
        cmd = 'gunzip -c ' + inputfilename
    else:
        cmd = 'cat ' + inputfilename
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        i+=1
        if doFirst and i >= firstN+1:
            continue
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        if doPeak:
            if doNarrowPeak:
                if len(fields) < 9:
                    continue
                chr = fields[0]
                peak = int(fields[1]) + int(fields[9])
                start = peak-radius
                stop = peak+radius
            else:
                if doBed:
                    peak = int((int(fields[fieldID + 1]) + int(fields[fieldID + 2]))/2.0)
                else:
                    peak = int(fields[peakFieldID].strip())
                chr = fields[fieldID].strip()
                start = peak - radius
                stop = peak + radius
        else:
            chr = fields[fieldID].strip()
            start = int(fields[fieldID+1].strip())
            stop = int(fields[fieldID+2].strip())
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr]={}
        sense = '+'
        if doStrand:
            strand=fields[strandField]
            if strand == 'F' or strand == '+':
                sense='+'
            if strand == 'R' or strand == '-':
                sense='+'
        if doName:
            ID = '>'
            for f in nameFields:
                ID = ID + fields[f] + '::'
            ID = ID[0:-2]
        else:
            if doReName:
                ID = '>'+ ReNameString + ':' + str(start) + '-' + str(stop)
            else:
                ID = '>'+chr+':'+str(start)+'-'+str(stop)
        RegionDict[chr][ID]=(chr,start,stop,sense)

#    print RegionDict

    if fasta.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + fasta
    elif fasta.endswith('.gz'):
        cmd = 'gunzip -c ' + fasta
    else:
        cmd = 'cat ' + fasta
    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line[0]=='>':
            if Keep:
                sequence = ''.join(sequence)
                for ID in RegionDict[chr].keys():
                    outfile.write(ID + '\n')
                    (chr,start,stop,sense) = RegionDict[chr][ID]
                    regionSequence = sequence[max(0,start):stop]
                    if sense == '-':
                        regionSequence  = getReverseComplement(regionSequence)
                    if doU:
                        regionSequence = regionSequence.replace('T','U').replace('t','u')
                    if doCAP:
                        regionSequence = regionSequence.upper()
                    outfile.write(regionSequence + '\n')
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            if RegionDict.has_key(chr):
                Keep=True
            continue
        else:
            sequence.append(line.strip())
    sequence = ''.join(sequence)
    if RegionDict.has_key(chr):
        for ID in RegionDict[chr].keys():
            outfile.write(ID + '\n')
            (chr,start,stop,sense) = RegionDict[chr][ID]
            regionSequence = sequence[max(0,start):stop]
            if sense == '-':
                regionSequence = getReverseComplement(regionSequence)
            if doU:
                regionSequence = regionSequence.replace('T','U').replace('t','u')
            if doCAP:
                regionSequence = regionSequence.upper()
            outfile.write(regionSequence + '\n')
   
run()
