##################################
#                                #
# Last modified 10/27/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import matplotlib
matplotlib.use('Agg')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection

def label(xy, text):
    y = xy[1] + 0.05
    plt.text(xy[0], y, text, ha="center", family='sans-serif', size=14)

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s config fasta bar_height figure_length figure_wdith_inches font_size outputfileprefix [-reverseStrand] [-layers N]' % sys.argv[0]
        print '\tconfig file format'
        print '\t#GTF\tcolor'
        print '\tOnly one chromosome per GTF file is assumed'
        print '\tSplit the general GTF file into subsets for if you want them colored differently, i.e. protein coding genes, rRNAs, etc.'
        print '\tbar height and figure length in base pair units (i.e. relative to the length of the chromosome'
        print '\tNote: as of now it does not plot introns!!!!'
        sys.exit(1)

    config = sys.argv[1]
    fasta = sys.argv[2]
    BH = float(sys.argv[3])
    FW = float(sys.argv[4])
    FWI = float(sys.argv[5])
    FS = int(sys.argv[6])
    outfileprefix = sys.argv[7]

    doRS = False
    if '-reverseStrand' in sys.argv:
        doRS = True
        StrandDict = {}
        StrandDict['+'] = '-'
        StrandDict['-'] = '+'

    ColorDict = {}

    maxLength = 0

    linelist = open(config)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        gtf = fields[0]
        c = fields[1]
        ColorDict[gtf] = c

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    linelist = open(config)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        gtf = fields[0]
        c = fields[1]
        ColorDict[gtf] = c

    GeneDict = {}
    for gtf in ColorDict.keys():
        c = ColorDict[gtf]
        linelist = open(gtf)
        for line in linelist:
            if line.startswith('#') or line.strip() == '':
                continue
            fields = line.strip().split('\t')
            chr = fields[0]
            left = int(fields[3])
            right = int(fields[4])
            strand = fields[6]
            if doRS:
                strand = StrandDict[strand]
                left = len(GenomeDict[chr]) - left
                right = len(GenomeDict[chr]) - right
            geneID=fields[8].split('gene_id "')[1].split('"')[0]
            transcriptID=fields[8].split('transcript_id "')[1].split('"')[0]
            if 'gene_name "' in fields[8]:
                geneName=fields[8].split('gene_name "')[1].split('"')[0]
            else:
                geneName=geneID
            if GeneDict.has_key((geneID,geneName)):
                pass
            else:
                GeneDict[(geneID,geneName)] = {}
                GeneDict[(geneID,geneName)]['c'] = c
                GeneDict[(geneID,geneName)]['strand'] = strand
                GeneDict[(geneID,geneName)]['coordinates'] = []
            GeneDict[(geneID,geneName)]['coordinates'].append(left)
            GeneDict[(geneID,geneName)]['coordinates'].append(right)

    print GenomeDict.keys()

    TL = len(GenomeDict[chr]) + 0.0

    print 'plotting genes'

    GeneList = []
    for (geneID,geneName) in GeneDict:
        left = min(GeneDict[(geneID,geneName)]['coordinates'])
        right = max(GeneDict[(geneID,geneName)]['coordinates'])
        strand = GeneDict[(geneID,geneName)]['strand']
        GeneList.append((strand,left,right,geneID,geneName))

    GeneList.sort()

    maxOverlay = 0
    CoverageDict = {}
    CoverageDict['-'] = {}
    CoverageDict['+'] = {}
    LayerDict = {}
    LayerDictBP = {}
    LayerDictBP['-'] = {}
    LayerDictBP['+'] = {}
    
    for (strand,left,right,geneID,geneName) in GeneList:
        for i in range(left,right):
            if CoverageDict[strand].has_key(i):
                CoverageDict[strand][i] += 1
                if CoverageDict[strand][i] > maxOverlay:
                    maxOverlay = CoverageDict[strand][i]
            else:
                CoverageDict[strand][i] = 1
#            if LayerDict.has_key(((geneID,geneName))):
#                continue
#            else:
#                LayerDict[(geneID,geneName)] = CoverageDict[strand][i]
        LayerList = []
        for i in range(left,right):
            if LayerDictBP[strand].has_key(i):
                LayerList.append(LayerDictBP[strand][i])
        LayerList = list(Set(LayerList))
        for i in range(1,CoverageDict[strand][left]+1):
            if i in LayerList:
                continue
            else:
                LayerDict[(geneID,geneName)] = i
                break
#        LayerDict[(geneID,geneName)] = CoverageDict[strand][left]
        print geneName, LayerDict[(geneID,geneName)], CoverageDict[strand][left], left, right, strand
        for i in range(left,right):
            LayerDictBP[strand][i] = LayerDict[(geneID,geneName)]

        
    doFixedLayers = False
    if '-layers' in sys.argv:
        doFixedLayers = True
        maxOverlay = int(sys.argv[sys.argv.index('-layers') + 1])

    fig = plt.figure(figsize=(FWI, (maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI) )
#    fig.set_size_inches(FWI,(BH/FW)*FWI)
    ax = fig.add_subplot(111, aspect='equal')

    leftpos = 0.5*(FW - TL)/FW
    rightpos = 1 - 0.5*(FW - TL)/FW
    print TL, FW, FWI, leftpos, rightpos, (rightpos-leftpos)*FWI

    rect = matplotlib.patches.Rectangle((0,0), FWI, (maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI, fill='none', color = 'white')
    ax.add_patch(rect)
#    rect = matplotlib.patches.Rectangle((0,(maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2), FWI, -(maxOverlay + 4)*(BH/FW)*FWI, fill='none', color = 'white')
#    ax.add_patch(rect)

    rect = matplotlib.patches.Rectangle((leftpos*FWI,(maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2.), (rightpos-leftpos)*FWI, 0, fill=None, color = 'black')
    ax.add_patch(rect)

    print maxOverlay
    print leftpos*FWI,(maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2., (rightpos-leftpos)*FWI, 0

    for (geneID,geneName) in GeneDict:
        left = min(GeneDict[(geneID,geneName)]['coordinates'])
        right = max(GeneDict[(geneID,geneName)]['coordinates'])
        strand = GeneDict[(geneID,geneName)]['strand']
        c = GeneDict[(geneID,geneName)]['c']
        layer = LayerDict[(geneID,geneName)]
        Lpos = leftpos + (left/TL)*(TL/FW)
        Rpos = leftpos + (right/TL)*(TL/FW)
#        print geneID, geneName, left, right, TL, FW, Lpos, Rpos, c, strand, layer
        if strand == '+':
            rect = matplotlib.patches.Rectangle((Lpos*FWI,(maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2. + (layer - 1)*(BH/FW)*FWI), (Rpos-Lpos)*FWI, (BH/FW)*FWI, facecolor=c, linestyle = 'solid', edgecolor = 'black', lw = 0.5)
        if strand == '-':
            rect = matplotlib.patches.Rectangle((Lpos*FWI,(maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2. - (layer - 1)*(BH/FW)*FWI), (Rpos-Lpos)*FWI, -(BH/FW)*FWI, facecolor=c, linestyle = 'solid', edgecolor = 'black', lw = 0.5)
        ax.add_patch(rect)
        if strand == '+':
            ax.text((Rpos*FWI + Lpos*FWI)/2, (maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2. + (maxOverlay+0.2)*(BH/FW)*FWI, geneName, fontsize=FS, rotation=90, ha = 'center')
        if strand == '-':
            ax.text((Rpos*FWI + Lpos*FWI)/2, (maxOverlay + 2*maxOverlay + 3)*(BH/FW)*FWI/2. - (maxOverlay+0.2)*(BH/FW)*FWI, geneName, fontsize=FS, rotation=90, ha = 'center', va = 'top')

    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.axis('equal')
    plt.axis('off')
    plt.savefig(outfileprefix + '.png', dpi=200)
    plt.savefig(outfileprefix + '.eps', format='eps')

#
#    print 'plotting domains'
#
#    Dlist = list(Set(Dlist))
#
#    fig = plt.figure()
#    fig.set_size_inches(FW,BH*len(Dlist))
#    ax = fig.add_subplot(111)
#
#    step = 1./(len(Dlist) + 1)
#    pos = 1 - 0.5*step
#    for (domain,c) in Dlist:
#        rect = matplotlib.patches.Rectangle((0.1,pos), 0.5, 0.66*step, color=c, linestyle = 'solid', edgecolor = 'black')
#        ax.add_patch(rect)
#        rect = matplotlib.patches.Rectangle((0.1,pos), 0.5, 0.66*step, fill=None, color = 'black')
#        ax.add_patch(rect)
#        plt.text(0.625, pos + 0.3*step, domain, fontsize=BH*12)
#        pos -= step
#
#    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
#    plt.axis('equal')
#    plt.axis('off')
#    plt.savefig(outfileprefix + '.domains.eps', format='eps')
    
run()
