##################################
#                                #
# Last modified 2019/07/12       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s contig_fasta output [-minContigSize bp]' % sys.argv[0]
        sys.exit(1)

    fasta = sys.argv[1]
    outfilename = sys.argv[2]

    doMS = False
    if '-minContigSize' in sys.argv:
        doMS = True
        mCS = int(sys.argv[sys.argv.index('-minContigSize') + 1])
        print 'will discard contigs shorter than', mCS, 'base pairs'
    
    input_stream = open(fasta)

    sequenceDict = {}
    sequenceList = []
    sequence=''
    for line in input_stream:
        if line[0]=='>':
            if sequence != '':
                sequenceDict[chr] = ''.join(sequence).upper()
                L = len(sequenceDict[chr])
                if doMS and L < mCS:
                    del sequenceDict[chr]
                    pass
                else:
                    sequenceList.append(L)
            chr = line.strip().split('>')[1]
            sequence=[]
            continue
        else:
            sequence.append(line.strip())
    sequenceDict[chr] = ''.join(sequence).upper()
    L = len(''.join(sequenceDict[chr]))
    if doMS and L < mCS:
        del sequenceDict[chr]
        pass
    else:
        sequenceList.append(L)

#    print len(sequenceList), len(sequenceDict.keys())

    sequenceList.sort()
    sequenceList.reverse()

    TL = sum(sequenceList) + 0.0

    A = 0
    C = 0
    G = 0
    T = 0

    for chr in sequenceDict.keys():
        A += sequenceDict[chr].count('A')
        C += sequenceDict[chr].count('C')
        G += sequenceDict[chr].count('G')
        T += sequenceDict[chr].count('T')

    SL = 0
    for L in sequenceList:
        SL += L
        if SL >= TL/2:
            N50 = L
            break

    SL = 0
    for L in sequenceList:
        if doMS and L < mCS:
            continue
        SL += L
        if SL >= 0.9*TL:
            N90 = L
            break

#    print TL, A, C, G, T, A+C+G+T

    outfile = open(outfilename,'w')
    outline = 'N50\t' + str(N50)
    print outline
    outfile.write(outline + '\n')

    outline = 'N90\t' + str(N90)
    print outline
    outfile.write(outline + '\n')

    outline = 'non-ACGT bases\t' + str((TL - A - C - G - T)/TL)
    print outline
    outfile.write(outline + '\n')

    outfile.close()

run()

