##################################
#                                #
# Last modified 11/11/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s aligment_fasta outfileprefix' % sys.argv[0]
        sys.exit(1)

    fasta = sys.argv[1]
    outprefix = sys.argv[2]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1].split(' ')[0]
            sequence=[]
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    seqList = GenomeDict.keys()
    seqList.sort()

    outfileFasta = open(outprefix + '.fa','w')
    outfilePos = open(outprefix + '.discordant_positions','w')

    
    outline = '>consensus'
    outfileFasta.write(outline + '\n')

    seqIDs = GenomeDict.keys()
    seqIDs.sort()

    outline = '#Pos\tConsensus'
    for ID in seqIDs:
        outline = outline + '\t' + ID
    outfilePos.write(outline + '\n')

    consensusSeq = ''
    for pos in range(len(GenomeDict[seqIDs[0]])):
        bases = []
        for ID in seqIDs:
            bases.append(GenomeDict[ID][pos])
        BS = list(Set(bases))
        if len(BS) == 1:
            consensusSeq+= BS[0]
            continue
        counts = []
        for b in BS:
            counts.append((bases.count(b),b))
        counts.sort()
        counts.reverse()
        Cbases = []
        for i in range(len(counts)):
            if counts[i][0] == counts[0][0]:
                Cbases.append(counts[i][1])
            else:
                break
        if len(Cbases) == 1:
            C = Cbases[0]
        else:
            if Set(Cbases) == Set(['A','G']):
                C = 'R'
            elif Set(Cbases) == Set(['C','T']):
                C = 'Y'
            elif Set(Cbases) == Set(['G','C']):
                C = 'S'
            elif Set(Cbases) == Set(['A','T']):
                C = 'W'
            elif Set(Cbases) == Set(['G','T']):
                C = 'K'
            elif Set(Cbases) == Set(['A','C']):
                C = 'M'
            elif Set(Cbases) == Set(['G','C','T']):
                C = 'B'
            elif Set(Cbases) == Set(['G','A','T']):
                C = 'D'
            elif Set(Cbases) == Set(['A','C','T']):
                C = 'H'
            elif Set(Cbases) == Set(['A','C','G']):
                C = 'V'
            else:
                C = 'N'
        outline = str(pos) + '\t' + C
        for ID in seqIDs:
            outline = outline + '\t' + GenomeDict[ID][pos]
        outfilePos.write(outline + '\n')
        consensusSeq += C

    for j in range(0,len(consensusSeq),50):
        outfileFasta.write(consensusSeq[j:min(j+50, len(consensusSeq))] + '\n')

    outfileFasta.close()
    outfilePos.close()

run()

