##################################
#                                #
# Last modified 03/26/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input startField [-splitByRepeat repeatFieldID]' % sys.argv[0]
        print 'assumed format:'
        print '\t#SELECT seq,[otherfield]]vasa_het_GEO_mult,vasa_mut_GEO_mult FROM tbl_colin_dros_cell2009_GEO_mappers'
        print '\tseq,vasa_het_GEO_mult,vasa_mut_GEO_mult'
        print '\tAAAAAAAAACAGCTAGAGTGAAG,0,0'
        print '\tAAAAAAAAACGAAAAGCACAATA,0,0'
        print 'the start field [0-based] should refer to the first field with read multiplicity'
        sys.exit(1)
    
    input = sys.argv[1]
    startFieldID = int(sys.argv[2])

    doRepeatSplit = False
    if '-splitByRepeat' in sys.argv:
        doRepeatSplit = True
        RFID = int(sys.argv[sys.argv.index('-splitByRepeat') + 1])

    ReadDict = {}
    NameDict = {}
    ReadCountDict = {}
    RList = []

    lineslist = open(input)
    for line in lineslist:
        if line.startswith('#'):
            continue
        if line.startswith('seq,'):
            fields = line.strip().split(',')
            for i in range(startFieldID,len(fields)):
                NameDict[i-startFieldID] = fields[i]
                ReadCountDict[i-startFieldID] = 0
            continue
        fields=line.strip().split(',')
        seq = fields[0]
        counts = []
        for i in range(startFieldID,len(fields)):
            counts.append(int(fields[i]))
        if max(counts) == 0:
            continue
        R = ''
        if doRepeatSplit:
            R = fields[RFID]
            RList.append(R)
        ReadDict[(seq,R)] = tuple(counts)

    RList = list(Set(RList))

    OutFileDict = {}
    for i in NameDict.keys():
        if doRepeatSplit:
            for R in RList:
                outfile = open(R + '.' + NameDict[i] + '.fa','w')
                OutFileDict[(i,R)] = outfile
        else:
            outfile = open(NameDict[i] + '.fa','w')
            OutFileDict[i] = outfile

    for (seq,R) in ReadDict.keys():
        counts = list(ReadDict[(seq,R)])
        for i in range(len(counts)):
            if counts[i] == 0:
                continue
            for j in range(counts[i]):
                ReadCountDict[i] += 1
                outline = '>read' + str(ReadCountDict[i])
                if R == '':
                    OutFileDict[i].write(outline + '\n')
                    OutFileDict[i].write(seq + '\n')
                else:
                    OutFileDict[(i,R)].write(outline + '\n')
                    OutFileDict[(i,R)].write(seq + '\n')
    
    if doRepeatSplit:
        for (i,R) in OutFileDict.keys():
            OutFileDict[(i,R)].close()
    else:
        for i in OutFileDict.keys():
            OutFileDict[i].close()

run()
