o
    Uݢgh                     @   s   d dl Zd dlmZ d dlmZmZmZ d dlm	Z
 ejjddZdd Zdd	 Zd
d Zd$ddZdd Zd$ddZdd Zdd Zdd Zdd Zdd Zdd Zdd ZG d d! d!ZG d"d# d#ZdS )%    N)optimize)betalngammalnloggamma*   )seedc                 C   s6   t | d }t dd | D }tt|t|S )z=Inverse Simpson Index, or the effective diversity of power 2.   c                 s   s    | ]}|d  V  qdS )r   N ).0vr	   r	   _/oak/stanford/groups/akundaje/marinovg/programs/cellranger-9.0.1/lib/python/cellranger/stats.py	<genexpr>   s    z&effective_diversity.<locals>.<genexpr>)npsumtk_statsrobust_dividefloat)counts	numeratordenominatorr	   r	   r   effective_diversity   s   r   c                 C   s   t j| dd\}}t | ||dk }t | }i }tt t| | | | D ]\}}||vr5d||< q(|| ||< ||  d7  < q(|S )a_  Produces an incremental count vector from a categorial sample vector.

    Takes a set of drawn samples from a multinomial or similar distribution
    distribution and produces a vector of the same size, with each element the
    the number of times the sample at that position has been seen
    indices.  I.e., it would turn the following vector of sample draws:
    [1, 1, 2, 1, 0, 2, 1, 3] into
    [1, 2, 1, 3, 1, 2, 4, 1]

    Args:
      sample_draws (np.ndarray(int)): A vector of sample indices.

    Returns:
      inc_counts (np.ndarray(int)): The incremental counts of those sample indices.

    T)return_counts   r   )r   uniqueisin	ones_likeziparangelen)sample_drawsidxscmaskZ
inc_countsr   ijr	   r	   r   incremental_counts_from_sample   s   
$
r%   c                 C   s   t j| |dS )an  Produces a count of items in the sample draws.

    Takes an array of samples drawn from a distribution and returns a new
    array of counts of how many times each feature was seen in the draws.

    Args:
      sample_draws (np.ndarray(int)): A vector of sample indices.

    Returns:
      counts (np.ndarray(int): A vector of counts for each feature index.

    )	minlength)r   bincount)r   num_featuresr	   r	   r   collapse_draws_to_counts5   s   r)   c                 C   s   | j d }tj|td}|du rt| jddd }t|d }t|D ]9}| j| | j|d  }}| j	|| }	| j
|| }
|j|	ddd}|| t|
d   |
|   ||< q%|S )a(  Computes the multinomial log-likelihood for many barcodes.

    Multinomial log-likelihood for a single barcode where the count of UMIs for
    feature i is x_i is:
    l = log(gamma(sum_i(x_i) + 1)) - sum_i(log(gamma(x_i + 1)) + sum_i(log(p_i) * x_i)

    Because the input matrix is sparse and zero counts do not contribute to the
    likelihood, we can rapidly remove zero elements from the computation.

    Args:
      matrix (scipy.sparse.csc_matrix): Matrix of UMI counts (feature x barcode)
      logp (np.ndarray(float)): The natural log of the multinomial probability
        vector across features
      n (int, optional): The total number of UMIs per barcode.  Saves computation
        if it can be precomputed.

    Returns:
      loglk (np.ndarray(float)): Log-likelihood for each barcode
    r   dtypeNr   axisclipr-   mode)shaper   zerosr   asarrayr   r   rangeindptrindicesdatatake)matrixlogpnnum_bcsloglkconstsr#   	idx_startidx_endr    rowZ
short_logpr	   r	   r   eval_multinomial_loglikelihoodsE   s   
*rB   c                 C   sB   t | }tdt| d }t|t| ||   }t|S )a  Computes the cumulative multinomial log-likelihood for a vector.

    Given a vector of sample draws from a multinomial distribution, computes the
    log-likelihood of all of the draws up to each draw.  This is done incrementally,
    by first pre-computing the rank R of each draw as the number of times that
    feature was seen when it was drawn.  Then the incremental log-likelihood from
    that draw is

    dl = log(i) - log(R_i) + log(p_i)

    Args:
      sample_draws (np.ndarray(int)): A vector of sample draws
      logp (np.ndarray(float)): The log-probability of sampling each feature

    Returns:
      loglk (np.ndarray(float): The cumulative log-likelihood up to each draw
    r   )r%   r   r   r   logcumsum)r   r:   marginal_countsnvalsr=   r	   r	   r   )eval_multinomial_loglikelihood_cumulativeh   s   
rG   c                 C   s   | j d }t|}|du rt| jddd }t|tt|| }t| j d D ]9}| j| | j|d  }}| j	|| }	| j
|| }
|j|	ddd}|| t|
  t||
  ||< q-|S )a?  Computes the Dirichlet-multinomial log-likelihood for many barcodes.

    Dirichlet-Multinomial log-likelihood for a single barcode where the count
    of UMIs for feature i is x_i is:
    l = log(sum_i(x_i)) + log(beta(sum_i(a_i), sum_i(x_i))) -
        sum_i(log(x_i)) - sum_i(log(beta(a_i, x_i)))

    Because the input matrix is sparse and zero counts do not contribute to the
    likelihood, we can rapidly remove zero elements from the computation.

    Args:
      matrix (scipy.sparse.csc_matrix): Matrix of UMI counts (feature x barcode)
      alpha (np.ndarray(float)): The vector of Dirichlet parameters for each feature
      n (int, optional): The total number of UMIs per barcode.  Saves computation
        if it can be precomputed.

    Returns:
      loglk (np.ndarray(float)): Log-likelihood for each barcode
    r   Nr   r,   r.   r/   )r1   r   r2   r3   r   rC   r   r4   r5   r6   r7   r8   )r9   alphar;   r<   r=   r>   Zbc_indexr?   r@   r    rA   Zshort_alphar	   r	   r   )eval_dirichlet_multinomial_loglikelihoods   s   

*rI   c                 C   sl   t | }tdt| d }t|}t|t| t|| d  t|d ||    }t|S )a  Computes the cumulative Dirichlet-multinomial log-likelihood for a vector.

    Given a vector of sample draws from a Dirichlet-multinomial distribution, computes the
    log-likelihood of all of the draws up to each draw.  This is done incrementally,
    by first pre-computing the rank R of each draw as the number of times that
    feature was seen when it was drawn.  Then the incremental log-likelihood from
    that draw is

    dl = log(i) - log(R_i) - log(i + sum_i(a_i) - 1) + log(R_i + a_i - 1)

    Args:
      sample_draws (np.ndarray(int)): A vector of sample draws
      alpha (np.ndarray(float)): The Dirichlet parameter for each feature

    Returns:
      loglk (np.ndarray(float): The cumulative log-likelihood up to each draw
    r   )r%   r   r   r   r   rC   rD   )r   rH   rE   rF   Zalpha_0r=   r	   r	   r   3eval_dirichlet_multinomial_loglikelihood_cumulative   s   

rJ   c                 C   s~   | dd|f } t | jdd }dd }ddg}tj||| ||fd}|js1td	|j t	d
|j
 d|  |j
S )aX  Estimates the best-fit overdispersion parameter for data.

    Uses a Dirichlet-multinomial to maximize the log-likelihood of the ambient
    barcode signal, given an input probability per feature that is scaled by
    a fixed overdispersion parameter.

    Args:
      matrix (scipy.sparse.csc_matrix): Matrix of UMI counts (feature x barcode)
      ambient_bcs (np.array): Array of barcode indexes to use for the ambient background
      p (np.ndarray(float)): The estimated multinomial probability vector across features

    Returns:
      alpha (float): Best-fit overdispersion for the data
    Nr   r,   c                 S   s   t t|| | |d S )N)r;   )r   r   rI   )rH   r9   pumis_per_bcr	   r	   r   ambient_loglk   s   z8estimate_dirichlet_overdispersion.<locals>.ambient_loglkgMbP?i'  )boundsargszCould not find valid alpha: zAlpha = zC: maximizes the likelihood of the ambient barcodes. Search bounds: )r   r3   r   flattenr   minimize_scalarsuccess
ValueErrormessageprintx)r9   Zambient_bcsrK   rL   rM   rN   resultr	   r	   r   !estimate_dirichlet_overdispersion   s   rX   c                 C   s   t j| d}t||S )a  Returns a fixed size array of multinomial draws from the input probabilities.

    Given a number of samples to draw and a probability vector for all features
    to sample from, produces an array of feature indices of the desired size,
    drawn according to a multinomial distribution.  Because it is significantly
    faster to generate random numbers on the unit interval, we use that and then
    find the feature index using the cumulative probabilities of the probability
    vector.

    Args:
      num_draws (int): The number of samples to draw from the distribution
      p_cumulative (np.ndarray(float)): The cumulative probability of all possible
        features.  The length of this array should be the number of features.
        It should be sorted from smallest to largest, and the last entry should
        be 1.

    Returns:
      sample_draws (np.ndarray(float)): A random set of draws from a multinomial
        distribution.
    )size)RNGrandomr   searchsorted)	num_drawsp_cumulativeZrng_numsr	   r	   r   draw_multinomial_sample   s   r_   c                 C   s   t |}t| t|S )a  Returns an array of Dirichlet-multinomial draws from the input parameters.

    Given a number of samples to draw and a vector of alpha parameters for all
    features to sample from, produces an array of feature indices of the desired
    size, drawn according to a Dirichlet-multinomial distribution.  Note that
    Dirichlet-multinomial draws are *not* independent and identically-distributed,
    so that all draws for a sample must be done with a single call or the
    sample variance will be under-stated.

    Args:
      num_draws (int): The number of samples to draw from the distribution
      alpha (np.ndarray(float)): The Dirichlet parameters for each feature.
      Normalized to a sum of 1, these are the mean probabilities of each feature
      over many draws of many samples.  The length of this array should be the
      number of features.

    Returns:
      sample_draws (np.ndarray(float)): A random set of draws from a Dirichlet-
        multinomial distribution.
    )rZ   	dirichletr_   r   rD   )r]   rH   probsr	   r	   r   !draw_dirichlet_multinomial_sample   s   
rb   c           
      C   s   t t |}t|}t jt||ftd}t | }t | }t	|D ]}t
||}	t|	||d  |dd|f< q%||fS )a  Simulate draws from a multinomial distribution for many values of N.

    Note that the samples within each simulation are not independent; we generate
    N draws for the largest value of N and use subsamples of the largest simulation
    for smaller values of N.

    Args:
      p (np.ndarray(float)): Probability of observing each feature.
      umis_per_bc (np.ndarray(int)): UMI counts per barcode.
      num_sims (int): Number of simulations to perform.

    Returns:
      distinct_ns (np.ndarray(int)): an array containing the distinct N values
          that were simulated.
      log_likelihoods (np.ndarray(float)): a len(distinct_ns) x num_sims matrix
          containing the simulated log likelihoods.
    r*   r   N)r   flatnonzeror'   maxr2   r   r   rD   rC   r4   r_   rG   )
rK   rL   num_sims
distinct_nmax_nr=   r^   r:   r#   drawr	   r	   r   #simulate_multinomial_loglikelihoods  s   


 ri   c                 C   sl   t t |}t|}t jt||ftd}t|D ]}t|| }t	|| |d  |dd|f< q||fS )a  Simulate draws from a Dirichlet-multinomial distribution for many values of N.

    Note that the samples within each simulation are not independent; we generate
    N draws for the largest value of N and use subsamples of the largest simulation
    for smaller values of N.

    Args:
      alpha (np.ndarray(float)): Dirichlet parameter for each feature.
      umis_per_bc (np.ndarray(int)): UMI counts per barcode.
      num_sims (int): Number of simulations to perform.

    Returns:
      distinct_ns (np.ndarray(int)): an array containing the distinct N values
          that were simulated.
      log_likelihoods (np.ndarray(float)): a len(distinct_ns) x num_sims matrix
          containing the simulated log likelihoods.
    r*   r   N)
r   rc   r'   rd   r2   r   r   r4   rb   rJ   )rH   rL   re   rf   rg   r=   r#   rh   r	   r	   r   -simulate_dirichlet_multinomial_loglikelihoods8  s   
rj   c           
      C   s   t | t |ks
J |jd t |ksJ t|| }|jd }t | }t|}t|D ]}t||| ddf || k }	td|	 d|  ||< q-|S )a  Compute p-values for observed multinomial log-likelihoods.

    Args:
      umis_per_bc (nd.array(int)): UMI counts per barcode
      obs_loglk (nd.array(float)): Observed log-likelihoods of each barcode deriving from an ambient profile
      sim_n (nd.array(int)): Multinomial N for simulated log-likelihoods
      sim_loglk (nd.array(float)): Simulated log-likelihoods of shape (len(sim_n), num_simulations)

    Returns:
      pvalues (nd.array(float)): p-values
    r   r   N)r   r1   r   r\   r2   r4   r   r   )
rL   Z	obs_loglkZsim_nZ	sim_loglkZ	sim_n_idxre   Znum_barcodespvaluesr#   Znum_lower_loglkr	   r	   r   compute_ambient_pvaluesX  s   

"rl   c                   @   s(   e Zd ZdZdd Zdd Zdd ZdS )	CurveaC  Curve information for plotting.

    Attributes:
        x     (list of int): Curve x-axis (number of cells)
        y     (list of float): Curve y-axis (expected value of number of
              unique clonotypes)
        y_std (list of float): Standard deviation of number of unique
              clonotypes in a curve
        y_ciu (list of float): Upper bound of 95% confidence interval
              for number of unique clonotypes in a curve
        y_cil (list of float): Lower bound of 95% confidence interval
              for number of unique clonotypes in a curve
    c                 C   s"   g | _ g | _g | _g | _g | _dS )zInitialize an empty curve.N)rV   yy_stdy_ciuy_cilselfr	   r	   r   __init__  s
   
zCurve.__init__c                 C   s$   t | jdkst | jdkrdS dS )zRCheck if curve is empty or calculated.

        Returns:
            bool
        r   TF)r   rn   rV   rr   r	   r	   r   is_empty  s   zCurve.is_emptyc                 C   s<   dd | j | j| j| jfD }||d t|krdS dS )zcCheck if calculated curve is consistent (length match).

        Returns:
            bool
        c                 S      g | ]}t |qS r	   )r   )r
   r#   r	   r	   r   
<listcomp>      z'Curve.is_consistent.<locals>.<listcomp>r   FT)rV   rn   rq   rp   countr   )rs   Zlen_listr	   r	   r   is_consistent  s   zCurve.is_consistentN)__name__
__module____qualname____doc__rt   ru   rz   r	   r	   r	   r   rm   u  s
    
rm   c                   @   s   e Zd ZdZdd Zdd Zdeeef fddZdefd	d
Z	dd Z
dd Zd'defddZ	d(dededefddZdd Zd)dedefddZdd Zd d! Z	d*d"ed#edefd$d%Zd&S )+	Diversitya  Represents diversity with rarefaction and extrapolation curves.

    Attributes:
        sorted_hist         (list of tuple): A histogram sorted by abundance
                                            (most abundant has lowest index)
        freq_counts         (dict of int: int): Frequency counts table
        N                   (int): total number of counts in histogram
        rarefaction_curve   (Curve): Rarefaction curve information
        extrapolation_curve (Curve): Extrapolation curve information
    c                 C   s6   t |dd| _|  | _|  | _t | _t | _d S )NT)reverse)	sortedsorted_hist_calc_nN_get_freq_countsfreq_countsrm   rarefaction_curveextrapolation_curve)rs   histr	   r	   r   rt     s
   

zDiversity.__init__c                 C   s   | j dks|  dv rdS dS )Nr   )r   FTr   	f_0_chao1rr   r	   r	   r   is_diversity_curve_possible  s   z%Diversity.is_diversity_curve_possiblereturnc                 C   s6   i }| j D ]}||vrd||< q||  d7  < q|S )z]Get frequency counts (f_k in Colwell et al 2012).

        Returns:
            dict
        r   )r   )rs   Zret_dictZabundr	   r	   r   r     s   

zDiversity._get_freq_countsc                 C   s
   t | jS )zmCalculates number of samples.

        (length of input histogram)

        Returns:
            int
        )r   r   rr   r	   r	   r   r     s   
zDiversity._calc_nc                 C   sj   || j | kr	dS t| j | d t| j | d  t| j d  t| j | | d  }tt|S )zCalculates alpha parameter in Colwell et al 2012.

        Args:
            n (int): rarefaction point n
            k (int): Frequency

        Returns:
            int
        r   r   )r   r   r   expreal)rs   r;   kZ	log_alphar	   r	   r   _alpha  s   
zDiversity._alphac                 C   s   t t| j }d}| j D ]\}}|| ||| 7 }q|| }d}| j D ]\}}|d| || d | t |d |    7 }q*|t|fS )zCalculates rarefaction and standard deviation of rarefaction for a single point n.

        Args:
            n (int): rarefaction point n

        Returns:
            float, float
        r   r   r   )	r   r   r   valuesitemsr   assemblage_size_estimater   sqrt)rs   r;   Znum_clonZsum_helper_exp_valr   ry   exp_valZsum_helper_std_devr	   r	   r   _rarefaction  s   

zDiversity._rarefaction(   	num_stepsc                 C   s  t | j|d  }|dkrd}td| j|}t||k r%td| j| |}d}|gt| }|gt| }|gt| }|gt| }t|D ]'\}	}
| |
\||	< ||	< ||	 d||	   ||	< ||	 d||	   ||	< qGt|| j_|| j_	|| j_
|| j_|| j_dS )zCalculates rarefaction curve for num_steps between 1 and N.

        Args:
            num_steps (int): Number of steps between 1 and N
        r   r           \(\?N)intr   r4   r   	enumerater   listr   rV   rn   ro   rq   rp   )rs   r   	step_sizeZrc_xplaceholderZrc_y_expZrc_y_stdZrc_y_ciuZrc_y_cilidxr;   r	   r	   r   calc_rarefaction_curve  s(   z Diversity.calc_rarefaction_curveToriginstdevc                 C   s   |  | g }|| jj| jjd|d|d |r?|| jj| jjddd  | jj| jjddd  d|d dd|d	d
 |S )a]  Create a plotly curve for rarefaction.

        Args:
            origin (str): Origin string (see vdj inputs)
            color: color of the curve
            num_steps (int): number of extrapolation steps
            stdev (bool): flag for calculation of standard deviation (and confidence intarvals)

        Returns:
            [dict]
        Zscatterlines)rV   rn   typenamer0   
line_colorNr   z (Stdev)Ztoselfzrgba(255,255,255,0)g?)rV   rn   r   r   fillr   Z	fillcolorZopacity)r   appendr   rV   rn   rp   rq   )rs   r   colorr   r   r   r	   r	   r   calc_rarefaction_curve_plotly  s0   

z'Diversity.calc_rarefaction_curve_plotlyc           
      C   st   t t| j }t |  }t | jd }t | j}t || j }dd|| |  |  }|||  }d}	||	fS )aZ  Calculates extrapolation for a single point (N + m).

        standard deviation is not implemented.

        Args:
            n_plus_m (int): Distance of extrapolation point to N
                (how many more samples collected)

        Returns:
            float: Expected value
            None: Standard deviation (not implemented)
        r   g      ?N)r   r   r   r   r   r   )
rs   n_plus_mZs_obsf_0Zf_1r   mZbracketsr   Zstd_devr	   r	   r   _extrapolationI  s   
zDiversity._extrapolation  rg   c                 C   s  t || j |d  }t| j||}t||k r!t| j|| |}dgt| }dgt| }dgt| }dgt| }t|D ]-\}	}
| |
\||	< ||	< ||	 durn||	 d||	   ||	< ||	 d||	   ||	< qAt|| j_|| j_	|| j_
|| j_|| j_dS )a  Calculates extrapolation curve for num_steps between N and max_n.

        Args:
            num_steps (int): Number of steps between N and max_n
            max_n (int): The limit to which the function extrapolates

        Returns:
            None
        r   r   Nr   )r   r   r4   r   r   r   r   r   rV   rn   ro   rq   rp   )rs   r   rg   r   Zec_xZec_y_expZec_y_stdZec_y_ciuZec_y_cilr   r   r	   r	   r   calc_extrapolation_curvef  s&   
z"Diversity.calc_extrapolation_curvec                 C   sf   d| j v rd| j v rt| j d d td| j d   S d| j v r1t| j d | j d d  d S dS )zEstimate f0 (number of clonotypes that we haven't seen yet) using chao1 estimator.

        Based on: Colwell et al 2012

        Returns:
            float
        r   r   g       @g      )r   r   rr   r	   r	   r   r     s
   $
 zDiversity.f_0_chao1c                 C   s   | j |   S )zEstimate assemblage size (number of total unique clonotypes) using chao1 estimator.

        Based on: Colwell et al 2012

        Returns:
            float
        r   rr   r	   r	   r   r     s   z"Diversity.assemblage_size_estimaterc_stepsec_stepsc                 C   sL  | j  r| || | j r| | | j  rdS | j   r%dS d}t|dp}|| t| jj	D ]*\}}|| jj
| | jj| | jj| g}	ddd |	D dg d }
||
 q8t| j j	D ]*\}}|| j j
| | j j| | j j| g}	dd	d |	D d
g d }
||
 qiW d   dS 1 sw   Y  dS )aw  Calculate rarefaction and extrapolation diversity curves.

        Save them in csv format in outfile.

        Args:
            outfile (str): Path to output file
            rc_steps (int): Rarefaction curve steps between 1 and N
            ec_steps (int): Extrapolation steps between N and max_n
            max_n (int): Maximum number of cells extrapolating to
        Nzx,y,y_cil,y_ciu,type
w,c                 S   rv   r	   strr
   r;   r	   r	   r   rw     rx   z0Diversity.create_both_curves.<locals>.<listcomp>Zrarefaction
c                 S   rv   r	   r   r   r	   r	   r   rw     rx   Zextrapolationr   )r   ru   r   r   r   rz   openwriter   rV   rn   rq   rp   join)rs   outfiler   r   rg   headerZ	outhandler   rV   numbersliner	   r	   r   create_both_curves  s@   










zDiversity.create_both_curvesN)r   )r   T)r   r   )r   r   r   )r{   r|   r}   r~   rt   r   dictr   r   r   r   r   r   r   boolr   r   r   r   r   r   r	   r	   r	   r   r     s:    	

+r   )N)numpyr   scipyr   scipy.specialr   r   r   tenkit.statsstatsr   r[   default_rngrZ   r   r%   r)   rB   rG   rI   rJ   rX   r_   rb   ri   rj   rl   rm   r   r	   r	   r	   r   <module>   s(   
#
"&  -