##################################
#                                #
# Last modified 2023/04/12       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s datafile title labelX fieldX labelY fieldY outfile [-noR] [-log10 pseudocounts_value] [-log10X pseudocounts_value] [-log10Y pseudocounts_value] [-log2 pseudocounts_value] [-log2X pseudocounts_value] [-log2Y pseudocounts_value] [-sameXYscale] [-max number] [-maxX number] [-maxY number] [-min number] [-minX number] [-minY number] [-color c] [-titlesize size] [-dotsize size] [-resize factor] [-rankX] [-rankY] [-colorByValue fieldID cmap] [-alpha value] [-saveps]' % sys.argv[0]
        print '       enter lables from the command line with "_" instead of sapces, the code will replace them with spaces'
        print '       the script will ignore header lines in the input file starting with #'
        print '       default color is red'
        print '       default size is 10x10, you can resize by using the -factor option'
        print '       if you use the -max, option together with the log10scale option, the -max number will have to be a log10 transform'
        sys.exit(1)
    
    input = sys.argv[1]
    title = sys.argv[2].replace('_',' ')
    label1 = sys.argv[3].replace('_',' ')
    field1 = int(sys.argv[4])
    label2 = sys.argv[5].replace('_',' ')
    field2 = int(sys.argv[6])
    outfilename = sys.argv[7]

    doRankX = False
    doRankY = False

    if '-rankX' in sys.argv:
        doRankX = True
        print 'will use the rank of the values on the X axis for the Y axis'

    if '-rankY' in sys.argv:
        doRankY = True
        print 'will use the rank of the values on the Y axis for the X axis'

    if doRankY and doRankX:
        print 'options collision --- [-rankX] and [-rankY], exiting'

    dotsize=20
    if '-dotsize' in sys.argv:
        dotsize=int(sys.argv[sys.argv.index('-dotsize')+1])

    resize=1
    if '-resize' in sys.argv:
        resize=float(sys.argv[sys.argv.index('-resize')+1])

    ALPHA=1
    if '-alpha' in sys.argv:
        ALPHA=float(sys.argv[sys.argv.index('-alpha')+1])

    titlesize=50
    if '-titlesize' in sys.argv:
        titlesize=float(sys.argv[sys.argv.index('-titlesize')+1])

    doR=True
    if '-noR' in sys.argv:
        doR=False

    doMaxX=False
    doMaxY=False
    if '-maxX' in sys.argv:
        doMaxX=True
        MaxNumberX=float(sys.argv[sys.argv.index('-maxX')+1])
        print 'will limit display on the X axis to numbers lower than', MaxNumberX

    if '-maxY' in sys.argv:
        doMaxY=True
        MaxNumberY=float(sys.argv[sys.argv.index('-maxY')+1])
        print 'will limit display on the Y axis to numbers lower than', MaxNumberY

    if '-max' in sys.argv:
        doMaxX=True
        doMaxY=True
        MaxNumberX=float(sys.argv[sys.argv.index('-max')+1])
        MaxNumberY=float(sys.argv[sys.argv.index('-max')+1])
        print 'will limit display on both axes to numbers lower than', MaxNumberX

    doMinX=False
    doMinY=False
    if '-minX' in sys.argv:
        doMinX=True
        MinNumberX=float(sys.argv[sys.argv.index('-minX')+1])
        print 'will limit display on the X axis to numbers larger than', MinNumberX

    if '-minY' in sys.argv:
        doMinY=True
        MinNumberY=float(sys.argv[sys.argv.index('-minY')+1])
        print 'will limit display on the Y axis to numbers larger than', MinNumberY

    if '-min' in sys.argv:
        doMinX=True
        doMinY=True
        MinNumberX=float(sys.argv[sys.argv.index('-min')+1])
        MinNumberY=float(sys.argv[sys.argv.index('-min')+1])
        print 'will limit display on both axes to numbers larger than', MinNumberX

    color='r'
    if '-color' in sys.argv:
        color=sys.argv[sys.argv.index('-color')+1]

    doCBV = False
    if '-colorByValue' in sys.argv:
        doCBV = True
        CBVFieldID = int(sys.argv[sys.argv.index('-colorByValue') + 1])
        CBVcmap = sys.argv[sys.argv.index('-colorByValue') + 2]

    doSameScale=False
    if '-sameXYscale' in sys.argv:
        doSameScale=True

    doLog10X=False
    doLog10Y=False

    if '-log10X' in sys.argv:
        doLog10X=True
        pseudoX=float(sys.argv[sys.argv.index('-log10X')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)

    if '-log10Y' in sys.argv:
        doLog10Y=True
        pseudoY=float(sys.argv[sys.argv.index('-log10Y')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    if '-log10' in sys.argv:
        doLog10X=True
        doLog10Y=True
        pseudoX=float(sys.argv[sys.argv.index('-log10')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)
        pseudoY=float(sys.argv[sys.argv.index('-log10')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    doLog2X=False
    doLog2Y=False

    if '-log2X' in sys.argv:
        doLog2X=True
        pseudoX=float(sys.argv[sys.argv.index('-log2X')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)

    if '-log2Y' in sys.argv:
        doLog2Y=True
        pseudoY=float(sys.argv[sys.argv.index('-log2Y')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    if '-log2' in sys.argv:
        doLog2X=True
        doLog2Y=True
        pseudoX=float(sys.argv[sys.argv.index('-log2')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)
        pseudoY=float(sys.argv[sys.argv.index('-log2')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    A=[]
    B=[]
    if doCBV:
        C = []
    lineslist=open(input)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[field1] == 'nan' or fields[field2] == 'nan':
            continue
        try:
            score1 = float(fields[field1])
            score2 = float(fields[field2])
            if doCBV:
                scoreC = float(fields[CBVFieldID])
        except:
            print 'skipping', line.strip()
            continue
        if doLog10X:
            if score1 + pseudoX <= 0:
                score1 = pseudoX
            else:
                score1=math.log10(score1+pseudoX)
        if doLog10Y:
            if score2 + pseudoY <= 0:
                score2 = pseudoY
            else:
                score2=math.log10(score2+pseudoY)
        if doLog2X:
            if score1 + pseudoX <= 0:
                score1 = pseudoX
            else:
                score1=math.log(score1+pseudoX,2)
        if doLog2Y:
            if score2 + pseudoY <= 0:
                score2 = pseudoY
            else:
                score2=math.log(score2+pseudoY,2)
        A.append(score1)
        B.append(score2)
        if doCBV:
            C.append(scoreC)

    if doRankX:
        A.sort()
        A.reverse()
        B=[]
        for i in range(len(A)):
            if doLog10Y:
                B.append(math.log10(i+pseudoY))
            elif doLog2Y:
                B.append(math.log(i+pseudoY,2))
            else:
                B.append(i+1)

    if doRankY:
        B.sort()
        B.reverse()
        A=[]
        for i in range(len(B)):
            if doLog10X:
                A.append(math.log10(i+pseudoX))
            elif doLog2X:
                A.append(math.log(i+pseudoX,2))
            else:
                A.append(i+1)

    rect = 0.10,0.10,0.8,0.8
    fig = figure(figsize=(20*resize, 20*resize))
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

    if doSameScale:
        lowerlimitX=min(min(A),min(B))-0.1
        upperlimitX=max(max(A),max(B))-0.1
        lowerlimitY=min(min(A),min(B))-0.1
        upperlimitY=max(max(A),max(B))-0.1
    else:
        lowerlimitX=min(A)
        upperlimitX=max(A)
        lowerlimitY=min(B)
        upperlimitY=max(B)

    if doMaxX:
#    if doMaxX and upperlimitX > MaxNumberX:
        upperlimitX = MaxNumberX
    if doMaxY:
#    if doMaxY and upperlimitY > MaxNumberY:
        upperlimitY = MaxNumberY
    if doMinX:
#    if doMinX and lowerlimitX < MinNumberX:
        lowerlimitX = MinNumberX
    if doMinY:
#    if doMinY and lowerlimitY < MinNumberY:
        lowerlimitY = MinNumberY

    if doCBV:
        ax.scatter(A, B, c=C, s=dotsize*resize, cmap=CBVcmap, alpha=ALPHA)
    else:
        ax.scatter(A, B, c=color, s=dotsize*resize, alpha=ALPHA)
    ax.set_title(title,size=titlesize*resize,weight='bold')
    ax.set_xlabel(label1,size=30*resize,weight='bold')
    ax.set_ylabel(label2,size=30*resize,weight='bold')
    ax.set_xlim(lowerlimitX,upperlimitX)
    ax.set_ylim(lowerlimitY,upperlimitY)
    xticks=ax.get_xticks()
    yticks=ax.get_yticks()
    xticklabels=[]
    yticklabels=[]
    for i in xticks:
        if doLog10X:
           i=math.pow(10,i)
        if doLog2X:
           i=math.pow(2,i)
        xticklabels.append(str(i))
    for i in yticks:
        if doLog10Y:
           i=math.pow(10,i)
        if doLog2Y:
           i=math.pow(2,i)
        yticklabels.append(str(i))
    ax.set_xticklabels(xticklabels,size=20*resize,weight='bold')
    ax.set_yticklabels(yticklabels,size=20*resize,weight='bold')

    if doR:
        r=np.corrcoef(np.nan_to_num(A),np.nan_to_num(B))[0,1]
        rtext='r='+str(r)[0:4]
        print rtext
        plt.figtext(0.8, 0.12, rtext, size=35*resize)

    savefig(outfilename)

    if doPostScript:
        savefig(outfilename + '.eps', format='eps')
   
run()
