##################################
#                                #
# Last modified 2017/02/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s filelist outputfilename [-fraction]' % sys.argv[0]
        print 'filelist format: label <tab> filename <tab> fields (comma-separated)'
        sys.exit(1)

    input = sys.argv[1]
    outfilename = sys.argv[2]

    doFraction = False
    if '-fraction' in sys.argv:
        doFraction = True
        print 'will output fractions, not element counts'

    outfile = open(outfilename, 'w')
    
    lineslist = open(input)
    DataDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        label=fields[0]
        file=fields[1]
        fieldIDs=[]
        fields = fields[2].split(',')
        for ID in fields:
            fieldIDs.append(int(ID))
        fieldIDs.sort()
        DataDict[label]={}
        DataDict[label]['file']=file
        DataDict[label]['field']=fieldIDs

    for label in DataDict.keys():
        DataDict[label]['elements']=[]
        lineslist = open(DataDict[label]['file'])
        for line in lineslist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            element = []
            for ID in DataDict[label]['field']:
                element.append(fields[ID])
            element = tuple(element)
            DataDict[label]['elements'].append(element)
        DataDict[label]['elements']=list(Set(DataDict[label]['elements']))

    OverlapDict={}
    for label1 in DataDict.keys():        
        OverlapDict[label1]={}
        for label2 in DataDict.keys():
            combined=Set(DataDict[label1]['elements']).intersection(Set(DataDict[label2]['elements']))
            combined=list(combined)
            OverlapDict[label1][label2]=len(combined)

    keys=DataDict.keys()
    keys.sort()
    outline='#dataset\t'
    for label in keys:
        outline=outline+'\t'+label
    print outline
    outfile.write(outline+'\n')
    outline='\t#Elements'
    for label in keys:
        outline=outline+'\t'+str(len(DataDict[label]['elements']))
    print outline
    outfile.write(outline+'\n')
    for label1 in keys:
        outline=label1+'\t'+str(len(DataDict[label1]['elements']))
        for label2 in keys:
            if doFraction:
                if len(DataDict[label1]['elements']) == 0:
                    outline=outline+'\t' + 'nan'
                else:
                    outline=outline+'\t' + str(OverlapDict[label1][label2]/(len(DataDict[label1]['elements']) + 0.0))
            else:
                outline=outline + '\t' + str(OverlapDict[label1][label2])
        print outline
        outfile.write(outline+'\n')
        
    outfile.close()
   
run()
