#!/usr/bin/python
# Author: Jingyi Jessica Li and Nathan Boley
# Program Completion Date: 09/18/2011
# Modification Date(s): 02/17/2012
# Copyright (c) 2011, Jingyi Jessica Li lijy03@gmail.com and Nathan Boley nboley@gmail.com
# All rights reserved.

PAUSE_ON_ERROR = True
DO_PROFILE = False
VERBOSE = True

import multiprocessing

import sys
import re

from operator import attrgetter, itemgetter
#from itertools import izip, product

# try to import the conversion function to work
# around rpy2's stupid typing. If we can't we must 
# be using rpy, so just build a stub function with the 
# same name to pass the value through.
try:
    from rpy import r
except ImportError:
    import rpy2.robjects
    from rpy2.robjects.numpy2ri import numpy2ri
    try:
        rpy2.robjects.numpy2ri.activate()
    except AttributeError:
        print "This rpy2 version is older than 2.2.0."
    r = rpy2.robjects.r            


import numpy
import scipy
import pysam
import random

import os

from collections import namedtuple
GenomicInterval = namedtuple('GenomicInterval', ['chr', 'strand', 'start', 'stop'])


#########################################################################################
class GeneBoundaries( dict ):
    Exon = namedtuple('Exon', ['chr', 'strand', 'start', 'stop'])

    @staticmethod
    def _parse_gtf_line( line  ):
        data = re.split( "\s+", line.strip() )
        # the type of element - ie exon, transcript, etc.
        type = data[2]

        # parse the meta data, and grab the gene name
        meta_data = dict( zip( data[8::2], ( i[:-1] for i in data[9::2] ) ) )
        try: gene_name = meta_data[ 'gene_name' ]
        except KeyError: gene_name = meta_data[ 'gene_id' ]
        
        # fix the chromosome if necessary
        chr = data[0]
        if chr.startswith("chr") and not ifchr:
            chr = data[0][3:]
        elif not chr.startswith("chr") and ifchr:
            chr = "chr" + chr
        return gene_name, type, GenomicInterval( chr, data[6], int(data[3]), int(data[4]) )

    def _validate_exons( self ):
        """Test all of the genes for consistency. If a gene is bad, move it to _invalid_genes
        
        
        """
        self._invalid_genes = {}
        # fix exons that overlap
        for gene_name in self.keys():
            # get the set of exons for this gene
            exons = list( self[ gene_name ] )
            
            try:
                # validate the exon inputs
                if not all( exon.chr == exons[0].chr for exon in exons ):
                    raise ValueError, "There appears to be a gene with exons from different chromosomes."
                if not all( exon.strand == exons[0].strand for exon in exons ):
                    raise ValueError, "There appears to be a gene with exons from different strands."
                strand = exons[0].strand
                chr = exons[0].chr
            except ValueError, inst:
                print( "Error in %s:" % gene_name )
                self._invalid_genes[ gene_name ] = self[ gene_name ]
                del self[ gene_name ]
        
    
    def __init__( self, fp  ):
        """Parse the gtf file at fp and return the genes and their unique exons.
        
        """
        # iterate through each line in the gtf, and organize exons
        # based upon the genes that they came from
        for line in fp:
            if line.startswith('#'):
                continue
            gene, type, data = self._parse_gtf_line( line )
            # skip non-exon lines
            if type != 'exon' and type != "EXON" and type != "Exon": continue
            
            try:
                self[gene].add( data )
            except KeyError:
                self[gene] = set()
                self[gene].add( data )
     
        # make sure that all of the exons make sense. ie, from the 
        # same chromosome, etc.
        self._validate_exons()
        
        # since we know that all of the remaining genes are on the same strand
        # and in the same chromosome, change the format so that the genes are
        # a chr, strand, and array of starts and stops
        for gene_name in self.keys():
            # get the set of exons for this gene
            exons = list( self[ gene_name ] )
            # build an array of the exon starts and stops to faciliate easier inclusion tests
            self[ gene_name ] = GeneModel( gene_name, exons[0].strand, exons[0].chr, 
                                           ( ( exon.start, exon.stop ) for exon in exons ) )
            
        
        return
#########################################################################################

#########################################################################################
class GeneModel( object ):
    @staticmethod
    def exons_do_overlap( bndry1, bndry2 ):
        return( bndry1[1] >= bndry2[0] )

    def _get_boundaries( self ):
        min_val = min( i[0] for i in self.exon_bndrys )
        max_val = max( i[1] for i in self.exon_bndrys )
        return GenomicInterval( self.chromosome, self.strand, min_val, max_val )        
        
    def __init__( self, name, strand, chromosome, exon_bndrys ):
        self.name = name
        self.strand = strand
        self.chromosome = chromosome
        
        self.exon_bndrys = sorted(exon_bndrys)
        
        self.boundaries = self._get_boundaries()
        
        self._possible_junctions = None
        
    def _find_possible_junctions( self ):
        """

        """
        # I moved this method from an external function, and I don't want to
        # put in all the self. so I'm just creating a local variable.
        exon_bndrys = self.exon_bndrys
        
        def junction_is_valid( bndry1, bndry2 ):
            """Stub for future junction filtering.
            
            """
            return True
        
        def iter_junctions( ):
            for i in xrange( len( exon_bndrys ) ):
                # we always include the bin with reads fully within that bin
                for j in xrange( i+i, len( exon_bndrys ) ):
                    # check whether this junction is valid
                    dont_overlap = ( not self.exons_do_overlap( exon_bndrys[i], exon_bndrys[j] ) )
                    # Stub for a potential junction is valid function
                    valid = junction_is_valid( exon_bndrys[i], exon_bndrys[j]  )
                    if dont_overlap and valid:
                        yield (i,j)

        self._possible_junctions = list( iter_junctions() )
    
    def iter_possible_junctions( self ):
        if self._possible_junctions == None:
            self._find_possible_junctions()
        return iter( self._possible_junctions )
#########################################################################################

#########################################################################################
class Reads( pysam.Samfile ):
    """Subclass the samfile object to include a method that returns reads and their pairs.


    """
    def _iter_paired_reads( self, gene_bndry ):
        # whether or not the gene is on the positive strand
        gene_strnd_is_rev = ( gene_bndry.strand == '-' )
        
        # get all of the first pairs
        def iter_pair1_reads():
            for read in self.fetch( gene_bndry.chr, gene_bndry.start, gene_bndry.stop  ):
                if read.is_read1 \
                   and read.is_reverse == gene_strnd_is_rev:
                    yield read

        # index the pair 2 reads
        reads_pair2 = {}
        for read in self.fetch( gene_bndry.chr, gene_bndry.start, gene_bndry.stop  ):
            if not read.is_read1 \
                and read.is_reverse != gene_strnd_is_rev:
                reads_pair2[ read.qname ] = read
        
        # iterate through the read pairs
        for read1 in iter_pair1_reads():
            try:
                read2 = reads_pair2[ read1.qname ]
            # if there is no mate, skip this read
            except KeyError:
                continue
            
            assert ( read1.qlen == read1.aend - read1.pos ) or ( len( read1.cigar ) > 1 )
            assert ( read2.qlen == read2.aend - read2.pos ) or ( len( read2.cigar ) > 1 )
            
            if read1.qlen != read2.qlen:
                print( "ERROR: unequal read lengths %i and %i\n", read1.qlen, read2.qlen )
                continue
            
            yield read1, read2
        
        return
    
    def get_paired_reads( self, gene_bndry ):
        gene_strnd_is_rev = ( gene_bndry.strand == '-' )
        
        for read1, read2 in self._iter_paired_reads( gene_bndry ):
            if not gene_strnd_is_rev:
                yield read1.qlen, ( read1.pos, read1.aend, read2.pos, read2.aend, read1, read2 )
            else:
                yield read1.qlen, ( read2.pos, read2.aend, read1.pos, read1.aend, read2, read1 )
        return
    
    def iter_junction_reads( self, gene_bndry ):
        gene_strnd_is_rev = ( gene_bndry.strand == '-' )
        for read in self.fetch( gene_bndry.chr, gene_bndry.start, gene_bndry.stop  ):
            if len( read.cigar ) > 1:
                yield read
        return
#########################################################################################

#########################################################################################
def iter_transcripts( exons ):
    """Iterate all transcripts.

    Transcripts are the set of all exons 
    such that no 2 exons overlap. The algroithm 
    is recursive and as follows:
    
    Initialization:
        1) sort the exon list by starts
        2) add each exon to the list ( any exon could
           start the transcript )
    
    Recursive Step:
        We want to add any exons that could immediately 
        follow the last exon in the transcript. That is, 
        we add the first non-overlapping exon, and then
        any exons that overlap that exon.
    
    Final Step:
        There are no more exons to be added.
    
    """
    Exon = namedtuple('Exon', ['start', 'stop'])

    exons = [ Exon(exon[0], exon[1]) for exon in  exons ]
    exons.sort( key=attrgetter('start') )
    
    if len( exons ) == 0:
        return

    def get_next_index( curr_index ):
        """Return the index of the next non-overlapping exon.
        
        """
        next_index = curr_index + 1
        while next_index < len( exons ) \
                and exons[next_index].start < exons[curr_index].stop:
            next_index += 1
        return next_index
    
    def get_overlapping_exons( index ):
        """Get all exons that overlap exon 'index', including itself
        
        """
        next_index = index
        while next_index < len( exons ) \
                and exons[index].start < exons[index].stop:
            yield next_index
            next_index += 1
        return
    
    # get the list of first exons
    transcripts = [ [index,] for index in reversed(range(len(exons))) ]
    while len( transcripts ) > 0:
        # get the next transcript off of the stack
        transcript = transcripts.pop()
        # every transcript is possible
        yield transcript
        
        # get the index of the next non-overlapping index
        next_nonoverlapping_index = get_next_index( transcript[-1] )
        if next_nonoverlapping_index > len( exons ):
            continue
        
        for next_exon_index in get_overlapping_exons( next_nonoverlapping_index ):
            transcript_copy = transcript[:]
            transcript_copy.append( next_exon_index )
            transcripts.append( transcript_copy )
    
    return
#########################################################################################   

######################################################################################### 
def build_array_from_transcripts( transcripts ):
    # first, find the max exon
    max_i = max( [ max(x) for x in transcripts ] )
    transcripts_array = []
    for exons in transcripts:
        transcript =  [0] * (max_i+1)
        for exon in exons:
            transcript[ exon ] = 1
        transcripts_array.append( transcript )
    return transcripts_array
######################################################################################### 

#########################################################################################
def renameTranscripts( G_is, old_G, new_G ):
	def if_cover( int1, int2 ):
		return not (int1[1] <= int2[0] or int1[0] >= int2[1])
	old_to_new_idx = []
	n = len( new_G )
	for old_exon in old_G:
		temp = []
		for new_exon_idx in range( n ):
			if if_cover( old_exon, new_G[new_exon_idx] ):
				temp.append( new_exon_idx )
		old_to_new_idx.append( temp )
	new_G_is = []
	for G_i in G_is:
		new_G_i = [0] * n
		old_to_new_exons = []
		for i in range( len(G_i) ):
			if G_i[i] == 1:
				old_to_new_exons.extend( old_to_new_idx[i] )
		old_to_new_exons = list( set( old_to_new_exons ) )
		for j in old_to_new_exons:
			new_G_i[ j ] = 1
		new_G_is.append( new_G_i )
		
	return( new_G_is )

#########################################################################################

#########################################################################################
def process_reads( gene, samfile, read_type ):
    # whether or not the gene is on the positive strand
    gene_strnd_is_rev = ( gene.boundaries.strand == '-' )
    
    if read_type == "paired-end": # the reads are paired-end
        # index the pair 1 reads
	    reads_pair1 = []
	    #qnames = {}
	    for read in samfile.fetch( gene.boundaries.chr, gene.boundaries.start, gene.boundaries.stop ):
	        #if read.qname in qnames.keys():
            #    	qnames[ read.qname ] += 1
	        #else:
	        #    	qnames[ read.qname ] = 0
	        if read.is_read1 and read.is_reverse == gene_strnd_is_rev:
	            
	            reads_pair1.append(read)
	    #print [key for key in qnames.keys() if qnames[key]==0]       
	    # index the pair 2 reads
	    reads_pair2 = {}
	    for read in samfile.fetch( gene.boundaries.chr, gene.boundaries.start, gene.boundaries.stop ):
	        if not read.is_read1 and read.is_reverse != gene_strnd_is_rev:
	           
	            reads_pair2[ read.qname ] = read

	    # iterate through the read pairs
	    for read1 in reads_pair1:
	        try:
	            read2 = reads_pair2[ read1.qname ]
	            
	        # if there is no mate, skip this read
	        except KeyError:
	            continue

	        assert ( read1.qlen == read1.aend - read1.pos ) or ( len( read1.cigar ) > 1 )
	        assert ( read2.qlen == read2.aend - read2.pos ) or ( len( read2.cigar ) > 1 )

	        if read1.qlen != read2.qlen:
	            print( "ERROR: unequal read lengths %i and %i\n", read1.qlen, read2.qlen )
	            continue
	        if not gene_strnd_is_rev:
	            yield read1.qlen, (read1.pos, read1.aend, read2.pos, read2.aend)
	        else:
	            yield read1.qlen, (read2.pos, read2.aend, read1.pos, read1.aend)
    
    elif read_type == "single-end": # the reads are single-end
	    for read in samfile.fetch( gene.boundaries.chr, gene.boundaries.start, gene.boundaries.stop ):
		    yield read.qlen, (read.pos, read.aend)
		
    else: # the reads are mixed (single and paired-end) or unknown to the users
        # store paired-end reads
        reads_paired = []
        # divide reads into single-end and paired-end ones, yield single-end reads
        for read in samfile.fetch( gene.boundaries.chr, gene.boundaries.start, gene.boundaries.stop ):
            if read.is_paired:
                reads_paired.append(read)	
            else:
                yield read.qlen, (read.pos, read.aend)
        # index the pair 1 reads
	    reads_pair1 = []
        reads_left = []
        for read in reads_paired:
            if read.is_read1 and read.is_reverse == gene_strnd_is_rev:
	            reads_pair1.append(read)
            else:
	            reads_left.append(read)

	    # index the pair 2 reads
        reads_pair2 = {}
        for read in reads_left:
	        if not read.is_read1 and read.is_reverse != gene_strnd_is_rev:
	            reads_pair2[ read.qname ] = read

	    # iterate through the read pairs
        for read1 in reads_pair1:
	        try:
	            read2 = reads_pair2[ read1.qname ]
	        # if there is no mate, skip this read
	        except KeyError:
	            continue

	        assert ( read1.qlen == read1.aend - read1.pos ) or ( len( read1.cigar ) > 1 )
	        assert ( read2.qlen == read2.aend - read2.pos ) or ( len( read2.cigar ) > 1 )

	        if read1.qlen != read2.qlen:
	            print( "ERROR: unequal read lengths %i and %i\n", read1.qlen, read2.qlen )
	            continue
	        if not gene_strnd_is_rev:
	            yield read1.qlen, (read1.pos, read1.aend, read2.pos, read2.aend)
	        else:
	            yield read1.qlen, (read2.pos, read2.aend, read1.pos, read1.aend)                    
    return

#########################################################################################

#########################################################################################
from quantile import quantile

def get_frag_len( genes, samfile ):
	frag_len = {}
	for read in samfile.fetch():
		frag_len[ read.qlen ] = []
	reads_pair1 = []
	reads_pair2 = {}
    
	max_num = 100000 # max number of fragment lengths
	n_num = 0 # number of fragment lengths
	
	for gene in genes:
		if n_num < max_num:
			gene_strnd_is_rev = ( gene.boundaries.strand == '-' )
			if ( len( gene.exon_bndrys ) == 1 ):
				for read in samfile.fetch( gene.boundaries.chr, gene.boundaries.start, gene.boundaries.stop ):
					if read.is_read1 and read.is_reverse == gene_strnd_is_rev:
						reads_pair1.append(read)
						n_num += 1
					else:
						if not read.is_read1 and read.is_reverse != gene_strnd_is_rev:
							reads_pair2[ read.qname ] = read
							n_num += 1

		else:
			break
		
	for read1 in reads_pair1:
		try:
			read2 = reads_pair2[ read1.qname ]
		# if there is no mate, skip this read
		except KeyError:
			continue
		assert ( read1.qlen == read1.aend - read1.pos ) or ( len( read1.cigar ) > 1 )
		assert ( read2.qlen == read2.aend - read2.pos ) or ( len( read2.cigar ) > 1 )
		if read1.qlen != read2.qlen:
			print( "ERROR: unequal read lengths %i and %i\n", read1.qlen, read2.qlen )
			continue
		if not gene_strnd_is_rev:
			frag_len[ read1.qlen ].append( int(read2.aend) - int(read1.pos) )
		else:
			frag_len[ read1.qlen ].append( int(read1.aend) - int(read2.pos) )
	
	frag_len_summary = {}
	for rl in frag_len.keys():
		l_quantile = quantile(frag_len[rl], 0.05, 7)
		u_quantile = quantile(frag_len[rl], 0.95, 7)
		mean = numpy.array(frag_len[rl]).mean()
		sd = numpy.array(frag_len[rl]).std()
		frag_len_summary[ rl ] = [l_quantile, u_quantile, mean, sd]
	return(frag_len_summary)
#########################################################################################

#########################################################################################
def buildExons( raw_exons, min_len=1 ):
	raw_exons = [ list(x) for x in set( tuple(x) for x in raw_exons ) ]
	bds = []
	for exon in raw_exons:
		bds.extend( exon )
	bds = sorted( list( set( bds ) ) )
	m = len( bds ) - 1
	intervals = []
	for i in range( m ):
		if i < m - 1:
			intervals.append( [ bds[i], bds[i+1]-1 ] )
		else:
			intervals.append( [ bds[i], bds[i+1] ] )
	#### function "if_overlap"
	def if_overlap( int1, int2 ):
		return not ( int1[1] <= int2[0] or int2[1] <= int1[0] )
	####
	pos_intervals = []
	for int1 in intervals:
		tmp = any( [ if_overlap( int1, int2 ) for int2 in raw_exons ] )
		if tmp:
			pos_intervals.append( int1 )
	intervals = pos_intervals
	lens = [ int1[1] - int1[0] for int1 in intervals ]
	singular_list = []
	#min_len = 1 minimum exon length
	if len( lens ) > 1:
		arg1 = any( [ x < min_len for x in lens ] )
		while arg1:
			short_idx = list( set( i for i in range(len(lens)) if lens[i] < min_len ) - set( singular_list ) )
			k = short_idx[0]
			if k > 0 and k < len( intervals ) - 1:
				a = intervals[k+1][0] == intervals[k][1] + 1
				b = intervals[k][0] == intervals[k-1][1] + 1
				if a and b:
					if lens[k+1] <= lens[k-1]:
						intervals[k+1][0] = intervals[k][0]
					else:
						intervals[k-1][1] = intervals[k][1]
					intervals.pop(k)
				elif a:
					intervals[k+1][0] = intervals[k][0]
					intervals.pop(k)
				elif b:
					intervals[k-1][1] = intervals[k][1]
					intervals.pop(k)
				else:
					singular_list.append(k)
			elif k == 0:
				if intervals[k+1][0] == intervals[k][1] + 1:
					intervals[k+1][0] = intervals[k][0]
					intervals.pop(k)
				else:
					singular_list.append(k)
			elif k == len(intervals) - 1:
				if intervals[k][0] == intervals[k-1][1] + 1:
					intervals[k-1][1] = intervals[k][1]
					intervals.pop(k)
				else:
					singular_list.append(k)
			#if len( intervals ) == 0:
			#	continue
			lens = [ int1[1] - int1[0] for int1 in intervals ]
			if len( singular_list ) == 0:
				arg1 = any( [ x < min_len for x in lens ] )
			else:
				arg1 = any( [ lens[i] < min_len for i in range(len(lens)) if not i in singular_list ] )
	
	exons = intervals
	
	# match starting and ending positions of exons to those of raw_exons
	raw_pos_list = bds
	for k in range( len( exons) ):
		for j in range(2):
			temp1 = [ x - exons[k][j] for x in raw_pos_list ]
			temp2 = [ i for i in range(len(temp1)) if abs(temp1[i])==1 ]
			if len( temp2 ) > 0:
				exons[k][j] = raw_pos_list[ temp2[0] ]
				
	
	exons = list( set( tuple(x) for x in exons ) ) # make the exons unique, by converting each exon to a tuple first
	sorted_exons = sorted( exons, key=lambda exon: exon[0] )   # sort by exon start position
	exons = [ list(x) for x in sorted_exons ] # restore each exon back to a list
				
	return exons
		
						
	
#########################################################################################

#########################################################################################
def build_bins( gene, samfile, read_type, new_exons ):
	bins = {}
	for rl, read  in process_reads( gene, samfile, read_type ):
		# read can be single-end (2-dim) or paired-end (4-dim)
		### function read2bin

		def read2bin( read, new_exons ):
			temp = []
			for pos in read:	
				for i in range( len( new_exons ) ):
					if pos >= new_exons[i][0] and pos <= new_exons[i][1]:
						temp.append( i+1 )
						break
					elif i == len(new_exons) - 1:  # if a position is not in any exons, assign its index as 0
						temp.append( 0 )
						break
					else:
						continue
			if len(temp) > 2 and temp[0] > temp[2]:
				temp = [ temp[2], temp[3], temp[0], temp[1] ]
			return temp
		###

		bin = read2bin( read, new_exons )   # convert a read's genomic positions to exon indices
		if len(bin) > 2: # paired-end
			temp_cond = not (bin[0]>bin[1] or bin[2]>bin[3])
		else: # single-end
			temp_cond = not (bin[0]>bin[1])
		if (not any([x==0 for x in bin])) and temp_cond:    # if every position has an exon index & an end's left index <= right indexp
			if not str(bin) in bins.keys():     # if this bin is not in the keys, add it to the keys
				bins[ str(bin) ] = { rl: 1 }
			else:
				if not rl in bins[ str(bin) ].keys():
					bins[ str(bin) ][ rl ] = 1
				else:
					bins[ str(bin) ][ rl ] += 1
	return(bins)
#########################################################################################

#########################################################################################
def filterTranscripts( G_is, filtered_bins ):
	junc_mat = []
	for j in range( len( filtered_bins ) ):
		if filtered_bins[j][0] != filtered_bins[j][1]:
			junc_mat.append( filtered_bins[j][0:2] )
		if len(filtered_bins[j]) > 2: # paired-end bins
			if filtered_bins[j][2] != filtered_bins[j][3]:
				junc_mat.append( filtered_bins[j][2:4] )
			if filtered_bins[j][1] + 1 == filtered_bins[j][2]:
				junc_mat.append( filtered_bins[j][1:3] )
			elif filtered_bins[j][1] + 1 < filtered_bins[j][2]:
				missing_exons = range( filtered_bins[j][1] + 1, filtered_bins[j][2] )
				temp = []
				for i in missing_exons:
					temp1 = [ i in x for x in filtered_bins ]
					temp.append( any( temp1 ) )
				if not any( temp ):
					junc_mat.append( filtered_bins[j][1:3] ) # if the unshown (missing) exon has no reads mapped to it, assume there is a junction skipping it 
	
	single_exons = []
	for i in range( len( G_is[0] ) ): # enumerate exon index i
		temp2 = []
		for j in range( len( filtered_bins ) ): # enumerate bin index j
			if i + 1 in filtered_bins[ j ]: # if exon (i+1) is in bin j
				temp2.extend( filtered_bins[ j ] ) # add exons in bin j into a list		
		if len( set( temp2 ) ) == 1: # if that list only contains exon (i+1)
			single_exons.append( i+1 )
	
	# handle the case where any two single exons are contiguous, add a possible junction
	if len( single_exons ) > 1:
		single_exons_sorted = sorted( single_exons )
		for i in range( len ( single_exons_sorted ) - 1 ):
			if single_exons_sorted[i+1] - single_exons_sorted[i] == 1:
				junc_mat.append( single_exons_sorted[i:(i+2)] )
	
	junc_mat = [ list(x) for x in set( tuple(x) for x in junc_mat ) ]
	junc_exp = [ 2**x[0] + 2**x[1] for x in junc_mat ]
	
	G_is_filtered = []
	for G_i in G_is:
		e = [ i+1 for i in range( len( G_i ) ) if G_i[i] == 1 ]
		if len( e ) > 1:
			tmp = True
			for p in range( len( e ) - 1 ):
				tmp = tmp and ( 2**e[p] + 2**e[p+1] ) in junc_exp
			if tmp:
				G_is_filtered.append( G_i )
		
	if len( single_exons ) > 0:
		for G_i in G_is:
			if sum( G_i ) == 1 and any( [ G_i[exon-1]==1 for exon in single_exons ] ):
				G_is_filtered.append( G_i )
	
	if len( G_is_filtered ) == 0:
		G_is_filtered = G_is
	
	
	return G_is_filtered
	
#########################################################################################

#########################################################################################
def build_long_transcripts( new_exons, filtered_bins ):
	junc_mat = []
	for j in range( len( filtered_bins ) ):
		if filtered_bins[j][0] != filtered_bins[j][1]:
			junc_mat.append( filtered_bins[j][0:2] )
		if len(filtered_bins[j]) > 2: # paired-end bins
			if filtered_bins[j][2] != filtered_bins[j][3]:
				junc_mat.append( filtered_bins[j][2:4] )
			if filtered_bins[j][1] + 1 == filtered_bins[j][2]:
				junc_mat.append( filtered_bins[j][1:3] )
			elif filtered_bins[j][1] + 1 < filtered_bins[j][2]:
				missing_exons = range( filtered_bins[j][1] + 1, filtered_bins[j][2] )
				temp = []
				for i in missing_exons:
					temp1 = [ i in x for x in filtered_bins ]
					temp.append( any( temp1 ) )
				if not any( temp ):
					junc_mat.append( filtered_bins[j][1:3] )
	# sort junc_mat
	junc_mat = [ list(x) for x in sorted( set( tuple(x) for x in junc_mat ) ) ]
				
	single_exons = []
	for i in range( len( new_exons ) ):
		temp2 = []
		for j in range( len( filtered_bins ) ):
			if i + 1 in filtered_bins[ j ]:
				temp2.extend( filtered_bins[ j ] )
			
		if len( set( temp2 ) ) == 1:
			single_exons.append( i+1 )
	
	# build the longest transcripts
	long_transcripts_build = []
	for junc in junc_mat:
		count = 0
		for build in long_transcripts_build:
			if junc[0] == build[-1]:
				build.extend( [ junc[1] ] )
				count += 1
		if count == 0:
			long_transcripts_build.append( junc )
	for exon in single_exons:
		count = 0
		for build in long_transcripts_build:
			if exon in build:
				count += 1
		if count == 0:
			long_transcripts_build.append( [ exon ] )
			
	G_is = []
	for build in long_transcripts_build:
		temp = [0] * len( new_exons )
		for x in build:
			temp[x-1] = 1
		G_is.append(temp)
	return G_is
	
#########################################################################################

#########################################################################################
def build_F( G_is, filtered_bins, G, read_lengths, frag_param, dist, GC_correction, chr_seq=None ):
	frag_lower_bd = []
	frag_upper_bd = []
	frag_mean = []
	frag_sd = []
	num_exons = len(G)
	exon_lengths = []
	for frag in frag_param:
		frag_lower_bd.append( frag[0] )
		frag_upper_bd.append( frag[1] )
		frag_mean.append( frag[2] )
		frag_sd.append( frag[3] )
	for exon in G:
		exon_lengths.append( exon[1] - exon [0] + 1 )
	
	densities = []
	# if GC content is not included, the density is uniform; otherwise the densitity is proportional to the GC content in exons
	if GC_correction:
			exon_GC_count = [ chr_seq.count( "G", exon[0]-1, exon[1] ) + chr_seq.count( "g", exon[0]-1, exon[1] ) + chr_seq.count( "C", exon[0]-1, exon[1] ) + chr_seq.count( "c", exon[0]-1, exon[1] ) for exon in G ]
				
	for G_i in G_is:
		temp = []
		if GC_correction:
			denominator = sum( [ exon_GC_count[ index ] for index in range( len(G_i) ) if G_i[ index ] == 1 ] )
			for i in range( len( G_i ) ):
				if G_i[ i ] == 0:
					temp.append( 0 )
				else:
					temp.append( exon_GC_count[i] / ( exon_lengths[i] * float(denominator) ) )
		else:
			temp_len = sum( [ exon_lengths[ index ] for index in range( len(G_i) ) if G_i[ index ] == 1 ] )
			temp_den = 1 / float(temp_len)
			for i in G_i:
				if i == 0:
					temp.append( 0 )
				else:
					temp.append( temp_den )
		densities.append( temp )		

	######## calculate F ############################
	
	
	# functions
	#### start of function "Int_Exp"
	def Int_Exp( l1, u1, l2, u2, lower_len, upper_len, mean, sd, den, dist ):
		if dist == "Exponential":
			def temp_FUN(x):
				a = scipy.special.gdtr( 1/mean, 1, min(u2-x, upper_len) ) - scipy.special.gdtr( 1/mean, 1, max(l2-x, lower_len) )
				b = scipy.special.gdtr( 1/mean, 1, upper_len ) - scipy.special.gdtr( 1/mean, 1, lower_len )
				return max( a/b, 0 )
		elif dist == "Normal":
			def temp_FUN(x):
				from scipy.stats import norm
				a = norm.cdf( min(u2-x, upper_len), mean, sd ) - norm.cdf( max(l2-x, lower_len), mean, sd )
				b = norm.cdf( upper_len, mean, sd ) - norm.cdf( lower_len, mean, sd )
				return max( a/b, 0 )
		else:
			sys.exit("Undefined fragment length distribution.  Please use \"trunExp\" or \"Normal\".")
		from scipy.integrate import quad
		return quad( temp_FUN, l1, u1 )[0] * den			
	#### end of function "Int_Exp"
	
	#### start of function "F_paired"
	def F_paired( G_i, G, bin_list, read_len, lower_len, upper_len, mean, sd, exon_den, dist ):
		#### start of function "f_paired"
		def f_paired( b_j, G_i, G, read_len, lower_len, upper_len, mean, sd, exon_den, dist ):
			exons = sorted( list( set( b_j ) ) )
			#if b_j[0] == b_j[1] and b_j[2] == b_j[3]:
			#	tmp = True
			#elif b_j[0] != b_j[1] and b_j[2] == b_j[3]:
			#	tmp = sum( G_i[ (b_j[0]-1):b_j[1] ] ) == 2
			#elif b_j[0] == b_j[1] and b_j[2] != b_j[3]:
			#	tmp = sum( G_i[ (b_j[2]-1):b_j[3] ] ) == 2
			#else:
			#	tmp = sum( G_i[ (b_j[0]-1):b_j[1] ] ) == 2 and sum( G_i[ (b_j[2]-1):b_j[3] ] ) == 2
			
			n = len( read_len ) # the number of different read lengths
			if all( [ G_i[x-1] == 1 for x in exons ] ):
				min_s = 0 # the transcript start
				max_e = 0
				for i in range( len(G_i) ):
					if G_i[i] == 1:
						max_e += G[i][1] - G[i][0] + 1
				# --- the transcript end
				s = exons[ 0 ]
				e = exons[ -1 ]
				L1 = G[ s-1 ][1] - G[ s-1 ][0] + 1
				if len( exons ) == 1:
					l1 = [0] * n
					u1 = [ L1 - x for x in read_len ]
					l2 = read_len
					u2 = [L1] * n
				elif len( exons ) == 2:
					L2 = G[ e-1 ][1] - G[ e-1 ][0] + 1
					if b_j[1] == b_j[0] and b_j[3] == b_j[2]:
						l1 = [0] * n
						u1 = [ L1 - x for x in read_len ]
						l2 = [ L1 + x for x in read_len ]
						u2 = [L1 + L2] * n
						if e - s > 1:
							for i in range( s, e-1 ):
								if G_i[i] == 1:
									l2 = [ x + G[i][1] - G[i][0] + 1 for x in l2 ]
									u2 = [ x + G[i][1] - G[i][0] + 1 for x in u2]
					elif b_j[1] == b_j[0] and b_j[3] > b_j[2]:
						l1 = [0] * n
						u1 = [ L1 - x for x in read_len ]
						l2 = [L1] * n
						u2 = [ L1 + x for x in read_len ]
						if e - s > 1:
							mid_len = 0
							for i in range( s, e-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l2[k] = l2[k] + mid_len
					elif b_j[1] > b_j[0] and b_j[3] == b_j[2]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [ L1 + x for x in read_len ]
						u2 = [L1 + L2] * n
						if e - s > 1:
							mid_len = 0
							for i in range( s, e-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l1[k] = l1[k] + mid_len
					elif b_j[1] > b_j[0] and b_j[3] > b_j[2]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [L1] * n
						u2 = [ L1 + x for x in read_len ]
						if e - s > 1:
							mid_len = 0
							for i in range( s, e-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l1[k] = l1[k] + mid_len
									l2[k] = l2[k] + mid_len
							
				elif len( exons ) == 3:
					m = exons[1]
					L2 = G[m-1][1] - G[m-1][0]
					L3 = G[e-1][1] - G[e-1][0]
					if b_j[1] == b_j[0]:
						l1 = [0] * n
						u1 = [ L1 - x for x in read_len ]
						l2 = [L1 + L2] * n
						u2 = [ L1 + L2 + x for x in read_len ]
						if m - s > 1:
							for i in range( s, m-1 ):
								if G_i[i] == 1:
									l2 = [ x + G[i][1] - G[i][0] + 1 for x in l2 ]
									u2 = [ x + G[i][1] - G[i][0] + 1 for x in u2 ]
						if e - m > 1:
							mid_len = 0
							for i in range( m, e-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l2[k] = l2[k] + mid_len
					elif b_j[3] == b_j[2]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [ L1 + L2 + x for x in read_len ]
						u2 = [L1 + L2 + L3] * n
						if m - s > 1:
							mid_len = 0
							for i in range( s, m-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l1[k] = l1[k] + mid_len
									l2[k] = l2[k] + mid_len
									u2[k] = u2[k] + mid_len
						if e - m > 1:
							for i in range( m, e-1 ):
								if G_i[i] == 1:
									l2 = [ x + G[i][1] - G[i][0] + 1 for x in l2 ]
									u2 = [ x + G[i][1] - G[i][0] + 1 for x in u2 ]	
						
					elif b_j[2] == b_j[1]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [L1 + L2] * n
						u2 = [ L1 + L2 + x for x in read_len ]
						if m - s > 1:
							mid_len = 0
							for i in range( s, m-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l1[k] = l1[k] + mid_len
									l2[k] = l2[k] + mid_len
									u2[k] = u2[k] + mid_len
						if e - m > 1:
							mid_len = 0
							for i in range( m, e-1 ):
								if G_i[i] == 1:
									mid_len += G[i][1] - G[i][0] + 1
							for k in range( len( read_len ) ):
								if mid_len > read_len[k]:
									u1[k] = l1[k] # force the integral to be 0
								else:
									l2[k] = l2[k] + mid_len
					elif b_j[2] == b_j[0]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [L1] * n
						u2 = [ L1 + x for x in read_len ]
						mid_len_1 = 0 # the length of exons between s and m
						mid_len_2 = L2 # the length of exons between s and e
						if m - s > 1:
							for i in range( s, m-1 ):
								if G_i[i] == 1:
									mid_len_1 += G[i][1] - G[i][0] + 1
									mid_len_2 += G[i][1] - G[i][0] + 1
						if e - m > 1:
							for i in range( m, e-1 ):
								if G_i[i] == 1:
									mid_len_2 += G[i][1] - G[i][0] + 1
						for k in range( len( read_len ) ):
							if mid_len_2 > read_len[k] or mid_len_1 > read_len[k]:
								u1[k] = l1[k] # force the integral to be 0
							else:
								l1[k] = l1[k] + mid_len_1
								l2[k] = l2[k] + mid_len_2

					elif b_j[3] == b_j[1]:
						l1 = [ L1 - x for x in read_len ]
						u1 = [L1] * n
						l2 = [ L1 + L2 ] * n
						u2 = [ L1 + L2 + x for x in read_len ]
						mid_len_1 = l2[0] # the length of exons between s and e
						mid_len_2 = 0 # the length of exons between m and e
						if m - s > 1:
							for i in range( s, m-1 ):
								if G_i[i] == 1:
									mid_len_1 += G[i][1] - G[i][0] + 1
						if e - m > 1:
							for i in range( m, e-1 ):
								if G_i[i] == 1:
									mid_len_1 += G[i][1] - G[i][0] + 1
									mid_len_2 += G[i][1] - G[i][0] + 1
						for k in range( len( read_len ) ):
							if mid_len_1 > read_len[k] or mid_len_2 > read_len[k]:
								u1[k] = l1[k] # force the integral to be 0
							else:
								l1[k] = l1[k] + mid_len_1
								l2[k] = l2[k] + mid_len_2
											
				elif len( exons ) == 4:
					m1 = exons[1]
					m2 = exons[2]
					L2 = G[m1-1][1] - G[m1-1][0]
					L3 = G[m2-1][1] - G[m2-1][0]
					L4 = G[e-1][1] - G[e-1][0]
					l1 = [ L1 - x for x in read_len ]
					u1 = [L1] * n
					l2 = [L1 + L2 + L3] * n
					u2 = [ L1 + L2 + L3 + x for x in read_len ]
					if m2 - m1 > 1:
						for i in range( m1, m2-1 ):
							if G_i[i] == 1:
								l2 = [ x + G[i][1] - G[i][0] + 1 for x in l2 ]
								u2 = [ x + G[i][1] - G[i][0] + 1 for x in u2 ]
					if m1 - s > 1:
						mid_len = 0
						for i in range( s, m1-1 ):
							if G_i[i] == 1:
								mid_len += G[i][1] - G[i][0] + 1
						for k in range( len( read_len ) ):
							if mid_len > read_len[k]:
								u1[k] = l1[k] # force the integral to be 0
							else:
								l1[k] = l1[k] + mid_len
								l2[k] = l2[k] + mid_len
								u2[k] = u2[k] + mid_len
					if e - m2 > 1:
						mid_len = 0
						for i in range( m2, e-1 ):
							if G_i[i] == 1:
								mid_len += G[i][1] - G[i][0] + 1
						for k in range( len( read_len ) ):
							if mid_len > read_len[k]:
								u1[k] = l1[k] # force the integral to be 0
							else:
								l2[k] = l2[k] + mid_len
				try:
					l1
				except NameError:
					print b_j, G_i, G		
				# correct the boundaries
				l1 = [ max( x, min_s ) for x in l1 ]
				u1 = [ max( x, min_s ) for x in u1 ]
				l2 = [ min( x, max_e ) for x in l2 ]
				u2 = [ min( x, max_e ) for x in u2 ]
				
				# calculation of probs				
				result = []
				for i in range( len(lower_len) ):
					if l1[i] != u1[i] and l2[i] != u2[i]:
						temp = Int_Exp( float(l1[i]), float(u1[i]), float(l2[i]), float(u2[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[s-1], dist )
					else:
						temp = float(0)
					result.append( temp )		
			else:
				result = [0] * len(lower_len)
			return result
			#### end of function "f_paired"
		#### start of function "f_single"	
		def f_single( b_j, G_i, G, read_len, lower_len, upper_len, mean, sd, exon_den, dist ):
			exons = sorted( list( set( b_j ) ) )
			#if b_j[0] == b_j[1] and b_j[2] == b_j[3]:
			#	tmp = True
			#elif b_j[0] != b_j[1] and b_j[2] == b_j[3]:
			#	tmp = sum( G_i[ (b_j[0]-1):b_j[1] ] ) == 2
			#elif b_j[0] == b_j[1] and b_j[2] != b_j[3]:
			#	tmp = sum( G_i[ (b_j[2]-1):b_j[3] ] ) == 2
			#else:
			#	tmp = sum( G_i[ (b_j[0]-1):b_j[1] ] ) == 2 and sum( G_i[ (b_j[2]-1):b_j[3] ] ) == 2

			n = len( read_len ) # the number of different read lengths
			if all( [ G_i[x-1] == 1 for x in exons ] ):
				min_s = 0 # the transcript start
				max_e = 0
				for i in range( len(G_i) ):
					if G_i[i] == 1:
						max_e += G[i][1] - G[i][0] + 1
				# --- the transcript end (max_e = transcript length)
				s = exons[ 0 ] # bin start exon
				e = exons[ -1 ] # bin stop exon
				trpt_exons = [ i+1 for i in range( len(G_i) ) if G_i[i] == 1 ]
				trpt_s = trpt_exons[ 0 ]
				trpt_e = trpt_exons[ -1 ]
				L1 = G[ s-1 ][1] - G[ s-1 ][0] + 1
				if len( exons ) == 1:
					if ( s==trpt_s or s==trpt_e ): # the bin start exon is in the transcript end
						l1 = [0] * n
						u1 = [ L1 - x for x in read_len ]
						l2 = read_len
						u2 = [max_e] * n
						# correct the boundaries
						u1 = [ max( x, min_s ) for x in u1 ]
						# calculation of probs
						result = []
						for i in range( len(lower_len) ):
							if l1[i] != u1[i] and l2[i] != u2[i]:
								temp = Int_Exp( float(l1[i]), float(u1[i]), float(l2[i]), float(u2[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[s-1], dist )
							else:
								temp = float(0)
							result.append( temp )
					else:                          # the bin start exon is in the middle of the transcript
						l_exons = [ x for x in trpt_exons if x < s ]
						r_exons = [ x for x in trpt_exons if x > s ]
						l_frag_len = max_e - sum([ G[i-1][1] - G[i-1][0] + 1 for i in l_exons ])
						r_frag_len = max_e - sum([ G[i-1][1] - G[i-1][0] + 1 for i in r_exons ])
						l1_l = [0] * n
						u1_l = [ L1 - x for x in read_len ]
						l2_l = read_len
						u2_l = [l_frag_len] * n
						l1_r = [0] * n
						u1_r = [ L1 - x for x in read_len ]
						l2_r = [L1] * n
						u2_r = [r_frag_len] * n
						# correct the boundaries
						u1_l = [ max( x, min_s ) for x in u1_l ]
						u1_r = [ max( x, min_s ) for x in u1_r ]
						l2_l = [ min( x, l_frag_len ) for x in l2_l ]
						# calculation of probs
						result = []
						for i in range( len(lower_len) ):
							if l1_l[i] != u1_l[i] and l2_l[i] != u2_l[i]:
								temp1 = Int_Exp( float(l1_l[i]), float(u1_l[i]), float(l2_l[i]), float(u2_l[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[s-1], dist )
							else:
								temp1 = float(0)
							if l1_r[i] != u1_r[i] and l2_r[i] != u2_r[i]:
								temp2 = Int_Exp( float(l1_r[i]), float(u1_r[i]), float(l2_r[i]), float(u2_r[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[s-1], dist )
							else:
								temp2 = float(0)	
							result.append( temp1 + temp2 )
				elif len( exons ) == 2:
					L2 = G[ e-1 ][1] - G[ e-1 ][0] + 1
					l_exons = [ x for x in trpt_exons if x < s ]
					r_exons = [ x for x in trpt_exons if x > s ]
					l_frag_len = max_e - sum([ G[i-1][1] - G[i-1][0] + 1 for i in l_exons ])
					r_frag_len = max_e - sum([ G[i-1][1] - G[i-1][0] + 1 for i in r_exons ])
					l1_l = [ L1 - x for x in read_len ]
					u1_l = [L1] * n
					l2_l = [L1] * n
					u2_l = [l_frag_len] * n
					l1_r = [ L2 - x for x in read_len ]
					u1_r = [L2] * n
					l2_r = [ L2 + x for x in read_len ]
					u2_r = [r_frag_len] * n
					# correct the boundaries
					l2_r = [ min( x, r_frag_len ) for x in l2_r ]
					# calculation of probs
					result = []
					for i in range( len(lower_len) ):
						if l1_l[i] != u1_l[i] and l2_l[i] != u2_l[i]:
							temp1 = Int_Exp( float(l1_l[i]), float(u1_l[i]), float(l2_l[i]), float(u2_l[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[s-1], dist )
						else:
							temp1 = float(0)
						if l1_r[i] != u1_r[i] and l2_r[i] != u2_r[i]:
							temp2 = Int_Exp( float(l1_r[i]), float(u1_r[i]), float(l2_r[i]), float(u2_r[i]), float(lower_len[i]), float(upper_len[i]), float(mean[i]), float(sd[i]), exon_den[e-1], dist )
						else:
							temp2 = float(0)
						result.append( temp1 + temp2 )	
			else:
				result = [0] * len(lower_len)
			return result
			#### end of function "f_single"
					
		cond_probs = []
		for bin in bin_list:
			if len(bin) == 4: # paired-end bins
				temp = f_paired( bin, G_i, G, read_len, lower_len, upper_len, mean, sd, exon_den, dist )
			else: # single-end bins
				temp = f_single( bin, G_i, G, read_len, lower_len, upper_len, mean, sd, exon_den, dist )
			cond_probs.append( temp )
		temp = numpy.transpose(cond_probs)
		temp2 = []
		for x in temp:
			temp2.extend(x)
		return( temp2 )
		#### end of function "F_paired"
	
	# calculation
	F = []
	for j in range( len(G_is) ):
		temp = F_paired( G_is[ j ], G, filtered_bins, read_lengths, frag_lower_bd, frag_upper_bd, frag_mean, frag_sd, densities[ j ], dist )
		F.append( temp )
	return( numpy.transpose(F) )

#########################################################################################

#########################################################################################
#def estimateTranscriptFrequencies( F, bin_counts, bins, G, G_is, lambda, total_n_reads ):
#	sum_bin_counts_by_rl = [ float( sum(x) ) for x in bin_counts ]
#	n = len( total_n_reads )
#	mean_total_n_reads = sum( total_n_reads ) / float(n)
#	total_n_reads_ratios = [ float(x) / mean_total_n_reads for x in total_n_reads ]
#	b = [] # used for isoform discovery
#	normalized_bin_counts = [] # used for quantification
#	for i in range( len( bin_counts ) ):
#		temp1 = [ x / sum_bin_counts_by_rl[ i ] * total_n_reads_ratios[ i ] for x in bin_counts[ i ] ]
#		b.extend( temp1 )
#		temp2 = [ x / total_n_reads_ratios[ i ] for x in bin_counts[ i ] ]
#		normalized_bin_counts.extend( temp2 )
	
#	if len(b) == 1:
		
		

#########################################################################################

#########################################################################################
def get_linear_models( genes_names, samfile, read_type, frag_len, total_num_reads, choose_frag_len_dist, if_use_GC_correction, gene_name_specified, chr_seqs=None, mode='discovery' ): # process the data into the liner model components ready for sparse estimation

	gene_linear_models = {}
	# group genes by their numbers of subexons

	for gene in genes_names:
		#if gene.name != '"NSUN6"':
		#	continue
		# check to make sure we have at least 1 exon ( this should always be true )
		if len( gene.exon_bndrys ) == 0:
			continue
			
		if not gene_name_specified == "":
			if gene.name != gene_name_specified:
			    continue

		# get non-overlapping exons
		new_exons = buildExons( gene.exon_bndrys, 10 )
		# enumerate all of existing ( filtered ) bins
		bin_built = build_bins( gene, samfile, read_type, new_exons )

		bins = bin_built.keys()	
		# remove non-existent exons
		existing_exons = []
		for bin in bins:
			temp_bin = [ int(x) for x in re.findall(r'\d+', bin) ] # convert bin name from string to vector
			existing_exons.extend( temp_bin )
		# if no reads are in the gene
		if len(existing_exons) == 0:
			continue
			
		# else, filter exons
		existing_exons = sorted(list( set(existing_exons) ))
		if len(existing_exons) < len(new_exons):
			new_exons = [ new_exons[i-1] for i in existing_exons ]
		# re-enumerate bins
		bin_built = build_bins( gene, samfile, read_type, new_exons )
		bins = bin_built.keys()
		bins.sort()
		filtered_bins = []
		for bin in bins:
			temp_bin = [ int(x) for x in re.findall(r'\d+', bin) ] # convert bin name from string to vector
			filtered_bins.append( temp_bin )
		
		if mode == "discovery":
			# for genes < 20 existing exons, we enumerate all the possible transcripts
			if len( new_exons ) < 20:
				transcripts = list( iter_transcripts( new_exons ) )
				possible_transcripts = build_array_from_transcripts( transcripts )
				# filter transcripts
				filtered_transcripts = filterTranscripts( possible_transcripts, filtered_bins )
			else: # for genes >= 20 exons, only consider the longest possible transcripts
				filtered_transcripts = build_long_transcripts( new_exons, filtered_bins )
		else:
			# get annotated isoforms
			annotated_isoforms = {}

			gtf_fp.seek(0) # move the pointer to the start of the gtf file
			line = gtf_fp.readline()
			while line:
				data = re.split( "\s+", line.strip() )
				# the type of element - ie exon, transcript, etc.
				line_type = data[2]
				if line_type=='exon':
					# parse the meta data, and grab the gene name
					meta_data = dict( zip( data[8::2], ( i[:-1] for i in data[9::2] ) ) )
					try: gene_name = meta_data[ 'gene_name' ]
					except KeyError: gene_name = meta_data[ 'gene_id' ]
					if gene_name == gene.name: # check if the gene in the gtf line is the same as the gene here
						try: transcript_name = meta_data[ 'transcript_name' ]
						except KeyError: transcript_name = meta_data[ 'transcript_id' ]
						if not transcript_name in annotated_isoforms.keys():
							annotated_isoforms[transcript_name] = []
						# check if the exon in the gtf line overlaps any exons in "new_exons"
						exon_start = int(data[3])
						exon_stop = int(data[4])
						is_covered = True
						for i in range(len(new_exons)):
							if not (exon_stop < new_exons[i][0] or exon_start > new_exons[i][1]):
								annotated_isoforms[transcript_name].append( i+1 )
								is_covered = False
						if is_covered:
							annotated_isoforms.pop(transcript_name)
					elif len(annotated_isoforms.keys()) > 0:
						break			
				line = gtf_fp.readline()
			
			if len(annotated_isoforms) == 0:
				continue
				
			# reformat the annotated isoforms
			filtered_transcripts = {}
			for transcript_name in annotated_isoforms.keys():
				temp = sorted(list(set(annotated_isoforms[transcript_name])))
				temp2 = numpy.zeros( len(new_exons) )
				for i in range( len(new_exons) ):
					if i+1 in temp:
						temp2[i] = 1
				filtered_transcripts[transcript_name] = temp2
				
		# get bin counts by read lengths
		bin_counts = {}
		for bin in bins:
			bin_count = bin_built[ bin ]
			for rl in bin_count.keys():
				if not rl in bin_counts.keys():
					bin_counts[ rl ] = [ bin_count[ rl ] ]
				else:
					bin_counts[ rl ].append( bin_count[ rl ] )
		# note that not all the possible read lengths may exist in this gene, so we need not generate F matrices for those non-existing ones
		existing_frag_len = [ frag_len[rl] for rl in bin_counts.keys() ]
		# build the F matrix
		if len(filtered_bins) == 1:
			F_mat = numpy.array( [ [1,1], [1,1] ] )
		else:
			if not isinstance( filtered_transcripts, dict ):
				if if_use_GC_correction:
					F_mat = build_F( 
					filtered_transcripts,        \
					filtered_bins, \
					new_exons,  \
					sorted( bin_counts.keys() ),                         \
					existing_frag_len,               \
					choose_frag_len_dist, if_use_GC_correction, chr_seqs[ gene.chromosome ] )
				else:
					F_mat = build_F( 
					filtered_transcripts,        \
					filtered_bins, \
					new_exons,  \
					sorted( bin_counts.keys() ),                         \
					existing_frag_len,               \
					choose_frag_len_dist, if_use_GC_correction )
			else:
				if if_use_GC_correction:
					F_mat = build_F( 
					filtered_transcripts.values(),        \
					filtered_bins, \
					new_exons,  \
					sorted( bin_counts.keys() ),                         \
					existing_frag_len,               \
					choose_frag_len_dist, if_use_GC_correction, chr_seqs[ gene.chromosome ] )
				else:
					F_mat = build_F( 
					filtered_transcripts.values(),        \
					filtered_bins, \
					new_exons,  \
					sorted( bin_counts.keys() ),                         \
					existing_frag_len,               \
					choose_frag_len_dist, if_use_GC_correction )
		
		#print [x[1]-x[0]+1 for x in new_exons]
		#print filtered_transcripts
		#print filtered_bins
		#print F_mat
		# put the gene model in the dictionary
		n = len( new_exons ) # number of subexons
		gene_linear_models[ gene.name ] = { 'num_subexons':n, 'chr':gene.chromosome, 'strand':gene.strand, 'subexons':new_exons, 'filtered_bins':filtered_bins, 'bin_counts':bin_counts, 'filtered_transcripts':filtered_transcripts, 'F_mat':F_mat }
	
	return gene_linear_models
#########################################################################################

#########################################################################################
def estimate_lambda( gene_linear_models, frag_len, total_num_reads, r_interp=None ): # estimate lambda by running lasso mutiple times with one half of the reads

	# lambda candidates
	lambda_candidates = [ 0.000001, 0.001, 0.01, 0.04, 0.07, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1 ]
	
	selected_lambdas = { 1:0 }
	pos_num_genes = { 1:0 }
	# store the stability measure
	q_i = {}
	
	# calculate the weight (ratio) of reads for each read length: total_num_reads / mean( total_num_reads )
	ratios = {}
	mean_total_num_reads = sum( total_num_reads.values() ) / float( len( total_num_reads.keys() ) )
	for rl in total_num_reads.keys():
		ratios[ rl ] = float( total_num_reads[ rl ] ) / mean_total_num_reads
	
	# n: number of subexons
	for gene in gene_linear_models.keys():
		n = gene_linear_models[ gene ][ 'num_subexons' ]
		if n < 2:
			continue # skip 1-subexon genes
			
		if not q_i.has_key( n ):
			q_i[ n ] = {}
		for ld in lambda_candidates:
			q_i[ n ][ ld ] = []
		
		bin_counts = gene_linear_models[ gene ][ 'bin_counts' ]
		filtered_transcripts = gene_linear_models[ gene ][ 'filtered_transcripts' ]
		F_mat = gene_linear_models[ gene ][ 'F_mat' ]
		## if there is only 1 transcript candidate or 1 bin count, skip it
		if len( filtered_transcripts ) == 1 or len( bin_counts.values()[0] ) == 1:
			continue
				
		## otherwise, do lasso estimation with various lambdas
		# 1. calculate the total number of reads in the gene for each read length, and list the reads as bin indices
		total_counts = {}
		reads_in_bin_indices = {}
		for rl in bin_counts.keys():
			total_counts[ rl ] = sum( bin_counts[ rl ] )
			reads_in_bin_indices[ rl ] = []
			for index in range( len( bin_counts[ rl ] ) ):
				reads_in_bin_indices[ rl ].extend( [ index ] * bin_counts[ rl ][ index ] )
		# 2. sample 1/2 of the reads in each run, with n_resamp runs in total.  calculate bin proportions from the sampled reads
		n_resamp = 50
		b_resamp = []
		for i in range( n_resamp ):
			temp = []
			for rl in sorted( bin_counts.keys() ):
				reads = random.sample( reads_in_bin_indices[ rl ], total_counts[ rl ]/2 )
				counts = [ 0 ] * len( bin_counts[ rl ] )
				for read in reads:
					counts[ read ] += 1
				temp.extend( [ float( x ) / sum( counts ) * ratios[ rl ]  for x in counts ] )
			b_resamp.append( temp )
		# 3. calculate the number of exons in each filtered transcript
		G_is_nexons = [ sum(x) for x in filtered_transcripts ]
		# 4. construct a weight (ratio) matrix, with same columns and different rows corresponding to different read lengths
		ratio_mat = []
		for rl in sorted( bin_counts.keys() ):
			ratio_mat.extend( [ [ ratios[ rl ] ] * len( filtered_transcripts ) ] * len( bin_counts[ rl ] ) )
		# 5. calculate the design matrix in sparse estimation
		AW = F_mat * numpy.array( [ G_is_nexons ] * len( F_mat ) ) * numpy.array( ratio_mat )
		# 6. run lasso for each lambda and each subsample
		for ld in lambda_candidates:
			q_irk = numpy.array( [ 0 ] * len( filtered_transcripts ) )
			for i in range( n_resamp ):
				results = r_interp.LassoEst( numpy.array( b_resamp[i] ), AW, ld ) # returns a 0-1 vector of len( filtered_transcripts ) indicating if each transcript is discovered or not
				q_irk += results
			pos_len = len( [ x for x in q_irk if x > 0 ] )
			if pos_len > 0:
				q_ir = sum( q_irk ) / ( 50 * float( pos_len ) )
			else:
				q_ir = 0
			q_i[ n ][ ld ].append( q_ir )
	################################# end of loop ################################	
	# average the stability measure over the genes with n subexons
	for n in q_i.keys():	
		ld_chosen = 0
		q = 0
		pos_n_genes = len( q_i[ n ][ lambda_candidates[0] ] )
		print n, pos_n_genes
		if pos_n_genes > 0:
			for ld in lambda_candidates:
				q_temp = sum( q_i[ n ][ ld ] ) / pos_n_genes
				if q_temp > q:
					q = q_temp
					ld_chosen = ld
				if q == 1:
					break			
		selected_lambdas[ n ] = ld_chosen
		pos_num_genes[ n ] = pos_n_genes
	return selected_lambdas, pos_num_genes
#########################################################################################

#########################################################################################
def estimate_transcript_freqs( gene_name, gene_linear_models, frag_len, total_num_reads, ld, r_interp=None, mode='discovery' ):
    if r_interp == None:
        r_interp = rpy.r

    #if gene_name != '"NBPF11"':
	#	return
    
    if VERBOSE:
        print( "==================NEW GENE %s ====================" % gene_name)
    

    #if len( gene.exon_bndrys ) < 15:
	#    return
    # read in subexons

    # if no linear model has been built for "gene_name" (0 bin counts), skip it
    if not gene_linear_models.has_key( gene_name ):
	    return
    
    # read in the subexons
    new_exons = gene_linear_models[ gene_name ][ 'subexons' ]

	# check to make sure we have at least 1 exon 
    # ( this should always be true )
    if len( new_exons ) == 0:
        print( "BUG PROBABLY - can't fit with 0 exons but this doesn't seem possible." )
        if PAUSE_ON_ERROR:
            raw_input("Press enter to conitnue...")
        return

	# read in ( filtered ) bins and bin counts
    filtered_bins = gene_linear_models[ gene_name ][ 'filtered_bins' ]
    bin_counts = gene_linear_models[ gene_name ][ 'bin_counts' ]

    # read in filtered transcripts
    filtered_transcripts = gene_linear_models[ gene_name ][ 'filtered_transcripts' ]

    # read in the F matrix
    F_mat = gene_linear_models[ gene_name ][ 'F_mat' ]

    # strand
    strand = gene_linear_models[ gene_name ][ 'strand' ]

    # gene name
    name = re.split('"', gene_name)[1]

    # print parameters
    print "# of possible transcripts: %d" %( len(filtered_transcripts) )
    print "# of bins: %d" %( len(filtered_bins) )
    print "# of bin count values: %d" %( len(bin_counts.values()) )

    # estimation
    if isinstance( ld, dict ) or isinstance( ld, float ):
        if isinstance( ld, dict ):
            ld_temp = ld[ len( new_exons ) ]
        else:
            ld_temp = ld
        if not isinstance( filtered_transcripts, dict ):
            freqs = r_interp.estimateTranscriptFrequencies( 
	                numpy.array(F_mat), 
	                numpy.array(bin_counts.values()),             \
	                numpy.array(filtered_bins),               \
	                numpy.array(new_exons), \
	                numpy.array(filtered_transcripts),          \
	                float(ld_temp),                       \
	                numpy.array( total_num_reads.values() ),   \
					mode
	           )
        else:
            freqs = r_interp.estimateTranscriptFrequencies( 
	                numpy.array(F_mat), 
	                numpy.array(bin_counts.values()),             \
	                numpy.array(filtered_bins),               \
	                numpy.array(new_exons), \
	                numpy.array(filtered_transcripts.values()),          \
	                float(ld_temp),                       \
	                numpy.array( total_num_reads.values() ),   \
					mode
	           )
        output_list = [] # store the output lines
        if not len(freqs) == 0:
            n_trpt = len(freqs)/3
            for i in range( n_trpt ):
                trpt = freqs[i]
                exon_idx = [ int(x) for x in re.findall(r'\d+', trpt) ]
                combined_exons = []
                for j in range(len(exon_idx)):
                    if j == 0:
                        combined_exons.append( new_exons[exon_idx[j]-1] )
                    else:
                        if new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] + 1 or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] - 1:
                            #combined_exons[ -1 ][1] = new_exons_temp[ exon_idx[j]-1 ][1]
                            combined_exons[ -1 ] = [ combined_exons[ -1 ][0], new_exons[ exon_idx[j]-1 ][1] ]
                        else:
                            combined_exons.append( new_exons[exon_idx[j]-1] )
                if not isinstance( filtered_transcripts, dict ):
                    transcript_name = name + '.' + str(i+1)
                else:
                    transcript_name = filtered_transcripts.keys()[i]

                output_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'transcript\t' + str(combined_exons[0][0]) + '\t' + str(combined_exons[-1][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + transcript_name + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )

                if strand == '+':
	                for k in range(len(combined_exons)):
	                    output_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + transcript_name + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )
                else:
	                for k in range( len(combined_exons)-1, -1, -1 ):
	                    output_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + transcript_name + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )
        return output_list
    else:
        ld_temp = ld
        freqs = r_interp.estimateTranscriptFrequencies( 
                numpy.array(F_mat), 
                numpy.array(bin_counts.values()),             \
                numpy.array(filtered_bins),               \
                numpy.array(new_exons), \
                numpy.array(filtered_transcripts),          \
                ld_temp[0],                       \
                numpy.array( total_num_reads.values() )   \
           )
        output1_list = [] # store the output lines for the smaller lambda
        if not len(freqs) == 0:
            n_trpt = len(freqs)/3
            for i in range( n_trpt ):
                trpt = freqs[i]
                exon_idx = [ int(x) for x in re.findall(r'\d+', trpt) ]
                combined_exons = []
                for j in range(len(exon_idx)):
                    if j == 0:
                        combined_exons.append( new_exons[exon_idx[j]-1] )
                    else:
                        if new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] + 1 or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] - 1:
                            #combined_exons[ -1 ][1] = new_exons_temp[ exon_idx[j]-1 ][1]
                            combined_exons[ -1 ] = [ combined_exons[ -1 ][0], new_exons[ exon_idx[j]-1 ][1] ]
                        else:
                            combined_exons.append( new_exons[exon_idx[j]-1] )

                output1_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'transcript\t' + str(combined_exons[0][0]) + '\t' + str(combined_exons[-1][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )

                if strand == '+':
	                for k in range(len(combined_exons)):
	                    output1_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )
                else:
	                for k in range( len(combined_exons)-1, -1, -1 ):
	                    output1_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )

        freqs = r_interp.estimateTranscriptFrequencies( 
                numpy.array(F_mat), 
                numpy.array(bin_counts.values()),             \
                numpy.array(filtered_bins),               \
                numpy.array(new_exons), \
                numpy.array(filtered_transcripts),          \
                ld_temp[1],                       \
                numpy.array( total_num_reads.values() )   \
           )
        output2_list = [] # store the output lines for the larger lambda
        if not len(freqs) == 0:
            n_trpt = len(freqs)/3
            for i in range( n_trpt ):
                trpt = freqs[i]
                exon_idx = [ int(x) for x in re.findall(r'\d+', trpt) ]
                combined_exons = []
                for j in range(len(exon_idx)):
                    if j == 0:
                        combined_exons.append( new_exons[exon_idx[j]-1] )
                    else:
                        if new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] + 1 or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] or new_exons[ exon_idx[j]-1 ][0] == new_exons[ exon_idx[j-1]-1 ][1] - 1:
                            #combined_exons[ -1 ][1] = new_exons_temp[ exon_idx[j]-1 ][1]
                            combined_exons[ -1 ] = [ combined_exons[ -1 ][0], new_exons[ exon_idx[j]-1 ][1] ]
                        else:
                            combined_exons.append( new_exons[exon_idx[j]-1] )

                output2_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'transcript\t' + str(combined_exons[0][0]) + '\t' + str(combined_exons[-1][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )

                if strand == '+':
	                for k in range(len(combined_exons)):
	                    output2_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )
                else:
	                for k in range( len(combined_exons)-1, -1, -1 ):
	                    output2_list.append( gene_linear_models[ gene_name ][ 'chr' ] + '\t' + 'SLIDE' + '\t' + 'exon\t' + str(combined_exons[k][0]) + '\t' + str(combined_exons[k][1]) + '\t.\t' + strand + '\t.\t' + 'gene_id "' + name + '"; ' + 'transcript_id "' + name + '.' + str(i+1) + '"; exon_number "' + str(k+1) + '"; RPKM "' + freqs[n_trpt+i] + '"; frac "' + freqs[2*n_trpt+i] + '";\n'  )
		
        return output1_list, output2_list
	
#########################################################################################    

#########################################################################################
def usage():
    print( "#"*40 )
    print( "USAGE:" )
    print( "\tpython slide.py input.gtf reads.bam output.gtf --read_type --mode --thread --total_num_reads --choose_frag_len_dist --if_calculate_frag_len_param --if_estimate_lambda --user_specified_lambda --if_use_GC_correction --fasta_dir --gene_name\n" )
    print( "\t###Required###")
    print( "\tinput.gtf: an input gtf file containing annotated or de novo assembled transcripts and exons;" )
    print( "\treads.bam: an input bam file containing mapped paired-end/single-end RNA-Seq reads, requiring associated .bami file;" )
    print( "\toutput.gtf: an output gtf file containing discovered mRNA transcripts/isoforms with estimated abundance;\n" )
    print( "\t--read_type: values in {single-end, paired-end, mixed, unknown} [e.g. --read_type single-end];\n" )
    print( "\t###Optional###")
    print( "\t--mode: values in {discovery, estimation}; 'discovery' means that SLIDE first discovers mRNA isoforms and then estimate their abundance; 'estimation' means that SLIDE directly estimates the abundance of isoforms in the 'input.gtf' file [e.g. --mode discovery]; DEFAULT = discovery;\n" )
    print( "\t--thread: the number of threads/processes used in parallel computing [e.g. --thread 2]; DEFAULT = 1;" )
    print( "\t--total_num_reads: the total number of reads if there is only one read length in 'reads.bam' [e.g. --total_num_reads 6000000]; for multiple read lengths, please ignore this option; DEFAULT = NULL;" )
    print( "\t--choose_frag_len_dist: choose the fragment length distribution as truncated Exponential or truncated Normal with range (5% percentile, 95% percentile), mean and sd given by the '--if_calculate_frag_len_param' option [e.g. --choose_frag_len_param Exponential/Normal]; DEFAULT = Exponential;" )
    print( "\t--if_calculate_frag_len_param: whether or not to calculate empirical fragment length distributional parameters (5% percentile, 95% percentile, mean, sd) from given reads in 'reads.bam' [e.g. --if_calculate_frag_len_param TRUE]; when FALSE, use default parameters (100, 300, 200, 30); DEFAULT = FALSE;" )
    print( "\t--if_estimate_lambda: whether or not to estimate lambda by a stability criterion (Meinshausen and Buhlmann, 2010) [e.g. --if_estimate_lambda TRUE]; when FALSE, use default lambda values 0.01 and 0.2; DEFAULT = FALSE;" )
    print( "\t--user_specified_lambda: a lambda value input by the user [e.g. --user_specified_lambda 0.3]; if provided, it will be used and override the '--if_estimate_lambda' option;" )
    print( "\t--if_use_GC_correction: whether or not to use GC content information to model the starting position distribution of cDNA fragments [e.g. --if_use_GC_correction TRUE]; when FALSE, model the starting position distribution as uniform in a transcript; DEFAULT = FALSE;" )
    print( "\t--fasta_dir: the directory of genome fasta sequences, one fasta file file for each chromosome; note that the file names [e.g. chr1.fa] must match the chromosome names in 'input.gtf'; REQUIRED when '--if_use_GC_correction' is TRUE;" )
    print( '\t--gene_name: if specified, only isoforms of the gene with the specified name will be returned [e.g. "CG12370"];  Please note that the specified gene name should be in "input.gtf";' )
    print( "#"*40 )
    sys.exit( 1 )
    
    
#########################################################################################
################################### MAIN ####################################
#########################################################################################

if __name__ == '__main__':

    ### read in arguments from the command line
    if len( sys.argv ) < 5:
        usage()   
    if len( sys.argv ) >= 5:
        if ( not len( sys.argv ) % 2 == 0 ) or ( not sys.argv[1].endswith('.gtf') ) or ( not sys.argv[2].endswith('.bam') ) or ( not sys.argv[3].endswith('.gtf') or ( not sys.argv[4].startswith('--read_type') ) ):
            usage()
            
    if len( sys.argv ) > 4:       
        for i in range(4, len(sys.argv), 2):
            if sys.argv[i].startswith('--'):
                option = sys.argv[i][2:]
                if option == 'read_type':
                    if (sys.argv[i+1] != "paired-end" and sys.argv[i+1] != "single-end" and sys.argv[i+1] != "mixed" and sys.argv[i+1] != "unknown"):
                        print "The 'read_type' value should be in {paired-end, single-end, mixed, unknown}"
                        sys.exit(1)
                    read_type = sys.argv[i+1]
                if option == 'mode':
                    if (sys.argv[i+1] != "discovery" and sys.argv[i+1] != "estimation"):
                        print "The 'mode' value should be in {discovery, estimation}"
                        sys.exit(1)
                    mode = sys.argv[i+1]
                if option == 'thread':
                    num_threads = int( sys.argv[i+1] )
                if option == 'total_num_reads':
                    total_num_reads = int( sys.argv[i+1] )
	            if option == 'choose_frag_len_dist':
	                frag_len_dist = sys.argv[i+1]
	                if frag_len_dist != 'Exponential' and frag_len_dist != 'Normal':
	                    print "Wrong distribution: %s. Please specify it as 'Exponential' or 'Normal'" %frag_len_dist
	                    sys.exit( 1 )
                if option == 'if_calculate_frag_len_param':
                    if_calculate_frag_len_param = (sys.argv[i+1] == 'TRUE') or (sys.argv[i+1] == 'True') or (sys.argv[i+1] == 'T') or (sys.argv[i+1] == 'true')
                if option == 'if_estimate_lambda':
                    if_estimate_lambda = (sys.argv[i+1] == 'TRUE') or (sys.argv[i+1] == 'True') or (sys.argv[i+1] == 'T') or (sys.argv[i+1] == 'true')
                if option == 'user_specified_lambda':
                    ld = float( sys.argv[i+1] )
                if option == 'if_use_GC_correction':
                    if_use_GC_correction = (sys.argv[i+1] == 'TRUE') or (sys.argv[i+1] == 'True') or (sys.argv[i+1] == 'T') or (sys.argv[i+1] == 'true')
                if option == 'fasta_dir':
                    fasta_dir = sys.argv[i+1]
                if option == 'gene_name':
                    gene_name_specified = '"' + sys.argv[i+1] +  '"'
            # if the option argument doesn't start with "--", quit
            else:
                usage()

    # open R source code
    pathname = os.path.dirname( sys.argv[0] )
    r.source( os.path.join( pathname, "FUNCTIONS.R" ) )

##########################################################################################    
    ### handle the required arguments ###

	# load the bam file (RNA-Seq reads)
    samfile = pysam.Samfile( sys.argv[2], "rb" )

    # make sure if the chromosome names in the bam file start with "chr" or not
    for read in samfile.fetch():
        samchr = samfile.getrname( read.tid )
        break
    global ifchr
    ifchr = samchr.startswith("chr")

    # get the exon boundaries (annotations)
    gtf_fp = open( sys.argv[1] )
    genes = GeneBoundaries( gtf_fp )
    
    # output gtf
    output_filename_split = re.split("\.gtf", sys.argv[3])

##########################################################################################    
    ### handle the optional arguments ###

    # if the mode is not specified, use the default mode = 'discovery'
    try:
        mode
    except NameError:
        mode = 'discovery'
    print "The mode is " + mode + "."

    # if the number of threads is not specified, use the default 1
    try:
        num_threads
    except NameError:
        num_threads = 1
    print "The number of threads is %d." % num_threads
    
    # count total number of reads for each read length
    try:
        total_num_reads
    except NameError:
        total_num_reads = {}
        for read in samfile:
            if not total_num_reads.has_key(read.qlen):
                if read_type == "paired-end":	
                    total_num_reads[read.qlen] = 0.5
                else:
                    total_num_reads[read.qlen] = 1
            else:
                if read_type == "paired-end":	
                    total_num_reads[read.qlen] += 0.5
                else:
                    total_num_reads[read.qlen] += 1
    if not isinstance(total_num_reads, dict):
        temp = total_num_reads
        total_num_reads = {}
        for read in samfile:
			total_num_reads[read.qlen] = temp
			break
    for rl in sorted(total_num_reads.keys()):
        print "There are %d reads with length %d." % ( total_num_reads[rl], rl ) 


    # if the user doesn't specify the fragment length distribution, use the default distribution 'Exponential'
    try:
        choose_frag_len_dist
    except NameError:
        choose_frag_len_dist = 'Exponential'
    print "Use '%s' as the fragment length distribution." % choose_frag_len_dist

    # if the user chooses not to calculate the fragment length distribution, use the default distributional parameters
    try:
        if_calculate_frag_len_param
    except NameError:
        if_calculate_frag_len_param = False
    if if_calculate_frag_len_param:
        print("Calculate parameters of the empirical fragment length distribution:")
        frag_len = get_frag_len( genes.values(), samfile )
    else:
        print("Use default parameters for the fragment length distribution:")
        frag_len = {}
        for key in total_num_reads.keys():
            frag_len[key] = [100,300,200,30]
    for rl in sorted(frag_len.keys()):
        print "\tFor reads of length %d: the 5%% percentile is %f, the 95%% percentile is %f, the mean is %f, and the sd is %f." % ( rl, float(frag_len[ rl ][ 0 ]), float(frag_len[ rl ][ 1 ]), float(frag_len[ rl ][ 2 ]), float(frag_len[ rl ][ 3 ]) )
        
    # if the user doesn't specify whether or not to estimate lambda, use the default lambda values
    try:
        if_estimate_lambda
    except NameError:
        if_estimate_lambda = False
    
    # whether or not to use GC content correction in the modeling of the conditional probablity matrix; if use GC contect correction, read in fasta files, calculate GC contents in 50 bp windows on each chromosome
    try:
        if_use_GC_correction
    except NameError:
        if_use_GC_correction = False

    if not if_use_GC_correction:
        print "GC content correction is not applied."
    else:
        try:
            fasta_dir
        except NameError:
            print "Please specify the directory of genome fasta files"
            usage()
        if not os.path.exists( fasta_dir):
            print "The directory of genome fasta files doesn't exist.  Please specify it correctly."
            usage()
        print "GC content correction is applied."
        fasta_dir = os.path.abspath( fasta_dir )
        fasta_list = os.listdir( fasta_dir )
        chr_seqs = {} # a dictionary containing chromosome sequences
        for name in fasta_list:
            chr_name = name.split(".")[0]
            if chr_name.startswith("chr") and not ifchr:
                chr_name = chr_name[0][3:]
            elif not chr_name.startswith("chr") and ifchr:
                chr_name = "chr" + chr_name 
            seq_file = open( os.path.join( fasta_dir, name ) )
            chr_seqs[ chr_name ] = ''.join([ seq.strip() for seq in seq_file.readlines()[1:]])
            seq_file.close()

    try:
        gene_name_specified
    except NameError:
        gene_name_specified = ""
       
##########################################################################################    
    ### get linear model components for every gene ###
    print("Build linear models for all the genes in '%s'..." % sys.argv[1])
    if if_use_GC_correction:
        gene_linear_models = get_linear_models( genes.values(), samfile, read_type, frag_len, total_num_reads, choose_frag_len_dist, if_use_GC_correction, gene_name_specified, chr_seqs, mode=mode )
    else:
        gene_linear_models = get_linear_models( genes.values(), samfile, read_type, frag_len, total_num_reads, choose_frag_len_dist, if_use_GC_correction, gene_name_specified, mode=mode )

	
    ### if the user specifies an lambda value, input it into SLIDE; otherwise, if the user chooses to estimate lambda, estimate lambda values by a stability criterion; otherwise, use the default lambda values (0.01 for more isoforms with lower confidence, 0.2 for fewer isoforms with higher confidence)
    try:
        ld
    except NameError:
        if if_estimate_lambda:
            print("Calculate lambda values for sparse estimation by a stability criterion...")
            ld, ngenes = estimate_lambda( gene_linear_models, frag_len, total_num_reads, r )
            for n in sorted(ld.keys()):
                if n < 2:
                    print "For 1-exon genes, there is no need to use sparse estimation."
                else:
                    print "For %d-exon genes, lambda = %f (selected from %d genes)" % ( n, ld[n], ngenes[n] )
            output_filename = output_filename_split[0]+".gtf"
            output = open( output_filename, "w" )
        elif mode=='discovery':
            print("Use default lambda values: lambda = 0.01 for returning a larger number of isoforms including lowly confident ones; lambda = 0.2 for returning a smaller number of highly confident isoforms.")
            ld = [ 0.01, 0.2 ]
            output1_filename = output_filename_split[0]+"_more.gtf"
            output2_filename = output_filename_split[0]+"_fewer.gtf"
            output1 = open( output1_filename, "w" )
            output2 = open( output2_filename, "w" )
        else:
			ld = 0.0

    if isinstance( ld, float ):
        if mode=='discovery':
            print "Use user-specified lambda value: %f" % ld
        else:
            print "This is to estimate the abundance of annotated isoforms only."
        output_filename = output_filename_split[0]+".gtf"
        output = open( output_filename, "w" )
    if mode=='discovery':
        print "###Warning: The isoform discovery results are highly dependent on the lambda values.  Since we don't know the biological truth (i.e. true isoforms), we encourage you to judge the results based on your biological knowledge/experience.  If you find that there are too many short isoforms as fragments of the 'true isoforms' in the results, you may increase the lambda value to obtain a smaller number of longer isoforms.  On the contratry, if you find that there are some missing isoforms in the results, you may decrease the lambda value to obtain a larger number of possible isoforms.  We will highly appreciate your feedback on the results of SLIDE."

    ### main estimation function
    def estimate_genes_expression( gene_names ):
        # build a dictionary for output
        if isinstance(ld, dict) or isinstance(ld, float):
	        output_dict = {}
        else:
            output1_dict = {}
            output2_dict = {}
	      
        for gene_name in gene_names:
            if not gene_linear_models.has_key( gene_name ):
                continue
            
            if not gene_name_specified == "":
                if gene_name != gene_name_specified:
                    continue
						
            # get the chromosome.  will use the chromosome as a key in output_dict
            chrom = gene_linear_models[ gene_name ][ 'chr' ]
            # get the gene start position.  will use the start position as a new key in "output_dict[ chrom ]"
            start_pos = gene_linear_models[ gene_name ][ 'subexons' ][0][0]
            if isinstance(ld, dict) or isinstance(ld, float):
                if not output_dict.has_key( chrom ):
		            output_dict[ chrom ] = {}
                output_dict[ chrom ][ start_pos ] = estimate_transcript_freqs( gene_name, gene_linear_models, frag_len, total_num_reads, ld, r, mode=mode )  
            else:
                if not output1_dict.has_key( chrom ):
		            output1_dict[ chrom ] = {}
		            output2_dict[ chrom ] = {}
                output1_dict[ chrom ][ start_pos ], output2_dict[ chrom ][ start_pos ] = estimate_transcript_freqs( gene_name, gene_linear_models, frag_len, total_num_reads, ld, r, mode=mode )

        if isinstance(ld, dict) or isinstance(ld, float):
            return(output_dict)
            #for chrom in sorted( output_dict.keys() ):
            #    for start_pos in sorted( output_dict[ chrom ].keys() ):
            #        for line in output_dict[ chrom ][ start_pos ]:
            #            output.write( line )
            #output.close()
        else:
            return(output1_dict, output2_dict)
            #for chrom in sorted( output1_dict.keys() ):
            #    for start_pos in sorted( output1_dict[ chrom ].keys() ):
            #        for line in output1_dict[ chrom ][ start_pos ]:
            #            output1.write( line )
            #        for line in output2_dict[ chrom ][ start_pos ]:
            #            output2.write( line )
            #output1.close()
            #output2.close()

        #return

##### multiple threads #####
    if num_threads > 1:
        pool = multiprocessing.Pool(processes=num_threads)
        keys = tuple(genes.keys())    
        size = len( keys )/num_threads
        
        keys_list = []
        for loop in xrange( num_threads-1 ):
            keys_list.append( keys[loop*size:(loop+1)*size]  )
        keys_list.append( keys[(num_threads-1)*size:]  )
        results = pool.map( estimate_genes_expression, keys_list )
        if isinstance(ld, dict) or isinstance(ld, float):
            output_dict = {}
            for output_dict_part in results:
                for chrom in output_dict_part.keys():
                    if chrom in output_dict.keys():
                        output_dict[chrom].update( output_dict_part[chrom] )
                    else:
                        output_dict[chrom] = output_dict_part[chrom]	
            for chrom in sorted( output_dict.keys() ):
	            for start_pos in sorted( output_dict[ chrom ].keys() ):
	                for line in output_dict[ chrom ][ start_pos ]:
	                    output.write( line )
            output.close()
        else:
            output1_dict = {}
            output2_dict = {}
            for output1_dict_part, output2_dict_part in results:
                for chrom in output1_dict_part.keys():
                    if chrom in output1_dict.keys():
                        output1_dict[chrom].update( output1_dict_part[chrom] )
                        output2_dict[chrom].update( output2_dict_part[chrom] )
                    else:
                        output1_dict[chrom] = output1_dict_part[chrom]
                        output2_dict[chrom] = output2_dict_part[chrom]
            for chrom in sorted( output1_dict.keys() ):
	            for start_pos in sorted( output1_dict[ chrom ].keys() ):
	                for line in output1_dict[ chrom ][ start_pos ]:
	                    output1.write( line )
	                for line in output2_dict[ chrom ][ start_pos ]:
	                    output2.write( line )
            output1.close()
            output2.close()
    else:
        if isinstance(ld, dict) or isinstance(ld, float):
            output_dict = estimate_genes_expression( genes.keys() )
            for chrom in sorted( output_dict.keys() ):
                for start_pos in sorted( output_dict[ chrom ].keys() ):
                    for line in output_dict[ chrom ][ start_pos ]:
                        output.write( line )
            output.close()
        else:
            output1_dict, output2_dict = estimate_genes_expression( genes.keys() )
            for chrom in sorted( output1_dict.keys() ):
                for start_pos in sorted( output1_dict[ chrom ].keys() ):
                    for line in output1_dict[ chrom ][ start_pos ]:
                        output1.write( line )
                    for line in output2_dict[ chrom ][ start_pos ]:
                        output2.write( line )
            output1.close()
            output2.close()

    """
    # if we want to keep track of the insert distribution
    
    with open( "insert_dist.txt", "w" ) as fp:
        print >> fp, "insert,read_len,cnt"
        for key in sorted( insert_dist ):
            print >> fp, ",".join( ( str(key[0]), str(key[1]), str(insert_dist[key]) ) )
    """
    samfile.close()
    gtf_fp.close()
    
    
 

   
    
