##################################
#                                #
# Last modified 05/07/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome

def getSequence(genome,chromosome,start,stop,sense):
    
    hg = Genome(genome)
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    chromosome = chromosome[3:len(chromosome)]
    if sense=='F':
        sequence = hg.sequence(chromosome,start,stop-start)
    if sense=='R':
        preliminarysequence = hg.sequence(chromosome,start,stop-start)
        sequence=''
        for i in range(len(preliminarysequence)):
            sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s  genome inputfilename outpurfilename [-seqradius bp] [-usepeak fieldID] [-chrfield fieldID] [-first number] [-name fieldID] [-strand fieldID] [-narrowPeak]' % sys.argv[0]
        print '       default sequence names format is chr:start-stop format' 
        print '       by default the forward strand is outputted' 
        sys.exit(1)

    cachePages = 2000000

    fieldID=1
    if '-chrfield' in sys.argv:
        fieldID = int(sys.argv[sys.argv.index('-chrfield') + 1])
    
    doName = False
    if '-name' in sys.argv:
        doName = True
        nameField = int(sys.argv[sys.argv.index('-name') + 1])

    doStrand = False
    if '-strand' in sys.argv:
        doStrand = True
        strandField = int(sys.argv[sys.argv.index('-strand') + 1])

    doPeak = False
    if '-usepeak' in sys.argv:
        doPeak = True
        peakFieldID = int(sys.argv[sys.argv.index('-usepeak') + 1])
    if '-seqradius' in sys.argv:
        radius = int(sys.argv[sys.argv.index('-seqradius') + 1])

    doNarrowPeak=False
    if '-narrowPeak' in sys.argv:
        doNarrowPeak=True
        doPeak=True

    genome = sys.argv[1]
    inputfilename = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')
    hg = Genome(genome)
    idb = geneinfoDB()
    
    inputdatafile = open(inputfilename)

    doTop=False   
    if '-first' in sys.argv: 
        doTop=True
        top = int(sys.argv[sys.argv.index('-top') + 1])
    i=0
    for line in inputdatafile:
        i+=1
        if doTop and i >= top:
            continue
        if line[0]=='#':
            continue
        if 'random' in line:
            continue
        fields = line.strip().split('\t')
        if doPeak:
            if doNarrowPeak:
                chromosome = fields[0]
                peak = int(fields[1]) + int(fields[9])
                start = peak-radius
                stop = peak+radius
#                print chromosome, fields[1], fields[2], fields[9], peak, start, stop
            else:
                peak = int(fields[peakFieldID].strip())
                chromosome = fields[1].strip()
                start = peak-radius
                stop = peak+radius
        else:
            chromosome = fields[fieldID].strip()
            start = int(fields[fieldID+1].strip())
            stop = int(fields[fieldID+2].strip())
        strand = 'F'
        if doStrand:
            strand=fields[strandField]
        if strand == 'F' or strand == '+':
            sense='F'
        if strand == 'R' or strand == '-':
            sense='R'
        sequence=getSequence(genome,chromosome,start,stop,sense)
#        print 'sequence', sequence
        if doName:
            name=fields[nameField]
            outfile.write('>'+name+'\n')
        else:
            outfile.write('>'+chromosome+':'+str(start)+'-'+str(stop)+'\n')
        outfile.write(sequence)
        outfile.write('\n')
   
run()
