##################################
#                                #
# Last modified 07/17/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy

def normcovariance(Vi,Vj):

    if len(Vi) != len(Vj):
        print 'vectors Vi, Vj of differing length, exiting'
    meani = numpy.mean(Vi)
    meanj = numpy.mean(Vj)
    cov = 0
    if meani == 0 or meani == 1 or meanj == 0 or meanj == 1:
        corr = 'nan'
    else:
        for k in range(len(Vi)):
            cov += (Vi[k] - meani)*(Vj[k] - meanj)
        corr = (cov/len(Vi))/(numpy.std(Vi)*numpy.std(Vj))
#        corr = cov/len(Vi)

    return corr

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s alignment_fasta radius reference_sequence_ID outfile' % sys.argv[0]
        print '\tradius: 0 for the residue itself, 1 for the three residues including it, 2 for the 5, etc.	'
        sys.exit(1)

    fasta = sys.argv[1]
    radius = int(sys.argv[2])
    refID = sys.argv[3]
    outfilename = sys.argv[4]

    SeqDict = {}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                SeqDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    SeqDict[chr] = ''.join(sequence)

    i=0
    ResiduePositions = []
    for AA in SeqDict[refID]:
        i+=1
        if AA == '-':
            continue
        ResiduePositions.append(i-1)
 
    Total = len(SeqDict.keys()) - 1

    PDict = {}
    j=0
    for i in range(len(ResiduePositions)):
        j+=1
        conserved = 0
        res = SeqDict[refID][ResiduePositions[i]]
        seqRef = ''
        for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
            seqRef += SeqDict[refID][ResiduePositions[p]]
        for ID in SeqDict.keys():
            if ID == refID:
               continue
            seq = ''
            for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
                seq += SeqDict[ID][ResiduePositions[p]]
            if seq == seqRef:
                conserved+=1
        PDict[i] = conserved/(Total + 0.0)

    CovDict = {}

    outline = str(j) + '\t' + res + '\t' + seqRef + '\t' + str(conserved) + '/' + str(Total)

    outfile = open(outfilename,'w')
    outline1 = '#\t\t'
    outline2 = '\t' + str(len(SeqDict.keys())) + '\t'
    outline3 = '\t\t'

    for i in range(len(ResiduePositions)):
        CovDict[i] = {}
        pi = PDict[i]
        outline3 = outline3 + '\t' + str(pi)
        resi = SeqDict[refID][ResiduePositions[i]]
        outline1 = outline1 + '\t' + resi
        outline2 = outline2 + '\t' + str(i+1)
        seqRefi = ''
        for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
            seqRefi += SeqDict[refID][ResiduePositions[p]]
        for j in range(len(ResiduePositions)):
            pj = PDict[j]
            resj = SeqDict[refID][ResiduePositions[j]]
            seqRefj = ''
            for p in range(max(j - radius,0),min(j + radius + 1,len(ResiduePositions))):
                seqRefj += SeqDict[refID][ResiduePositions[p]]
            Vi = []
            Vj = []
            for ID in SeqDict.keys():
                if ID == refID:
                    continue
                seqi = ''
                for p in range(max(i - radius,0),min(i + radius + 1,len(ResiduePositions))):
                    seqi += SeqDict[ID][ResiduePositions[p]]
                if seqi == seqRefi:
                    conservedi = 1
                else:
                    conservedi = 0
                seqj = ''
                for p in range(max(j - radius,0),min(j + radius + 1,len(ResiduePositions))):
                    seqj += SeqDict[ID][ResiduePositions[p]]
                if seqj == seqRefj:
                    conservedj = 1
                else:
                    conservedj = 0
                Vi.append(conservedi)
                Vj.append(conservedj)
            CovDict[i][j] = normcovariance(Vi,Vj)

        PDict[j] = conserved/(Total + 0.0)

    outfile.write(outline1 + '\n')
    outfile.write(outline2 + '\n')
    outfile.write(outline3 + '\n')

    for i in range(len(ResiduePositions)):
        resi = SeqDict[refID][ResiduePositions[i]]
        outline = resi + '\t' + str(i+1) + '\t' + str(PDict[i])
        for j in range(len(ResiduePositions)):
            outline = outline + '\t' + str(CovDict[i][j])
        outfile.write(outline + '\n')

    outfile.close()


run()

