##################################
#                                #
# Last modified 07/18/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s alignment_fasta radius reference_sequence_ID outfile' % sys.argv[0]
        print '\tradius: 0 for the residue itself, 1 for the three residues including it, 2 for the 5, etc.	'
        sys.exit(1)

    fasta = sys.argv[1]
    radius = int(sys.argv[2])
    refID = sys.argv[3]
    outfilename = sys.argv[4]

    SeqDict = {}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                SeqDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    SeqDict[chr] = ''.join(sequence)

    outfile = open(outfilename,'w')
    outline = '#position\tRes\tSeq\tConservation'
    outfile.write(outline + '\n')

    IncompleteDict = {}
    for chr in SeqDict.keys():
        if chr == refID:
            continue
        seq = SeqDict[chr]
        if seq.replace('-',' ').strip().endswith('X'):
            IncompleteDict[chr] = seq.rfind('X')

    i=0
    ResiduePositions = []
    for AA in SeqDict[refID]:
        i+=1
        if AA == '-':
            continue
        ResiduePositions.append(i-1)
 
    j=0
    for i in range(len(ResiduePositions)):
        Total = len(SeqDict.keys()) - 1
        for chr in SeqDict.keys():
            if IncompleteDict.has_key(chr):
                if IncompleteDict[chr] <= ResiduePositions[i]:
                    Total = Total - 1
        j+=1
        conserved = 0
        res = SeqDict[refID][ResiduePositions[i]]
        seqRef = ''
        for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
            seqRef += SeqDict[refID][ResiduePositions[p]]
        for ID in SeqDict.keys():
            if ID == refID:
               continue
            seq = ''
            for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
                seq += SeqDict[ID][ResiduePositions[p]]
            if seq == seqRef:
                conserved+=1
        outline = str(j) + '\t' + res + '\t' + seqRef + '\t' + str(conserved) + '/' + str(Total)
        outfile.write(outline + '\n')

    outfile.close()


run()

