# ReXpress Version 0.9.1
# http://bio.math.berkeley.edu/ReXpress
# Copyright 2013 Adam Roberts and Lorian Schaeffer. All Rights Reserved

import sys
import numpy
import pysam
import itertools
import filecmp
import argparse
import subprocess
import os
import time
import pickle
import tempfile
import shutil

class SamIterator(object):
	'''
	'''

	def __init__(self, sam0, sam1, return_all):
		'''
		'''
		self.files = sam0, sam1
		self.iterable = (sam0.fetch(until_eof=True),sam1.fetch(until_eof=True))
		# Initialize with first read from each file
		self.data = [self.iterable[0].next(),self.iterable[1].next()]
		self.current_match = self.data[1].qname
		self.matched = False
		self.return_all = return_all

	def __iter__(self):
		return self

	def next(self):
		'''

		'''
		while True:
			if not self.matched and self.data[0] == None:
				if self.data[1] != None: #for debugging assertion error
					print "ERROR: Your BAM file is either not consistantly sorted, or does not include all reads."
				assert self.data[1] == None
				raise StopIteration
			elif self.data[1] == None and not self.return_all:
				raise StopIteration
			elif self.data[1] == None:
				self.matched = False
				return self.update_return_items(0)
			else:
				if self.data[0].qname == self.current_match:
					self.matched = True
					return self.update_return_items(0)
				elif self.matched and self.data[1].qname == self.current_match:
					return self.update_return_items(1)
				elif self.matched: #reset when no longer actually matched
					self.matched = False
					self.current_match = self.data[1].qname
				elif self.return_all:
					return self.update_return_items(0)
				else: #skip reads that don't appear in both files
					self.update_return_items(0)

	def update_return_items(self,i):
		'''
		Returns stored read (and indicator of what file it is from). Updates iterator from same
		file once to contain new data. If file has no more reads, sets data to None.
		'''
		return_data = self.data[i]

		try:
			self.data[i] = self.iterable[i].next()
		except StopIteration:
			self.data[i] = None

		return return_data,i

class SamIteratorSolo(object):
	'''
	'''

	def __init__(self, sam0):
		'''
		'''
		self.files = sam0
		self.iterable = sam0.fetch(until_eof=True)
		# Initialize with first read from each file
		self.data = self.iterable.next()

	def __iter__(self):
		return self

	def next(self):
		while True:
			if self.data == None:
				raise StopIteration
			else:
				return self.update_return_items()

	def update_return_items(self):
		'''
		Returns stored read (and indicator of what file it is from). Updates iterator from same
		file once to contain new data. If file has no more reads, sets data to None.
		'''
		return_data = self.data

		try:
			self.data = self.iterable.next()
		except StopIteration:
			self.data = None

		return return_data,0

class Fasta(object):
	'''
	Compares two fasta files and uses differences to create new fasta files
	'''
	def __init__(self,original_fasta_filename,new_fasta_filename):
		'''
		Loads two fasta files and stores sequence/name information in dictionary format. Finds
		sequences that have been added, deleted, modified (same name, different sequence) and
		renamed (same sequence, different name).
		'''
		self.o_seqs = self.load_fasta(original_fasta_filename)
		self.n_seqs = self.load_fasta(new_fasta_filename)

		# Find new, deleted, and modified sequences
		self.new = set(self.n_seqs.keys()) - set(self.o_seqs.keys())
		self.deleted = set(self.o_seqs.keys()) - set(self.n_seqs.keys())

		self.modified = set()
		for t in set(self.o_seqs.keys()) & set(self.n_seqs.keys()):
			if self.o_seqs[t] != self.n_seqs[t]:
				self.modified.add(t)

		# Find sequences that are the same, but were renamed
		orig_inv = self.invert_fasta(self.o_seqs) #deleted after initialization
		self.renamed = set() #contains new names of renamed sequences
		self.renamed_dict = {} #contains dictionary of new names, indexed by old names, of renamed sequences
		for n in self.new:
			if self.n_seqs[n] in orig_inv and orig_inv[self.n_seqs[n]] in self.deleted:
				self.renamed.add(n)
				self.renamed_dict[orig_inv[self.n_seqs[n]]] = n

		# Remove renamed sequences from new and deleted sequences
		self.new = self.new - self.renamed
		for old_name in self.renamed_dict:
			if old_name in self.deleted:
				self.deleted.remove(old_name)

	def load_fasta(self, fasta_filename):
		'''
		Returns dictionary of sequences indexed by name from fasta_filename
		'''
		seqs = {}
		f = open(fasta_filename, 'r')

		targ = ''
		seq = ''
		for line in f:
			if line[0] == '>':
				if targ != '':
					seqs[targ] = seq
				targ = line[1:].strip()
				if targ in seqs:
					print "WARNING: Target '{0}' is repeated in FASTA file.".format(targ)
				seq = ''
			else:
				seq += line.strip()
		if targ != '':
			seqs[targ] = seq
		return seqs

	def component_fastas(self,components,references, path=""):
		'''
		Creates a fasta file for every component.
			components - list of sets containing tid's
			references - list of reference names, such that the list index is the tid
		'''
		for i,c in enumerate(components): #component fasta filenames should line up correctly with bams
			component_seqs = {}
			for tid in c:
				try:
					component_seqs[references[tid]] = self.n_seqs[references[tid]]
				except:
					component_seqs[self.renamed_dict[references[tid]]] = self.n_seqs[self.renamed_dict[references[tid]]]
			self.output_fasta("component_" +str(i)+ ".fasta", component_seqs, path)

	def make_dict(self, seq_names):
		'''
		Converts set of sequence names into dictionary of sequences, indexed by name
		'''
		seq_dict = {}
		for name in seq_names:
			seq_dict[name] = self.n_seqs[name]
		return seq_dict

	def output_fasta(self, filename, sequences, path="", chars_per_line=80):
		'''
		Writes sequence dictionaries to fasta filename.
			sequences - can be set of names, or dict of names/seqs
		'''
		if type(sequences) == list or type(sequences) == set:
			fasta_seqs = self.make_dict(sequences)
		else:
			fasta_seqs = sequences
		if fasta_seqs == {}: #empty file
			return 0
		o = open(os.path.join(path, filename),'w')
		for name,seq in fasta_seqs.iteritems():
			o.write('>%s\n' % name)
			o.write('%s\n' % seq[0:chars_per_line])
			while len(seq) > chars_per_line:
				seq = seq[chars_per_line:]
				o.write('%s\n' % seq[0:chars_per_line])
		o.close()
		return 1

	def invert_fasta(self,f):
		'''
		Inverts dictionary f
		'''
		inv = dict()
		for n,s in f.iteritems():
			if not s in inv:
				inv[s] = n
		return inv

class TargetGraph(object):
	'''
	Undirected graph storing the relationship between targets. Nodes are indexed by the target ID.
	'''
	def __init__(self, N):
		'''
		Constructor initializes the TargetGraph with N disconnected targets.
			N - The number of targets in the TargetGraph (fixed).
		'''
		self.N = N # number of nodes in graph
		self.edges = {}
		self.neighbors = [set() for u in xrange(N)] # neighbors directly adjacent to index
		self.total_weight = 0 #tracks total weight of all edges for later averaging
		self.ignore = set() #records reads that align to too many targets and should be ignored
		self.has_reads = set() #tracks of what targets have aligned reads

	def edge_weight(self, u, v):
		'''
		Accessor for the edge weight between two nodes (Target IDs). Returns 0 if no edge exists.
			u - Target ID for first node.
			v - Target ID for second node.
		'''
		u, v = min(u, v), max(u, v)
		return self.edges.get((u,v), 0)

	def weighted_degree(self, u):
		degree = 0
		for v in self.neighbors[u]:
			degree += self.edge_weight(u, v)
		return degree

	def add_edge(self, u, v, w=1):
		'''
		Adds an edge between the two nodes with the given weight. Increments the weight if an edge
		already exists.
			u - Target ID for first node.
			v - Target ID for second node.
			w - Weight of edge to add.
		'''
		assert(u < self.N and v < self.N)
		if w <= 0:
			return -1
		self.has_reads.add(u)
		self.has_reads.add(v)
		if u == v: #record that self node has reads, but don't make self-edge
			return
		u, v = min(u,v), max(u,v)
		self.edges[(u,v)] = self.edges.get((u,v), 0) + w
		self.neighbors[u].add(v)
		self.neighbors[v].add(u)
		self.total_weight += w #increment total weight counter

	def component(self, target):
		'''
		Returns the set of all nodes (the component) that can be reached by following
		the neighbor links of the input target, including itself.
		'''
		visited = set([target])
		unseen = self.neighbors[target] - visited # set of all neighbors left to visit

		while len(unseen)>0:
			current_target = unseen.pop()
			if current_target not in visited:
				visited.add(current_target)
				unseen |= self.neighbors[current_target] - visited # union of both sets, minus visited
		return visited

	def all_components(self):
		'''
		Computes and returns a list of all components (sets of node indices) in the graph as well
		as a map from target nodes to the index of the component it is a member of
		'''
		visited = set([])
		components = []
		component_map = {}
		for current_target in xrange(self.N):

			if current_target not in visited:
				current_component = self.component(current_target)
				visited |= current_component # union of both sets
				components.append(current_component)
				for j in components[-1]:
					component_map[j] = len(components)

		return components, component_map

	def weight_map(self):
		weight_list = {}
		for node1,node2 in itertools.combinations(range(self.N), 2):
			weight_list[node1,node2] = self.edge_weight(node1,node2)
			if weight_list[node1,node2] == 0:
				del weight_list[node1,node2]
		return weight_list

	def remove_nodes(self, to_remove):
		'''
		Returns a copy of the tree with the given nodes removed.
		'''
		id_map = -1*numpy.ones(self.N)
		next_id = 0
		for i in xrange(self.N):
			if not i in to_remove:
				id_map[i] = next_id
				next_id += 1

		T = TargetGraph(next_id)
		for (u,v),w in self.edges.iteritems():
			if u not in to_remove and v not in to_remove:
				T.add_edge(int(id_map[u]), int(id_map[v]), w)

		T.ignore = self.ignore.copy()

		return T

class Partition(object):
	'''
	The Partition class keeps track of a partition of a TargetGraph made up of a subset of the targets
	in the graph. The weight of the cut between the partition in the rest of the TargetGraph is kept
	along with the set of neighbors of the partition and the sum of the weight of cut edges to each
	of those neighbors.
'''
	def __init__(self, T, B):
		'''
		Partition constructor adds initial set of targets.
		 T - TargetGraph object to partition.
		 B - Initial set of targets in partition. Must be in tid form, not reference name form.
		'''
		self.T = T
		self.B = set() #will eventually contain all targets in the partition
		self.cut_weight = 0
		self.incidence_weight = 0
		self.neighbors = set()
		self.neighbor_weights = {}
		for node in B:
			self.add_greedy_target(node)

	def add_greedy_target(self, u):
		'''
		Adds the target to the partition, updating cut weights, neighbor sets, and neighbor weights.
			u - The ID of the target (node) to add.
		'''
		if u in self.B:
			raise AssertionError("{0} is already a member of the partition {1}".format(u,self.B))

		# Get new neighbors and add to current neighbors
		new_neighbors = self.T.neighbors[u] - self.B - self.neighbors
		self.neighbors.update(new_neighbors)

		# Remove u from neighbors
		if u in self.neighbors:
			self.cut_weight -= self.neighbor_weights[u]
			del self.neighbor_weights[u]
			self.neighbors.remove(u)

		# Look for edges between u and all neighbors
		for neighbor in self.neighbors:
			weight = self.T.edge_weight(neighbor,u)
			self.cut_weight += weight
			self.incidence_weight += weight
			try:
				self.neighbor_weights[neighbor] += weight
			except:
				self.neighbor_weights[neighbor] = weight

		# Add u to the partition
		self.B.add(u)

	def optimal_greedy_neighbor(self):
		'''
		Finds the neighbor that, when added, will lead to the smallest cut density. Returns -1 if no neighbors exist.
		'''

		best_density = numpy.Inf
		best_neighbor = -1
		for u in self.neighbors:
			w_degree = self.T.weighted_degree(u)
			density = float(self.cut_weight + w_degree - 2*self.neighbor_weights[u]) / (self.incidence_weight + w_degree - self.neighbor_weights[u])
			if density < best_density:
				best_density = density
				best_neighbor = u

		return best_neighbor

	def max_weight_neighbor(self):
		'''
		Finds and returns the target ID of the neighbor to the partition with the maximum edge weight
		crossing the cut between the partition and the rest of the TargetGraph. Returns -1 if no neighbors exist.
		'''
		if self.neighbor_weights == {}:
			return -1
		return max(self.neighbor_weights, key=self.neighbor_weights.__getitem__) #returns neighbor with greatest weight, not actual weight


	def random_neighbor_by_weight(self):
		'''
		Finds and returns the target ID of the neighbor to the partition with the maximum edge weight
		crossing the cut between the partition and the rest of the TargetGraph
		'''
		if self.neighbor_weights == {}: #no neighbors
			return -1

		r = numpy.random.random_sample()

		for neighbor,weight in self.neighbor_weights.iteritems():
			weighted = float(weight)/float(self.cut_weight)
			if r < weighted:
				break
			r = r - weighted

		return neighbor

	def cut_density(self):
		return float(self.cut_weight) / self.incidence_weight

	def __len__(self):
		'''
		Accessor for the number of targets in the partition.
		'''
		return len(self.B)

def greedy_partitioner(G, B, density_threshold):
	'''
	Finds a partition in G containing the targets in B with a cut weight no greater than the
	threshold. Builds the partition by greedily adding the neighbor with the largest sum of edge
	weights between it and targets in the partition.
		G - The TargetGraph to partition.
		B - The set of targets (IDs) that must be in the partition.
		density_threshold - The maximum cut density allowed for the partition.
	'''
	P = Partition(G, B)
	while P.cut_density() > density_threshold and len(P.B) <= G.N and not(P.neighbor_weights == {}):
		u = P.max_weight_neighbor()
		P.add_greedy_target(u)

	return P

def random_greedy_partitioner(G, B, density_threshold, N):
	'''
	Finds N random greedy partitions containing the set of targets in B and returns the one with the
	smallest cardinality whose cut weight is below the threshold. If multiple minimum-cardinality
	partitions are found, the one with the lowest cut density is returned.
		G - The TargetGraph to partition.
		B - The set of targets (IDs) that must be in the partition.
		density_threshold - The maximum cut density allowed for the partition.
		N - The number of random partitions to sample.
	'''

	best_P = None
	for i in xrange(N):
		P = Partition(G, B)
		while P.cut_density() > density_threshold and len(P.B) <= G.N and not(P.neighbor_weights == {}):
			u = P.random_neighbor_by_weight()
			P.add_greedy_target(u)

		if not(best_P):
			best_P = P
		elif len(P) < len(best_P) or (len(P) == len(best_P) and P.cut_density() < best_P.cut_density()):
			best_P = P

	return best_P

def build_graph_single(bam_filename, max_alignments):
	samfile = pysam.Samfile(bam_filename,"rb")
	samGraph = TargetGraph(samfile.nreferences)

	last_read_name = ''
	current_targets = set()
	for alignedread in samfile:
		if alignedread.tid != -1 and (not alignedread.is_paired or alignedread.is_proper_pair):

			if last_read_name == alignedread.qname:
				current_targets.add(alignedread.tid)
			else:
				if len(current_targets) < max_alignments:
					for u,v in itertools.combinations(current_targets, 2):
						samGraph.add_edge(u, v)
				else:
					samGraph.ignore.add(last_read_name)
				current_targets = set([alignedread.tid]) #reset list of transcripts that current read aligns to
				last_read_name = alignedread.qname

	return samGraph

# Updates given TargetGraph from new reads in new bam file
def update_graph(samGraph, original_bam_filename,new_bam_filename,max_alignments):
	original_sam = pysam.Samfile(original_bam_filename,"rb")
	new_sam = pysam.Samfile(new_bam_filename,"rb")

	assert(samGraph.N == original_sam.nreferences)
	samGraph.N += new_sam.nreferences
	samGraph.neighbors += [set() for i in xrange(new_sam.nreferences)]

	new_sam_offset = original_sam.nreferences
	last_read_name = None
	current_targets = [set(),set()] #old,new targets that all align to the current read
	buildIterator = SamIterator(original_sam, new_sam, False)

	for alignedread,which_file in buildIterator:
		if alignedread.tid != -1 and not(alignedread.qname in samGraph.ignore) and (not alignedread.is_paired or alignedread.is_proper_pair):
			offset = new_sam_offset if which_file==1 else 0

			if last_read_name == alignedread.qname:
				current_targets[which_file].add(alignedread.tid +offset)
				if len(current_targets[0]) + len(current_targets[1]) >	max_alignments:
					samGraph.ignore.add(last_read_name)
			else:
				if not(last_read_name in samGraph.ignore): #avoid reads that map to tons of transcripts
					for u,v in itertools.combinations(current_targets[1], 2):
						samGraph.add_edge(u, v)
					for u,v in itertools.product(current_targets[0], current_targets[1]):
						samGraph.add_edge(u, v)
					for u in current_targets[0] | current_targets[1]:
						samGraph.add_edge(u, u) #record all targets that have reads
				current_targets = [set(),set()]
				current_targets[which_file].add(alignedread.tid +offset) #reset list of transcripts that current read aligns to
				last_read_name = alignedread.qname

	return samGraph

# Creates a SAM for each component and populates it with the relevant alignments. Also creates a
# new SAM file for the updated targets.
def extract_component_alignments(components, original_bam_filename, update_bam_filename="", path="", temppath="", reads_to_ignore=[], tids_to_remove=set(), renamed_references=dict()):

	if update_bam_filename == "":
		update_exists = False
	else:
		update_exists = True

	original_sam = pysam.Samfile(original_bam_filename, "rb")
	updated_bam_filename = "updated_bamfile.bam"

	if update_exists:
		update_sam = pysam.Samfile(update_bam_filename, "rb")
	else: #no updates, no deletions, no renames -> just copy the previous file
		if renamed_references == {} and tids_to_remove == []:
			shutil.copyfile(original_bam_filename, os.path.join(path,updated_bam_filename))
			return 0

	update_sam_offset = original_sam.nreferences

	#Replace references with renamed names
	edited_references = []
	for r in original_sam.references:
		if r in renamed_references:
			edited_references.append(renamed_references[r])
		else:
			edited_references.append(r)

	#Create full (combined) reference list
	if update_exists:
		references = edited_references + [r for r in update_sam.references]
		lengths = [r for r in original_sam.lengths] + [r for r in update_sam.lengths]
	else:
		references = edited_references
		lengths = [r for r in original_sam.lengths]

	#Filter reference set for those that are being removed. tids_to_remove must be a set for speed reasons.
	new_references = [references[tid] for tid in xrange(len(references)) if not tid in tids_to_remove]
	new_lengths = [lengths[tid] for tid in xrange(len(references)) if not tid in tids_to_remove]

	#Create mapping from full reference list to new reference list
	#Add 1 so that -1 maps to -1.
	new_tid_map = -1 * numpy.ones(len(references) + 1)
	new_tid = 0
	for tid in xrange(len(references)):
		if not tid in tids_to_remove:
			new_tid_map[tid] = new_tid
			new_tid += 1

	#Create an updated output SAM
	new_sam = pysam.Samfile(os.path.join(path,updated_bam_filename), mode='wb', referencenames=new_references, referencelengths=new_lengths)

	#Create an output SAM for each component
	component_sams = []
	component_tid_map = -1*numpy.ones(len(references), dtype=numpy.int)
	for i,c in enumerate(components):
		#Build header for component sams
		ref_c, len_c = [],[]
		j = 0
		for tid in c:
			ref_c.append(references[tid])
			len_c.append(lengths[tid])
			component_tid_map[tid] = j
			j += 1
		component_sams.append(pysam.Samfile(os.path.join(temppath,'component_' + str(i) + '.bam'), mode='wb', referencenames=ref_c, referencelengths=len_c))

	# Iterate over all aligned reads in both files
	if update_exists:
		inputIterator = SamIterator(original_sam, update_sam, True)
	else:
		inputIterator = SamIteratorSolo(original_sam)
	for alignedread,which_file in inputIterator:
	 	# Write every alignment to replacement SAM file
	 	tid = alignedread.tid
		if which_file == 1: #set offset based on file source
			tid += update_sam_offset

		is_proper_pair = alignedread.is_proper_pair
		is_unmapped = alignedread.is_unmapped

		alignedread.tid = new_tid_map[tid]
		alignedread.rnext = alignedread.tid
		if alignedread.tid == -1:
			alignedread.is_unmapped = True
			alignedread.is_proper_pair = False
		new_sam.write(alignedread)

		alignedread.tid = tid
		alignedread.is_unmapped = is_unmapped
		alignedread.is_proper_pair = is_proper_pair

		if tid != -1 and not(alignedread.qname in reads_to_ignore) and (not alignedread.is_paired or alignedread.is_proper_pair): #skips writing unmapped reads
			for i,c in enumerate(components):
				if tid in c:
					alignedread.tid = component_tid_map[tid]
					alignedread.rnext = alignedread.tid
					assert(alignedread.tid >= 0)
					component_sams[i].write(alignedread)

def get_partition_for_target(args, readGraph, target):
	component = readGraph.component(target)
	if args.no_partitions or len(component) < args.partition_threshold:
		return component
	elif (args.greedy):
		P = greedy_partitioner(readGraph, set([target]), args.greedy_threshold)
	elif (args.random_greedy):
		P = random_greedy_partitioner(readGraph, set([target]), args.greedy_threshold, args.random_greedy_trials)
	elif (args.edge_weight_min):
		P = min_weight_partitioner(readGraph, set([target]), 2)
	return P.B

def make_dir(path):
	try:
		os.makedirs(path)
	except OSError as exception:
		if os.path.isdir(path): #couldn't create directory because already there
			pass
		else: #directory does not exist and we can't create it
			raise OSError("Could not create directory {0}".format(path))

# Arguments: choice of clustering algorithm, read alignment threshold, followed by original alignment file
# (.bam), original eXpress file (results.xprs), original params file (params.xprs), read 1 fastq file,
# [-2] read 2 fastq file (optional), original FASTA file, new FASTA file, [-p] how many CPUs to use (optional),
# [-d] directory to put files (optional), [-s] stored targetGraph (optional)
def main():
	# Variables for various clustering decisions
	# Argument parsing
	parser = argparse.ArgumentParser(description='Finds reads most likely to be affected by particular transcript changes, and reruns them through eXpress.')
	alg_group = parser.add_mutually_exclusive_group(required=True)
	alg_group.add_argument('-g','--greedy', action='store_true',
					help='Clusters transcripts using a greedy algorithm')
	alg_group.add_argument('-r','--random-greedy', action='store_true',
					help='Clusters transcripts using a random greedy algorithm')
	alg_group.add_argument('-n','--no-partitions', action='store_true',
					help='Does not partition graph.')
	parser.add_argument('original_bam_filename', metavar='Original BAM filename',
					help='Original Bowtie alignment of reads to "old" transcripts. Must be sorted and in BAM format.')
	parser.add_argument('init_results_filename', metavar='eXpress results.xprs',
					help='results.xprs produced by eXpress when run on original alignments (above).')
	parser.add_argument('init_params_filename',metavar='eXpress params.xprs',
					help='params.xprs produced by eXpress when run on original alignments (above).')
	parser.add_argument('r1_fastq_filename', metavar='Read 1 fastq file',
					help='Fastq file of sequenced reads (read 1 if paired end).')
	parser.add_argument('original_fasta_filename', metavar='Old FASTA file',
					help='Original FASTA file of transcripts.')
	parser.add_argument('new_fasta_filename', metavar='New FASTA file',
					help='New FASTA file of transcripts.')
	parser.add_argument('-2', '--read2',
					help="Optional read 2 fastq file of sequenced reads.")
	parser.add_argument('-p', '--parallel', type=int, default = 1,
					help="Optional number of CPUs/cores to use for running bowtie.")
	parser.add_argument('-d', '--directory', type=str, default = '',
					help="Optional directory to create files in.")
	parser.add_argument('-s', '--stored-target-graph', type=str,
					help="Optional file containing stored target graph, created during previous run of reXpress.")
	parser.add_argument('--greedy-threshold', type=int, default=0.1)
	parser.add_argument('--partition-threshold', metavar='Min component size to parition', type=int, default=500,
					help='Minimum size of component to partition if enabled.')
	parser.add_argument('--random-greedy-trials', metavar='Random greedy trials', type=int, default=10,
					help='Number of trials used for random greedy partitioning if enabled.')
	parser.add_argument('--max-alignments', metavar='Max alignments', type=int, default=500,
					help='Maximum number of alignments per read. Reads with more alignments will be ignored.')

	args = parser.parse_args()

	# Put filenames in our namespace
	original_bam_filename = args.original_bam_filename
	results_filename = args.init_results_filename
	params_filename = args.init_params_filename
	r1_fastq_filename = args.r1_fastq_filename
	r2_fastq_filename = args.read2
	original_fasta_filename = args.original_fasta_filename
	new_fasta_filename = args.new_fasta_filename

	if (args.stored_target_graph):
		old_pickled_graph = args.stored_target_graph

	# Store all files in subdirectories
	if args.directory:
		make_dir(args.directory) #final permanent files will live here
	tempdir = tempfile.mkdtemp(dir=args.directory) #temporary files will live here and be deleted later

	# Set filenames that our pipeline will be creating and using
	only_new_fasta_filename = os.path.join(tempdir, "only_new.fasta")
	new_filename = os.path.join(tempdir, "new_alignments")
	new_bam_filename = new_filename + ".bam"
	new_index = os.path.join(tempdir, 'only_new')
	updated_results_filename = "updated_results.xprs"
	new_pickled_graph = "saved_target_graph.pickle"

	# Compare FASTA files to create new file for alignment purposes
	print "Creating new FASTA file..."
	fasta = Fasta(original_fasta_filename,new_fasta_filename)
	fasta_exists = True
	if fasta.output_fasta(only_new_fasta_filename, fasta.new) == 0: #fasta file was empty and not created
		fasta_exists = False

	def run_bowtie():
		# Run Bowtie to create alignments to new.fasta

		print "Building new Bowtie 2 index..."
		proc = subprocess.Popen(['bowtie2-build', only_new_fasta_filename, new_index],stdout=subprocess.PIPE)

		output = proc.communicate()
		if proc.poll() is None:
			raise IOError("bowtie2-build process did not terminate.")

		print "Aligning reads to new index with Bowtie 2..."
		proc3 = subprocess.Popen(['samtools', 'view', '-bhS', '-o' + new_bam_filename, '-'],stdin=subprocess.PIPE) #compress to bam
		proc2 = subprocess.Popen(['gawk', '{if (substr($1,1,1) == "@" || (and($2, 0x4)==0) && and($2, 0x8)==0) print}'],stdin=subprocess.PIPE,stdout=proc3.stdin) #remove unmapped reads
		if r2_fastq_filename: #paired end
			proc1 = subprocess.Popen(['bowtie2', '-k %d' % (args.max_alignments+1), '--rdg 6,5', '--rfg 6,5', '--score-min L,-.6,-.4', '--no-discordant', '--no-mixed', '--reorder', '-p %d' % args.parallel, '-x ' + new_index, '-1' +r1_fastq_filename, '-2' +r2_fastq_filename],stdin=subprocess.PIPE,stdout=proc2.stdin)
		else: #single read
			proc1 = subprocess.Popen(['bowtie2', '-k %d' % (args.max_alignments+1), '--rdg 6,5', '--rfg 6,5', '--score-min L,-.6,-.4', '--no-discordant', '--no-mixed', '--reorder', '-p %d' % args.parallel, '-x ' + new_index, '-U ' + r1_fastq_filename],stdin=subprocess.PIPE,stdout=proc2.stdin)
		while proc1.poll() is None: #wait for bowtie to finish
			time.sleep(60)
		proc2.stdin.close()
		while proc2.poll() is None:
			time.sleep(5)
			proc3.stdin.close()
		while proc3.poll() is None:
			time.sleep(5)

	bam_exists = False
	if os.path.exists(new_bam_filename):
		# Check that something aligned
		for line in pysam.Samfile(new_bam_filename,'rb').fetch(until_eof=True): #this loop won't run if there are no readlines
			bam_exists = True
			break
	elif fasta_exists:
		run_bowtie()
		# Check that something aligned
		for line in pysam.Samfile(new_bam_filename,'rb').fetch(until_eof=True): #this loop won't run if there are no readlines
			bam_exists = True
			break

	# Build list of tids for new and deleted transcripts
	original_sam = pysam.Samfile(original_bam_filename,"rb")

	full_references = []
	for ref in original_sam.references:
		if ref in fasta.renamed_dict: #catch references that should be renamed
			full_references.append(fasta.renamed_dict[ref])
		else:
			full_references.append(ref)
	if bam_exists:
		new_sam = pysam.Samfile(new_bam_filename,"rb")
		full_references = full_references + [r for r in new_sam.references]

	new_tids = [full_references.index(r) for r in fasta.new]
	deleted_tids = [full_references.index(r) for r in fasta.deleted]

	affected_targets = new_tids + deleted_tids

	# Build or load graph
	if (args.stored_target_graph):
		print "Loading stored transcript graph..."
		originalGraph = pickle.load(file(old_pickled_graph,'r'))
	else:
		print "Building transcript graph..."
		originalGraph = build_graph_single(original_bam_filename, args.max_alignments)

	# Update graph
	if bam_exists:
		print "Updating transcript graph..."
		readGraph = update_graph(originalGraph,original_bam_filename,new_bam_filename,args.max_alignments)
	else:
		readGraph = originalGraph

	# Store updated graph without deleted targets (note that this only removes nodes from a copy, not the real graph)
	pickle.dump(originalGraph.remove_nodes(set(deleted_tids)), file(os.path.join(args.directory,new_pickled_graph), 'w'))

	# Partition graph into blocks for each new or deleted transcript.
	affected_partitions = []
	unaligned_targets = set()
	for node in affected_targets:
		if not (node in readGraph.has_reads) and not(node in deleted_tids): #catch nodes without aligned reads that haven't been deleted
			assert(len(readGraph.component(node)) == 1)
			unaligned_targets.add(full_references[node])
		else:
			affected_partitions.append(get_partition_for_target(args, readGraph, node))

	#Merge overlapping partitions
#	print "{0} partitions containing {1} nodes".format(len(affected_partitions),len(reduce(set.union, affected_partitions, set())))
	while True:
		merged_partitions = []
		for p1,p2 in itertools.combinations(affected_partitions,2):
			if len(p1&p2) != 0 and not(p1 in merged_partitions or p2 in merged_partitions): #skip partitions that are already merged
				merged_partitions.append(p1|p2)
		if len(merged_partitions) == 0:
			break
		merge_union = reduce(set.union, merged_partitions, set())
		affected_partitions = merged_partitions + [p for p in affected_partitions if len(p&merge_union) == 0]

	#Remove deleted targets from partitions
	set_deleted = set(deleted_tids) #deleted_tids is a list
	affected_partitions = [p-set_deleted for p in affected_partitions]
	affected_partitions = [p for p in affected_partitions if len(p)]

	# Build fasta files for each component
	print "Building fasta files for each component..."
	fasta.component_fastas(affected_partitions, full_references, tempdir)

	# Build bam files for each component, as well as final composite sam
	print "Building bam files for each component..."
	if bam_exists:
		extract_component_alignments(affected_partitions,original_bam_filename,new_bam_filename, args.directory, tempdir, readGraph.ignore, deleted_tids, fasta.renamed_dict)
	else:
		extract_component_alignments(affected_partitions,original_bam_filename,"", args.directory, tempdir, readGraph.ignore, deleted_tids, fasta.renamed_dict)

	def run_express():
		# Run eXpress on each component bam/fasta pair
		print "Running eXpress on each component..."
		for i in xrange(len(affected_partitions)):
			print "Running eXpress on component {0}".format(i)
			# Create directory so eXpress results don't overwrite themselves
			c_path = os.path.join(tempdir, 'component_' +str(i))
			make_dir(c_path)
			# Run eXpress
#			proc = subprocess.Popen(['express', '-f 0.75', '--aux-param-file ' +params_filename, '-o ' +os.path.join(path,c_path), os.path.join(path, 'component_' +str(i)+ '.fasta'), os.path.join(path, 'component_' +str(i)+ '.bam')], shell=True)
			os.system('express -f 0.75 --aux-param-file ' +params_filename + ' -B 20 -o ' + c_path + ' ' +	os.path.join(tempdir, 'component_' +str(i)+ '.fasta') + ' ' + os.path.join(tempdir, 'component_' +str(i)+ '.bam'))
#			output = proc.communicate()
#			if proc.poll() is None:
#				raise IOError("eXpress process did not terminate.")
	run_express()

	def update_results():
		print "Generating updated_results.xprs..."
		# Combine rows from all the results.xprs files
		component_results = []
		new_results = {}
		sum_est_counts = 0 # Keep track of total estimated_counts, for updating fpkms
		for i in xrange(len(affected_partitions)):
			component_results = open(os.path.join(tempdir,'component_' +str(i),"results.xprs"),'r')
#			print "Component {0}".format(i)
			for line in component_results:
				if line[0:6] != 'bundle': #skip header
					split_line = line.strip().split('\t')
					sum_est_counts += float(split_line[6])
					new_results[split_line[1]] = split_line
		original_results = open(results_filename,'r')
		for line in original_results:
			if line[0:6] == 'bundle': #header
				header = line
			else:
				split_line = line.strip().split('\t')
				line_name = split_line[1]
				sum_est_counts += float(split_line[6])
				if (line_name not in new_results) and (line_name not in fasta.deleted) and (line_name != 'target_id'):
					new_results[line_name] = split_line

		# Write xprs results to file, updating FPKMs and bundle_IDs:
		all_comps, component_map = readGraph.all_components()
		next_bundle = len(all_comps) #indicates smallest UNUSED bundleid

		updated_results = open(os.path.join(args.directory, updated_results_filename),'w')
		updated_results.write(header)
		for name,field in new_results.iteritems():
			if name in fasta.renamed_dict: #catch renamed targets
				name = fasta.renamed_dict[name]
			bundleid = component_map[full_references.index(name)]
			try:
				updated_fpkm = pow(10,9)*float(field[6])/(sum_est_counts*float(field[3]))
			except:
				updated_fpkm = 0 #catch when field[3] is 0
			updated_line = '%d\t%s\t%s\t%f\t%s\n' % (bundleid, name, '\t'.join(field[2:10]), updated_fpkm, '\t'.join(field[11:]))
			updated_results.write(updated_line)
		for name in unaligned_targets: #add transcripts that had no aligned reads
			length = len(fasta.n_seqs[name])
			updated_results.write('%d\t%s\t%d\t%f\t0.0\t0.0\t0.0\t0.0\t0.0\t0.0\t0.0\t0.0\t0.0\t0.0\n' % (next_bundle,name,length,length))
			next_bundle += 1 #put each unmapped read in its own bundle
		updated_results.close()
	update_results()

	# Clean up temp files
	shutil.rmtree(tempdir)
	print "Pipeline finished."

if __name__ == "__main__":
	main()
