##################################
#                                #
# Last modified 2018/01/04       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
from mpl_toolkits.mplot3d import Axes3D

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s datafile title labelX fieldX labelY fieldY labelZ fieldZ outfile [-noR] [-log10 pseudocounts_value] [-log10X pseudocounts_value] [-log10Y pseudocounts_value] [-log10Z pseudocounts_value] [-log2 pseudocounts_value] [-log2X pseudocounts_value] [-log2Y pseudocounts_value] [-log2Z pseudocounts_value] [-sameXYscale] [-max number] [-maxX number] [-maxY number] [-maxZ number] [-min number] [-minX number] [-minY number] [-minZ number] [-color c] [-titlesize size] [-dotsize size] [-resize factor] [-colorByValue fieldID cmap] [-alpha value] [-saveps]' % sys.argv[0]
        print '       enter lables from the command line with "_" instead of sapces, the code will replace them with spaces'
        print '       the script will ignore header lines in the input file starting with #'
        print '       default color is red'
        print '       default size is 10x10, you can resize by using the -factor option'
        print '       if you use the -max, option together with the log10scale option, the -max number will have to be a log10 transform'
        sys.exit(1)
    
    input = sys.argv[1]
    title = sys.argv[2].replace('_',' ')
    label1 = sys.argv[3].replace('_',' ')
    field1 = int(sys.argv[4])
    label2 = sys.argv[5].replace('_',' ')
    field2 = int(sys.argv[6])
    label3 = sys.argv[7].replace('_',' ')
    field3 = int(sys.argv[8])
    outfilename = sys.argv[9]

    dotsize=20
    if '-dotsize' in sys.argv:
        dotsize=int(sys.argv[sys.argv.index('-dotsize')+1])

    resize=1
    if '-resize' in sys.argv:
        resize=float(sys.argv[sys.argv.index('-resize')+1])

    ALPHA=1
    if '-alpha' in sys.argv:
        ALPHA=float(sys.argv[sys.argv.index('-alpha')+1])

    titlesize=50
    if '-titlesize' in sys.argv:
        titlesize=float(sys.argv[sys.argv.index('-titlesize')+1])

    doR=True
    if '-noR' in sys.argv:
        doR=False

    doMaxX=False
    doMaxY=False
    doMaxZ=False
    if '-maxX' in sys.argv:
        doMaxX=True
        MaxNumberX=float(sys.argv[sys.argv.index('-maxX')+1])
        print 'will limit display on the X axis to numbers lower than', MaxNumberX

    if '-maxY' in sys.argv:
        doMaxY=True
        MaxNumberY=float(sys.argv[sys.argv.index('-maxY')+1])
        print 'will limit display on the Y axis to numbers lower than', MaxNumberY

    if '-maxZ' in sys.argv:
        doMaxZ=True
        MaxNumberZ=float(sys.argv[sys.argv.index('-maxZ')+1])
        print 'will limit display on the Z axis to numbers lower than', MaxNumberZ

    if '-max' in sys.argv:
        doMaxX = True
        doMaxY = True
        doMaxZ = True
        MaxNumberX = float(sys.argv[sys.argv.index('-max')+1])
        MaxNumberY = float(sys.argv[sys.argv.index('-max')+1])
        MaxNumberZ = float(sys.argv[sys.argv.index('-max')+1])
        print 'will limit display on all axes to numbers lower than', MaxNumberX

    doMinX = False
    doMinY = False
    doMinZ = False
    if '-minX' in sys.argv:
        doMinX=True
        MinNumberX=float(sys.argv[sys.argv.index('-minX')+1])
        print 'will limit display on the X axis to numbers larger than', MinNumberX

    if '-minY' in sys.argv:
        doMinY=True
        MinNumberY=float(sys.argv[sys.argv.index('-minY')+1])
        print 'will limit display on the Y axis to numbers larger than', MinNumberY

    if '-minZ' in sys.argv:
        dominZ = True
        MinNumberZ = float(sys.argv[sys.argv.index('-minZ')+1])
        print 'will limit display on the Z axis to numbers larger than', MinNumberZ

    if '-min' in sys.argv:
        doMinX = True
        doMinY = True
        doMinZ = True
        MinNumberX = float(sys.argv[sys.argv.index('-min')+1])
        MinNumberY = float(sys.argv[sys.argv.index('-min')+1])
        MinNumberZ = float(sys.argv[sys.argv.index('-min')+1])
        print 'will limit display on both axes to numbers larger than', MinNumberX

    color='r'
    if '-color' in sys.argv:
        color=sys.argv[sys.argv.index('-color')+1]

    doCBV = False
    if '-colorByValue' in sys.argv:
        doCBV = True
        CBVFieldID = int(sys.argv[sys.argv.index('-colorByValue') + 1])
        CBVcmap = sys.argv[sys.argv.index('-colorByValue') + 2]

    doSameScale=False
    if '-sameXYscale' in sys.argv:
        doSameScale=True

    doLog10X = False
    doLog10Y = False
    doLog10Z = False

    if '-log10X' in sys.argv:
        doLog10X=True
        pseudoX=float(sys.argv[sys.argv.index('-log10X')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)

    if '-log10Y' in sys.argv:
        doLog10Y=True
        pseudoY=float(sys.argv[sys.argv.index('-log10Y')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    if '-log10Z' in sys.argv:
        doLog10Z = True
        pseudoZ = float(sys.argv[sys.argv.index('-log10Z')+1])
        if pseudoZ==0:
            pass
        else:
            label3 = label3 + ' + ' +str(pseudoZ)

    if '-log10' in sys.argv:
        doLog10X=True
        doLog10Y=True
        doLog10Z=True
        pseudoX=float(sys.argv[sys.argv.index('-log10')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)
        pseudoY=float(sys.argv[sys.argv.index('-log10')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)
        pseudoZ = float(sys.argv[sys.argv.index('-log10')+1])
        if pseudoZ==0:
            pass
        else:
            label3 = label3 + ' + ' + str(pseudoZ)

    doLog2X = False
    doLog2Y = False
    doLog2Z = False

    if '-log2X' in sys.argv:
        doLog2X=True
        pseudoX=float(sys.argv[sys.argv.index('-log2X')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)

    if '-log2Y' in sys.argv:
        doLog2Y=True
        pseudoY=float(sys.argv[sys.argv.index('-log2Y')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)

    if '-log2Z' in sys.argv:
        doLog2Z = True
        pseudoZ = float(sys.argv[sys.argv.index('-log2Z')+1])
        if pseudoZ == 0:
            pass
        else:
            label3= label3 + ' + ' + str(pseudoZ)

    if '-log2' in sys.argv:
        doLog2X = True
        doLog2Y = True
        doLog2Z = True
        pseudoX=float(sys.argv[sys.argv.index('-log2')+1])
        if pseudoX==0:
            pass
        else:
            label1=label1+' + ' +str(pseudoX)
        pseudoY=float(sys.argv[sys.argv.index('-log2')+1])
        if pseudoY==0:
            pass
        else:
            label2=label2+' + ' +str(pseudoY)
        pseudoY=float(sys.argv[sys.argv.index('-log2')+1])
        if pseudoZ == 0:
            pass
        else:
            label3 = label3 + ' + ' + str(pseudoZ)

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    A = []
    B = []
    Z = []
    if doCBV:
        C = []
    lineslist=open(input)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        try:
            score1 = float(fields[field1])
            score2 = float(fields[field2])
            score3 = float(fields[field3])
            if doCBV:
                scoreC = float(fields[CBVFieldID])
        except:
            print 'skipping', line.strip()
            continue
        if doLog10X:
            if score1 + pseudoX <= 0:
                score1 = pseudoX
            else:
                score1=math.log10(score1+pseudoX)
        if doLog10Y:
            if score2 + pseudoY <= 0:
                score2 = pseudoY
            else:
                score2=math.log10(score2+pseudoY)
        if doLog10Z:
            if score3 + pseudoZ <= 0:
                score3 = pseudoZ
            else:
                score3 = math.log10(score3 + pseudoZ)
        if doLog2X:
            if score1 + pseudoX <= 0:
                score1 = pseudoX
            else:
                score1=math.log(score1+pseudoX,2)
        if doLog2Y:
            if score2 + pseudoY <= 0:
                score2 = pseudoY
            else:
                score2=math.log(score2+pseudoY,2)
        if doLog2Z:
            if score3 + pseudoZ <= 0:
                score3 = pseudoZ
            else:
                score3 = math.log(score3 + pseudoZ,2)
        A.append(score1)
        B.append(score2)
        Z.append(score3)
        if doCBV:
            C.append(scoreC)

    if doSameScale:
        lowerlimitX = min(min(A),min(B),min(Z))-0.1
        upperlimitX = max(max(A),max(B),max(Z))-0.1
        lowerlimitY = min(min(A),min(B),min(Z))-0.1
        upperlimitY = max(max(A),max(B),max(Z))-0.1
        lowerlimitZ = min(min(A),min(B),min(Z))-0.1
        upperlimitZ = max(max(A),max(B),max(Z))-0.1
    else:
        lowerlimitX = min(A)
        upperlimitX = max(A)
        lowerlimitY = min(B)
        upperlimitY = max(B)
        lowerlimitZ = min(Z)
        upperlimitZ = max(Z)

    if doMaxX:
#    if doMaxX and upperlimitX > MaxNumberX:
        upperlimitX = MaxNumberX
    if doMaxY:
#    if doMaxY and upperlimitY > MaxNumberY:
        upperlimitY = MaxNumberY
    if doMaxZ:
#    if doMaxZ and upperlimitZ > MaxNumberZ:
        upperlimitZ = MaxNumberZ
    if doMinX:
#    if doMinX and lowerlimitX < MinNumberX:
        lowerlimitX = MinNumberX
    if doMinY:
#    if doMinY and lowerlimitY < MinNumberY:
        lowerlimitY = MinNumberY
    if doMinZ:
#    if doMinZ and lowerlimitZ < MinNumberZ:
        lowerlimitZ = MinNumberZ

    rect = 0.10,0.10,0.8,0.8
    fig = figure(figsize=(20*resize, 20*resize))
    ax = Axes3D(fig)
#    ax = fig.add_subplot(111,aspect='equal', projection='3d')
#    ax = fig.add_axes(rect)

    if doCBV:
#        ax.scatter(A, B, Z, c=C, s=dotsize*resize, cmap=CBVcmap, alpha=ALPHA)
        ax.scatter(A, B, Z, c=C, cmap=CBVcmap, alpha=ALPHA)
    else:
#        ax.scatter(A, B, Z, c=color, s=dotsize*resize, alpha=ALPHA)
        ax.scatter(A, B, Z, c=color, alpha=ALPHA)
    ax.set_title(title,size=titlesize*resize,weight='bold')
    ax.set_xlabel(label1,size=30*resize,weight='bold')
    ax.set_ylabel(label2,size=30*resize,weight='bold')
    ax.set_zlabel(label3,size=30*resize,weight='bold')
    ax.set_xlim(lowerlimitX,upperlimitX)
    ax.set_ylim(lowerlimitY,upperlimitY)
    ax.set_zlim(lowerlimitZ,upperlimitZ)
    xticks=ax.get_xticks()
    yticks=ax.get_yticks()
    zticks=ax.get_zticks()
    xticklabels=[]
    yticklabels=[]
    zticklabels=[]
    for i in xticks:
        if doLog10X:
           i=math.pow(10,i)
        if doLog2X:
           i=math.pow(2,i)
        xticklabels.append(str(i))
    for i in yticks:
        if doLog10Y:
           i=math.pow(10,i)
        if doLog2Y:
           i=math.pow(2,i)
        yticklabels.append(str(i))
    for i in zticks:
        if doLog10Z:
           i=math.pow(10,i)
        if doLog2Z:
           i=math.pow(2,i)
        zticklabels.append(str(i))
    ax.set_xticklabels(xticklabels,size=20*resize,weight='bold')
    ax.set_yticklabels(yticklabels,size=20*resize,weight='bold')
    ax.set_zticklabels(zticklabels,size=20*resize,weight='bold')

    if doR:
        r=np.corrcoef(np.nan_to_num(A),np.nan_to_num(B))[0,1]
        rtext='r='+str(r)[0:4]
        print rtext
        plt.figtext(0.8, 0.12, rtext, size=35*resize)

    savefig(outfilename)

    if doPostScript:
        savefig(outfilename + '.eps', format='eps')
   
run()
