##################################
#                                #
# Last modified 2022/04/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta chunk_size outfile' % sys.argv[0]
        sys.exit(1)

    fasta = sys.argv[1]
    maxCS = int(sys.argv[2])
    Nlen = int(sys.argv[3])
    Ns = Nlen*'N'
    outfilename = sys.argv[4]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
#            if SeenDict.has_key(chr):
#                pass
#            else:
#                if doUAD:
#                    NewFastaDict[UAD].append(chr)
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    print 'finished inputting fasta'

    ChrSize = []
    for chr in GenomeDict.keys():
        ChrSize.append((len(GenomeDict[chr]),chr))

    ChrSize.sort()
    ChrSize.reverse()
   
    outfile = open(outfilename, 'w')

    blocksize = 100

    NewFastaDict = {}
    CC = 1
    seqList = []
    newChrSize = 0
    for (Csize,chr) in ChrSize:
        seqList.append(chr)
        newChrSize += Csize
        if newChrSize >= maxCS:
            newChr = 'mergedcontigs' + str(CC)
            NewFastaDict[newChr] = seqList
            seqList = []
            newChrSize = 0
            CC += 1
    if newChrSize > 0:
        newChr = 'mergedcontigs' + str(CC)
        NewFastaDict[newChr] = seqList

    for (newChr) in NewFastaDict.keys():
        outline = '>' + newChr
        outfile.write(outline + '\n')
        seqList = []
        for chr in NewFastaDict[newChr]:
            seqList.append(GenomeDict[chr])
        newSeq = Ns.join(seqList)
        for i in range(0,len(newSeq),blocksize):
            outfile.write(newSeq[i:min(i+blocksize, len(newSeq))] + '\n')

    outfile.close()

run()
