##################################
#                                #
# Last modified 03/17/2016       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import subprocess

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s KaKs_location fasta muscle_location genetic_code model outfilename' % sys.argv[0]
        sys.exit(1)

    KaKs = sys.argv[1]
    MSA = sys.argv[2]
    muscle = sys.argv[3]
    GeneticCode = sys.argv[4]
    Model = sys.argv[5]
    outprefix = sys.argv[6]
    outfile = open(sys.argv[6],'w')

    SeqDict={}
    sequence=''
    inputdatafile = open(MSA)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                SeqDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    SeqDict[chr] = ''.join(sequence).upper()

    KaList = []
    KsList = []
    KaKsList = []

    sequences = SeqDict.keys()
    sequence.sort()

    outline = '#seq1\tseq2\tKa\tKs\tKa/Ks\tLength\tS-Sites\tN-Sites\tSubstitutions\tS-Substitutions\tN-Substitutions'
    outfile.write(outline + '\n')

    for i in range(len(sequences)):
        s1 = sequences[i]
        for j in range(i+1,len(sequences)):
            s2 = sequences[j]

            tempfasta = open(outprefix + '.temp.fa','w')
            tempfasta.write('>' + s1 + '\n')
            tempfasta.write(SeqDict[s1] + '\n')
            tempfasta.write('>' + s2 + '\n')
            tempfasta.write(SeqDict[s2] + '\n')
            tempfasta.close()
            cmd = 'muscle -in ' + outprefix + '.temp.fa' + ' -out ' + outprefix + '.temp.muscle.fa'
            print cmd
            os.system(cmd)
            
            TempSeqDict={}
            sequence=''
            inputdatafile = open(outprefix + '.temp.muscle.fa')
            for line in inputdatafile:
                if line[0]=='>':
                    if sequence != '':
                        TempSeqDict[chr] = ''.join(sequence).upper()
                    chr = line.strip().split('>')[1]
                    sequence=[]
                    Keep=False
                    continue
                else:
                    sequence.append(line.strip())
            TempSeqDict[chr] = ''.join(sequence).upper()

            cmd = 'rm ' + outprefix + '.temp.fa'
            os.system(cmd)
            cmd = 'rm ' + outprefix + '.temp.muscle.fa'
            os.system(cmd)

            seq1 = TempSeqDict[s1].replace('-','')
            seq2 = TempSeqDict[s2].replace('-','')
            trailing = max(len(seq1)%3,len(seq2)%3)
            print 'trailing', trailing
            firstNonDash1 = min(TempSeqDict[s1].index('A'),TempSeqDict[s1].index('C'),TempSeqDict[s1].index('G'),TempSeqDict[s1].index('T'))
            firstNonDash2 = min(TempSeqDict[s1].index('A'),TempSeqDict[s1].index('C'),TempSeqDict[s1].index('G'),TempSeqDict[s1].index('T'))
            tempfile = open(outprefix + '.temp.axt','w')
            tempfile.write('sequence' + '\n')
            tempfile.write(TempSeqDict[s1][min(firstNonDash1,firstNonDash2):-trailing] + '\n')
            tempfile.write(TempSeqDict[s2][min(firstNonDash1,firstNonDash2):-trailing] + '\n')
#            tempfile.write(TempSeqDict[s1] + '\n')
#            tempfile.write(TempSeqDict[s2] + '\n')
            tempfile.close()

            cmd = KaKs + ' -i ' + outprefix + '.temp.axt' + ' -o ' + outprefix + '.temp.KaKs -c ' + GeneticCode + ' -m ' + Model
            print cmd
            os.system(cmd)
            RETRY = False
            linelist = open(outprefix + '.temp.KaKs')
            fields = linelist.readline().strip().split()
            if len(fields) < 2:
                RETRY = True
#                sys.exit(1)
            if RETRY:
                tempfile = open(outprefix + '.temp.axt','w')
                tempfile.write('sequence' + '\n')
                tempfile.write(TempSeqDict[s1][min(firstNonDash1,firstNonDash2):-(trailing + 1)] + '\n')
                tempfile.write(TempSeqDict[s2][min(firstNonDash1,firstNonDash2):-(trailing + 1)] + '\n')
                tempfile.close()
                cmd = KaKs + ' -i ' + outprefix + '.temp.axt' + ' -o ' + outprefix + '.temp.KaKs -c ' + GeneticCode + ' -m ' + Model
                print 'RETRY: ', cmd
                os.system(cmd)
                linelist = open(outprefix + '.temp.KaKs')
                fields = linelist.readline().strip().split()
                RETRY2 = False
                if len(fields) < 2:
                    RETRY2 = True
#                    sys.exit(1)
                if RETRY2:
                    tempfile = open(outprefix + '.temp.axt','w')
                    tempfile.write('sequence' + '\n')
                    tempfile.write(TempSeqDict[s1][min(firstNonDash1,firstNonDash2):-(trailing + 2)] + '\n')
                    tempfile.write(TempSeqDict[s2][min(firstNonDash1,firstNonDash2):-(trailing + 2)] + '\n')
                    tempfile.close()
                    cmd = KaKs + ' -i ' + outprefix + '.temp.axt' + ' -o ' + outprefix + '.temp.KaKs -c ' + GeneticCode + ' -m ' + Model
                    print 'RETRY2: ', cmd
                    os.system(cmd)
                    linelist = open(outprefix + '.temp.KaKs')
                    fields = linelist.readline().strip().split()
                    if len(fields) < 2:
                        print 'problem with sequences', s1, s2, 'exiting'
                        sys.exit(1)
            KaKs_field = fields.index('Ka/Ks')
            Ka_field = fields.index('Ka')
            Ks_field = fields.index('Ks')
            L_field = fields.index('Length')
            SS_field = fields.index('S-Sites')
            NS_field = fields.index('N-Sites')
            Sub_field = fields.index('Substitutions')
            SSub_field = fields.index('S-Substitutions')
            NSub_field = fields.index('N-Substitutions')
            fields = linelist.readline().strip().split()
            outline = s1 + '\t' + s2
            outline = outline + '\t' + fields[Ka_field]
            outline = outline + '\t' + fields[Ks_field]
            outline = outline + '\t' + fields[KaKs_field]
            outline = outline + '\t' + fields[L_field]
            outline = outline + '\t' + fields[SS_field]
            outline = outline + '\t' + fields[NS_field]
            outline = outline + '\t' + fields[Sub_field]
            outline = outline + '\t' + fields[SSub_field]
            outline = outline + '\t' + fields[NSub_field]
            outfile.write(outline + '\n')
            if fields[KaKs_field] != 'NA' and fields[Ka_field] != 'NA' and fields[Ks_field] != 'NA':
                KaKsList.append(float(fields[KaKs_field]))
                KaList.append(float(fields[Ka_field]))
                KsList.append(float(fields[Ks_field]))

    if len(KsList) == 0:
        outline = '#Average Ka: ' + 'NA'
        outfile.write(outline + '\n')
        outline = '#Average Ks: ' + 'NA'
        outfile.write(outline + '\n')
        outline = '#Average KaKs: ' + 'NA'
        outfile.write(outline + '\n')
    else:
        outline = '#Average Ka: ' + str(sum(KaList)/len(KaList))
        outfile.write(outline + '\n')
        outline = '#Average Ks: ' + str(sum(KsList)/len(KsList))
        outfile.write(outline + '\n')
        outline = '#Average KaKs: ' + str(sum(KaKsList)/len(KaKsList))
        outfile.write(outline + '\n')

    cmd = 'rm ' + outprefix + '.temp.axt'
    os.system(cmd)
    cmd = 'rm ' + outprefix + '.temp.KaKs'
    os.system(cmd)

    outfile.close()

run()

