##################################
#                                #
# Last modified 10/12/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s alignment_fasta coverage_fraction outfile [-phylipOutput] [-nexusOutput dna|protein]' % sys.argv[0]
        sys.exit(1)

    fasta = sys.argv[1]
    coverageCutoff = float(sys.argv[2])
    outfilename = sys.argv[3]

    doPO = False
    if '-phylipOutput' in sys.argv:
        doPO = True

    doNO = False
    if '-nexusOutput' in sys.argv:
        doNO = True
        datatype = sys.argv[sys.argv.index('-nexusOutput') + 1]

    SeqDict = {}
    sequence=''
    sequences = []
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                SeqDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    SeqDict[chr] = ''.join(sequence)

    sequences = SeqDict.keys()

    CovDict = {}
    for i in range(len(SeqDict[sequences[0]])):
        Gaps = 0.0
        for chr in sequences:
            if SeqDict[chr][i] == '-':
                Gaps+=1
        CovDict[i] = 1 - Gaps/len(sequences)

    outfile = open(outfilename,'w')

    if doNO:
        chr = sequences[0]
        s=0
        for i in range(len(SeqDict[chr])):
            if CovDict[i] >= coverageCutoff:
                s+=1
        outfile.write('#NEXUS' + '\n')
        outfile.write('begin data;' + '\n')
        outfile.write('  dimensions ntax=' + str(len(sequence)) + ' nchar=' + str(s) + ';\n')
        outfile.write('  format datatype=' + datatype + ' interleave=no gap=-;\n')
        outfile.write('  matrix' + '\n')
        for chr in sequences:
            outline = '  ' + chr + '\t'
            for i in range(len(SeqDict[chr])):
                if CovDict[i] >= coverageCutoff:
                    outline = outline + SeqDict[chr][i]
            outfile.write(outline + '\n')
        outfile.write('  ;\n')
        outfile.write('end;\n')
    elif doPO:
        for chr in sequences:
            outline = chr + '\t'
            for i in range(len(SeqDict[chr])):
                if CovDict[i] >= coverageCutoff:
                    outline = outline + SeqDict[chr][i]
            outfile.write(outline + '\n')
    else:
        for chr in sequences:
            outline = '>' + chr
            outfile.write(outline + '\n')
            outline = ''
            for i in range(len(SeqDict[chr])):
                if CovDict[i] >= coverageCutoff:
                    outline = outline + SeqDict[chr][i]
            outfile.write(outline + '\n')

    outfile.close()


run()

