##################################
#                                #
# Last modified 07/02/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math
from scipy import linspace, polyval, polyfit, sqrt, stats, randn
import numpy as np

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s spikeInput spikeRPKM spikeFieldID spikeFPKMFieLdID FPKM_file FPKM_fieldID(s) outfilename' % sys.argv[0]
        print '\tNote: FPKM_fieldID(s) should be comma-separated'
        print '\tNote: If there are spikes which you wish to exclude, remove them from the concentrations file; the script will only look at those that are present in it'
        print '\tNote: spikeInput file format: <spike name> \t <transcript copies>; spike names should match between the spikeInput and spikeRPKM files'
        sys.exit(1)

    spikeInput = sys.argv[1]
    spikeRPKM = sys.argv[2]
    spikeID = int(sys.argv[3])
    spikeFPKMID = int(sys.argv[4])
    FPKM = sys.argv[5]
    fields = sys.argv[6].split(',')
    FPKMfields = []
    for ID in fields:
        FPKMfields.append(int(ID))
    outfilename = sys.argv[7]

    SpikeDict = {}

    lineslist = open(spikeInput)
    for line in lineslist:
        fields = line.strip().split('\t')
        spike = fields[0]
        copies = float(fields[1])
        SpikeDict[spike]={}
        SpikeDict[spike]['copies'] = copies

    lineslist = open(spikeRPKM)
    for line in lineslist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        spike = fields[spikeID]
        RPKM = float(fields[spikeFPKMID])
        if SpikeDict.has_key(spike):
            SpikeDict[spike]['RPKM'] = RPKM

    RPKM = []
    copies = []
    logRPKM = []
    logcopies = []
    for spike in SpikeDict.keys():
        RPKM.append(SpikeDict[spike]['RPKM'])
        copies.append(SpikeDict[spike]['copies'])
        logRPKM.append(math.log(SpikeDict[spike]['RPKM']+1))
        logcopies.append(math.log(SpikeDict[spike]['copies']+1))


    copies = np.array(copies)
    RPKM = np.array(RPKM)
    RPKM = RPKM[:,np.newaxis]
    a,_,_,_ = np.linalg.lstsq(RPKM,copies)

    print a
    print RPKM
    print a*RPKM
    print copies

    logcopies = np.array(logcopies)
    logRPKM = np.array(logRPKM)
    logRPKM = logRPKM[:,np.newaxis]
    loga,_,_,_ = np.linalg.lstsq(logRPKM,logcopies)

#    print loga
#    print logRPKM
#    print math.exp(loga*logRPKM), loga*logRPKM
#    print math.exp(logcopies), logcopies


    outfile = open(outfilename, 'w')

    lineslist = open(FPKM)
    for line in lineslist:
        fields = line.strip().split('\t')
        if line.startswith('#') or line.startswith('tracking_id'):
            outline = line.strip()
            for ID in FPKMfields:
                outline = outline + '\t' + fields[ID] + '_converted_to_copies'
            for ID in FPKMfields:
                outline = outline + '\t' + fields[ID] + '_converted_to_copies_log'
            outfile.write(outline + '\n')
            continue
        outline = line.strip()
        for ID in FPKMfields:
            FPKM = float(fields[ID])
            copies = a*FPKM
            outline = outline + '\t' + str(copies[0])
        for ID in FPKMfields:
            logFPKM = math.log(float(fields[ID])+1)
            logcopies = loga*logFPKM
            outline = outline + '\t' + str(math.exp(logcopies[0])-1)
        outfile.write(outline + '\n')

# from scipy import linspace, polyval, polyfit, sqrt, stats, randn
# (a,b)=polyfit(RPKM,concentrations,1)
# copies = polyval([a,b],RPKM)
# err = sqrt(sum((concentrationsR-concentrations)**2)/len(concentrationsR))

run()
