##################################
#                                #
# Last modified 06/04/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s transcriptome_fasta readlength numberReads outifle' % sys.argv[0]
        print '\t use - to print to stdout instead of writing to file'
        sys.exit(1)

    fasta = sys.argv[1]
    readlength = int(sys.argv[2])
    NumReads = int(sys.argv[3])
    outputfilename = sys.argv[4]

    doStdOut = False
    if outputfilename == '-':
        doStdOut = True

    SeqDict={}
    lineslist = open(fasta)
    currentChr=''
    i=0
    for line in lineslist:
        i+=1
        if i % 1000000 == 0:
            if doStdOut:
                pass
            else:
                print i, 'lines processed'
        if line.startswith('>'):
            if currentChr != '':
                SeqDict[currentChr]=sequence
            currentChr=line.strip().split('>')[1]
            sequence=''
        else:
            sequence+=line.strip()

    SeqDict[currentChr]=sequence

    keys=SeqDict.keys()
    keys.sort()

    TotalLength = 0.0
    for chr in keys:
        TotalLength += len(SeqDict[chr])
    Fraction = NumReads/(TotalLength - readlength*len(keys))

    p = Fraction % 1
    f = Fraction - p

#    print TotalLength, Fraction, p, f
#    sys.exit()

    if doStdOut:
        pass
    else:
        outfile = open(outputfilename, 'w')

    for chr in keys:
        if len(SeqDict[chr]) <= readlength:
            continue
        if doStdOut:
            pass
        else:
            print chr
        for i in range(len(SeqDict[chr])-readlength+1):
            read=SeqDict[chr][i:i+readlength]
            for j in range(int(f)):
                if doStdOut:
                    print '>' + chr + ':' + str(i) + '-' + str(i+readlength) + '-' + str(j)
                    print read
                else:
                    outfile.write('>'+chr+':'+str(i)+'-'+str(i+readlength)+'\n')
                    outfile.write(read+'\n')
            r = random.random()
            if r > p:
                continue
            if doStdOut:
                print '>' + chr + ':' + str(i) + '-' + str(i+readlength)
                print read
            else:
                outfile.write('>'+chr+':'+str(i)+'-'+str(i+readlength)+'\n')
                outfile.write(read+'\n')

    if doStdOut:
        pass
    else:
        outfile.close()

run()

