B
     \X                 @   s  d Z ddlmZmZmZ ddlZddlZddlZddlm	Z	m
Z
mZmZmZmZmZ ddlZddlmZ ddlmZ ddlmZmZ dd	lmZmZmZ eeZdZdZ d
Z!dZ"dZ#e	Z$e
Z%d& Z'x&e(e'D ]\Z)Z*e+ej,e e*e) qW dd Z-e. Z/G dd de0Z1G dd de0Z2dddZ3e3e4j5d  Z6ye1 7 Z8W n   dZ8Y nX e8dk rfdZ9ndZ9dd Z:dd Z;d Z<d!Z=G d"d# d#e0Z>d$Z?d%Z@d&ZAd'ZBd(d) ZCd*d+ ZDd,d- ZEeFd.ZGeFd/ZHeFd0ZIeFd1Jd2d3ZKeFd4ZLd5d6d7d8d9d:d;d<d=d>d?d@hZMeFdAZNeFdBZOeFdCZPeFdDZQeFdEZReFdFZSeFdGZTeFdHZUdIdJ ZVdKdL ZWdMdN ZXdOdP ZYdQdR ZZdS )Sz(
This is a direct translation of nvvm.h
    )print_functionabsolute_importdivisionN)c_void_pc_intPOINTERc_char_pc_size_tbyrefc_char)ir)config   )	NvvmErrorNvvmSupportError)get_libdeviceopen_libdeviceopen_cudalib         a  
NVVM_SUCCESS
NVVM_ERROR_OUT_OF_MEMORY
NVVM_ERROR_PROGRAM_CREATION_FAILURE
NVVM_ERROR_IR_VERSION_MISMATCH
NVVM_ERROR_INVALID_INPUT
NVVM_ERROR_INVALID_PROGRAM
NVVM_ERROR_INVALID_IR
NVVM_ERROR_INVALID_OPTION
NVVM_ERROR_NO_MODULE_IN_PROGRAM
NVVM_ERROR_COMPILATION
c               C   s(   y
t   W n tk
r   dS X dS dS )z(
    Return if libNVVM is available
    FTN)NVVMr    r   r   6lib/python3.7/site-packages/numba/cuda/cudadrv/nvvm.pyis_available1   s
    
r   c               @   s   e Zd ZdZeeeeefeeefeeefeeee	efeeeeefeeee	feeefeeee	feeefd	Z
dZdd Zdd Zdd	d
ZdS )r   zProcess-wide singleton.
    )	nvvmVersionnvvmCreateProgramnvvmDestroyProgramnvvmAddModuleToProgramnvvmCompileProgramnvvmGetCompiledResultSizenvvmGetCompiledResultnvvmGetProgramLogSizenvvmGetProgramLogNc             C   s   t  | jd krt|  | _}ytddd|_W n8 tk
rj } zd | _d}t|| W d d }~X Y nX xF|j	 D ]8\}}t
|j|}|d |_|dd  |_t||| qxW W d Q R X | jS )NZnvvmT)Zcccz;libNVVM cannot be found. Do `conda install cudatoolkit`:
%sr   r   )
_nvvm_lock_NVVM__INSTANCEobject__new__r   driverOSErrorr   _PROTOTYPESitemsgetattrZrestypeZargtypessetattr)clsZinsteerrmsgnameprotofuncr   r   r   r'   j   s    

zNVVM.__new__c             C   s8   t  }t  }| t|t|}| |d |j|jfS )NzFailed to get version.)r   r   r
   check_errorvalue)selfmajorminorerrr   r   r   get_version   s
    zNVVM.get_versionFc             C   s2   |r.t |t| }|r*t| td n|d S )Nr   )r   RESULT_CODE_NAMESprintsysexit)r6   errormsgr>   excr   r   r   r4      s    zNVVM.check_error)F)__name__
__module____qualname____doc__nvvm_resultr   r   nvvm_programr   r	   r*   r%   r'   r:   r4   r   r   r   r   r   ?   s   

r   c               @   s<   e Zd Zdd Zdd Zdd Zdd Zd	d
 Zdd ZdS )CompilationUnitc             C   s4   t  | _t | _| jt| j}| j|d d S )NzFailed to create CU)r   r(   rG   _handler   r
   r4   )r6   r9   r   r   r   __init__   s    zCompilationUnit.__init__c             C   s*   t  }|t| j}|j|ddd d S )NzFailed to destroy CUT)r>   )r   r   r
   rI   r4   )r6   r(   r9   r   r   r   __del__   s    zCompilationUnit.__del__c             C   s*   | j | j|t|d}| j |d dS )z
         Add a module level NVVM IR to a compilation unit.
         - The buffer should contain an NVVM module IR either in the bitcode
           representation (LLVM3.0) or in the text representation.
        NzFailed to add module)r(   r   rI   lenr4   )r6   bufferr9   r   r   r   
add_module   s    zCompilationUnit.add_modulec             K   sh  g }d|kr | dr |d |dr>|d| d  |dr\|d| d  d}x@|D ]8}||krftt| |}|d|d	d
|f  qfW |rdtt|	 }t
d|tt| dd |D  }| j| jt||}| |d t }	| j| jt|	}| |d t|	j  }
| j| j|
}| |d |  | _|
dd S )aF  Perform Compliation

        The valid compiler options are

         *   - -g (enable generation of debugging information)
         *   - -opt=
         *     - 0 (disable optimizations)
         *     - 3 (default, enable optimizations)
         *   - -arch=
         *     - compute_20 (default)
         *     - compute_30
         *     - compute_35
         *   - -ftz=
         *     - 0 (default, preserve denormal values, when performing
         *          single-precision floating-point operations)
         *     - 1 (flush denormal values to zero, when performing
         *          single-precision floating-point operations)
         *   - -prec-sqrt=
         *     - 0 (use a faster approximation for single-precision
         *          floating-point square root)
         *     - 1 (default, use IEEE round-to-nearest mode for
         *          single-precision floating-point square root)
         *   - -prec-div=
         *     - 0 (use a faster approximation for single-precision
         *          floating-point division and reciprocals)
         *     - 1 (default, use IEEE round-to-nearest mode for
         *          single-precision floating-point division and reciprocals)
         *   - -fma=
         *     - 0 (disable FMA contraction)
         *     - 1 (default, enable FMA contraction)
         *
         debugz-gZoptz-opt=%darchz-arch=%s)ZftzZ	prec_sqrtZprec_divZfmaz-%s=%d_-z, zunsupported option {0}c             S   s   g | ]}t |d qS )utf8)r   encode).0xr   r   r   
<listcomp>   s   z+CompilationUnit.compile.<locals>.<listcomp>zFailed to compile
z&Failed to get size of compiled result.zFailed to get compiled result.N)popappendgetintboolreplacejoinmapreprkeysr   formatr   rL   r(   r   rI   
_try_errorr	   r    r
   r   r5   r!   get_loglog)r6   ZoptionsoptsZother_optionskvZoptstrZc_optsr9   reslenZptxbufr   r   r   compile   s8    #






zCompilationUnit.compilec             C   s   | j |d||  f  d S )Nz%s
%s)r(   r4   rd   )r6   r9   r@   r   r   r   rc      s    zCompilationUnit._try_errorc             C   sl   t  }| j| jt|}| j|d |jdkrht|j  }| j| j|}| j|d |j	dS dS )Nz#Failed to get compilation log size.r   zFailed to get compilation log.rS    )
r	   r(   r"   rI   r
   r4   r5   r   r#   decode)r6   ri   r9   Zlogbufr   r   r   rd      s    
zCompilationUnit.get_logN)	rB   rC   rD   rJ   rK   rN   rj   rc   rd   r   r   r   r   rH      s   
TrH   ze-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64ze-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64)    @      )r   r   )r   r   )
)   r   )rp   r   )r   r   )r   r   )r   r   )r   rp   )r   r   )   r   )rq   r   )rq   rp   )	)r   r   )r   r   )r   r   )r   rp   )r   r   )rq   r   )rq   r   )rq   rp   )   r   c             C   sX   xNt tD ]B\}}|| kr|S || kr
|dkr@td| |  q
t|d  S q
W td S )Nr   z@GPU compute capability %d.%d is not supported (requires >=%d.%d)r   )	enumerateSUPPORTED_CCr   )ZmycciZccr   r   r   
_find_arch$  s    rw   c             C   s"   t jrt j}nt| |f}d| S )z1Matches with the closest architecture option
    zcompute_%d%d)r   ZFORCE_CUDA_CCrw   )r7   r8   rP   r   r   r   get_arch_option7  s    rx   z
Please define environment variable NUMBAPRO_LIBDEVICE=/path/to/libdevice
/path/to/libdevice -- is the path to the directory containing the libdevice.*.bc
files in the installation of CUDA.  (requires CUDA >=8.0)
zMissing libdevice file for {arch}.
Please ensure you have package cudatoolkit 8.0.
Install package by:

    conda install cudatoolkit=8.0
c               @   s4   e Zd Zi ZddddgZdd Zdd Zd	d
 ZdS )	LibDevice
compute_20Z
compute_30Z
compute_35Z
compute_50c             C   sT   || j kr>| |}t|dkr0ttj|dt|| j |< || _| j | | _dS )z@
        arch --- must be result from get_arch_option()
        N)rP   )	_cache__get_closest_archr   RuntimeErrorMISSING_LIBDEVICE_FILE_MSGrb   r   rP   bc)r6   rP   r   r   r   rJ   X  s    

zLibDevice.__init__c             C   s*   | j d }x| j D ]}||kr|}qW |S )Nr   )_known_arch)r6   rP   ZresZ	potentialr   r   r   r|   e  s
    
zLibDevice._get_closest_archc             C   s   | j S )N)r   )r6   r   r   r   rZ   l  s    zLibDevice.getN)rB   rC   rD   r{   r   rJ   r|   rZ   r   r   r   r   ry   O  s   ry   z
define internal i32 @___numba_cas_hack(i32* %ptr, i32 %cmp, i32 %val) alwaysinline {
    %out = cmpxchg volatile i32* %ptr, i32 %cmp, i32 %val monotonic
    ret i32 %out
}
ae  
define internal double @___numba_atomic_double_add(double* %ptr, double %val) alwaysinline {
entry:
    %iptr = bitcast double* %ptr to i64*
    %old2 = load volatile i64, i64* %iptr
    br label %attempt

attempt:
    %old = phi i64 [ %old2, %entry ], [ %cas, %attempt ]
    %dold = bitcast i64 %old to double
    %dnew = fadd double %dold, %val
    %new = bitcast double %dnew to i64
    %cas = cmpxchg volatile i64* %iptr, i64 %old, i64 %new monotonic
    %repeat = icmp ne i64 %cas, %old
    br i1 %repeat, label %attempt, label %done

done:
    %result = bitcast i64 %old to double
    ret double %result
}
a  
define internal {T} @___numba_atomic_{T}_max({T}* %ptr, {T} %val) alwaysinline {{
entry:
    %ptrval = load volatile {T}, {T}* %ptr
    ; Check if val is a NaN and return *ptr early if so
    %valnan = fcmp uno {T} %val, %val
    br i1 %valnan, label %done, label %lt_check

lt_check:
    %dold = phi {T} [ %ptrval, %entry ], [ %dcas, %attempt ]
    ; Continue attempts if dold < val or dold is NaN (using ult semantics)
    %lt = fcmp ult {T} %dold, %val
    br i1 %lt, label %attempt, label %done

attempt:
    ; Attempt to swap in the larger value
    %iold = bitcast {T} %dold to {Ti}
    %iptr = bitcast {T}* %ptr to {Ti}*
    %ival = bitcast {T} %val to {Ti}
    %cas = cmpxchg volatile {Ti}* %iptr, {Ti} %iold, {Ti} %ival monotonic
    %dcas = bitcast {Ti} %cas to {T}
    br label %lt_check

done:
    ; Return max
    %ret = phi {T} [ %ptrval, %entry ], [ %dold, %lt_check ]
    ret {T} %ret
}}
a  
define internal {T} @___numba_atomic_{T}_min({T}* %ptr, {T} %val) alwaysinline{{
entry:
    %ptrval = load volatile {T}, {T}* %ptr
    ; Check if val is a NaN and return *ptr early if so
    %valnan = fcmp uno {T} %val, %val
    br i1 %valnan, label %done, label %gt_check

gt_check:
    %dold = phi {T} [ %ptrval, %entry ], [ %dcas, %attempt ]
    ; Continue attempts if dold > val or dold is NaN (using ugt semantics)
    %lt = fcmp ugt {T} %dold, %val
    br i1 %lt, label %attempt, label %done

attempt:
    ; Attempt to swap in the smaller value
    %iold = bitcast {T} %dold to {Ti}
    %iptr = bitcast {T}* %ptr to {Ti}*
    %ival = bitcast {T} %val to {Ti}
    %cas = cmpxchg volatile {Ti}* %iptr, {Ti} %iold, {Ti} %ival monotonic
    %dcas = bitcast {Ti} %cas to {T}
    br label %gt_check

done:
    ; Return min
    %ret = phi {T} [ %ptrval, %entry ], [ %dold, %gt_check ]
    ret {T} %ret
}}
c             C   sF   |   }x2t|D ]&\}}|drd}|t||< P qW d|S )z@
    Find the line containing the datalayout and replace it
    ztarget datalayoutztarget datalayout = "{0}"
)
splitlinesrt   
startswithrb   default_data_layoutr^   )llvmirlinesrv   ZlnZtmpr   r   r   _replace_datalayout  s    
r   c          
   K   s   t  }t|ddd}t| } dtfdtfdtjddd	fd
tjddd	fdtjddd	fdtjddd	fg}x|D ]\}}| 	||} qvW t
| } || d ||  |jf |}t|S )NrP   rz   )rP   z.declare i32 @___numba_cas_hack(i32*, i32, i32)z;declare double @___numba_atomic_double_add(double*, double)z7declare float @___numba_atomic_float_max(float*, float)floati32)TZTiz;declare double @___numba_atomic_double_max(double*, double)ZdoubleZi64z7declare float @___numba_atomic_float_min(float*, float)z;declare double @___numba_atomic_double_min(double*, double)rS   )rH   ry   rZ   r   ir_numba_cas_hackir_numba_atomic_double_addir_numba_atomic_maxrb   ir_numba_atomic_minr]   llvm39_to_34_irrN   rT   rj   patch_ptx_debug_pubnames)r   rf   ZcuZ	libdeviceZreplacementsZdeclfnptxr   r   r   llvm_to_ptx  s,    r   c             C   sV   xP|  d}|dk rP |  d|}|dk r2td| d| | |d d  } qW | S )z
    Patch PTX to workaround .debug_pubnames NVVM error::

        ptxas fatal   : Internal error: overlapping non-identical data

    s   .section .debug_pubnamesr      }zmissing "}"Nr   )find
ValueError)r   startstopr   r   r   r     s    
 r   z	\!\d+\s*=zmetadata\s*\![{'\"0-9]z\!\d+z,\!{i32 \d, \!\"Debug Info Version\", i32 \d} z\s+z"^attributes #\d+ = \{ ([\w\s]+)\ }ZalwaysinlineZcoldZ
inlinehintZminsizeZnoduplicateZnoinlineZnoreturnZnounwindZoptnoneZoptiszeZreadnonereadonlyz"\bgetelementptr\s(?:inbounds )?\(?z=\s*\bload\s(?:\bvolatile\s)?z(call\s[^@]+\))(\s@)z\s*!range\s+!\d+z
[,{}()[\]]z\bnonnull\bz"\b(local_unnamed_addr|writeonly)\bz\((.*)\)c             C   sT  dd }g }x:|   D ],}|dr4|dd}| drVd|krV|dd}t|rdt|kr|d	d
}|dd}|d}|d|d  ||d d  }}dd }d	|t
||f}|drqt|dk	rtdd |}|drDt|}|d }	d	dd |	D }	||d|	}d|krt|}|dkrptd|f | }
|d|
 |||
d  }d|krt|}|r| }
|d|
 |||
d  }d|krtd|}td|d}d|krtt|}d |kr0d|kr0tt|}td|}|| qW d!	|S )"z+
    Convert LLVM 3.9 IR for LLVM 3.4.
    c             S   s   d}d}xpt | |}|d kr.td| f P | }|d}|dkrT|dkrvP q
|dkrf|d7 }q
|dkr
|d8 }q
W | |d   S )Nr   zfailed parsing leading type: %s,z{[(r   z)]})re_type_toksearchr}   endgrouplstrip)sZ	par_levelposmtokr   r   r   parse_out_leading_type,  s"    

z/llvm39_to_34_ir.<locals>.parse_out_leading_typez!numba.llvm.dbg.cuz!llvm.dbg.cuz%tail call void asm sideeffect "// dbgz
!numba.dbgz!dbgNz!{zmetadata !{z!"zmetadata !"=r   c             S   s   d|  d S )Nz	metadata r   )r   )r   r   r   r   fix_metadata_refY  s    z)llvm39_to_34_ir.<locals>.fix_metadata_refr   zsource_filename =c             S   s   dS )Nrk   r   )r   r   r   r   <lambda>_  s    z!llvm39_to_34_ir.<locals>.<lambda>zattributes #c             s   s   | ]}|t kr|V  qd S )N)supported_attributes)rU   ar   r   r   	<genexpr>e  s    z"llvm39_to_34_ir.<locals>.<genexpr>zgetelementptr z failed parsing getelementptr: %szload zcall z\1*\2rk   r   z@llvm.memsetZdeclarer   )r   r   r]   r   re_metadata_defmatchre_metadata_correct_usager   r   r^   re_metadata_refsubre_unsupported_keywordsre_attributes_defr   splitre_getelementptrr}   r   re_loadre_callre_rangerstripre_parenthesized_list_replace_llvm_memset_usage _replace_llvm_memset_declarationre_annotationsrY   )r   r   ZbuflineZassigposZlhsZrhsr   r   Zattrsr   r   r   r   r   (  sf    


"










r   c             C   sP   t | dd}td|d d}|dd| d|}d|S )	zNReplace `llvm.memset` usage for llvm7+.

    Used as functor for `re.sub.
    r   r   zalign (\d+)r   rs   zi32 {}z, z({}))listr   r   rer   insertrb   r^   )r   paramsZalignoutr   r   r   r     s
    
r   c             C   s4   t | dd}|dd d|}d|S )zTReplace `llvm.memset` declaration for llvm7+.

    Used as functor for `re.sub.
    r   r   rs   r   z, z({}))r   r   r   r   r^   rb   )r   r   r   r   r   r   r     s    
r   c             C   s   ddl m}m}m}m} | j}| ||d|| df}|||}|d}|	| t
d}	||	d|	d|	d|	dg}
|d|
 d S )	Nr   )MetaDataMetaDataStringConstantTypeZkernelr   znvvm.annotationsrm   rp   znvvmir.version)Zllvmlite.llvmpy.corer   r   r   r   modulerZ   r[   Zget_or_insert_named_metadataaddr   ZIntTypeZadd_metadataZadd_named_metadata)Zlfuncr   r   r   r   r   ZopsZmdZnmdr   Zmd_verr   r   r   set_cuda_kernel  s    


"r   c             C   s
   t | _d S )N)r   data_layout)r   r   r   r   fix_data_layout  s    r   )[rE   Z
__future__r   r   r   r=   Zloggingr   Zctypesr   r   r   r   r	   r
   r   Z	threadingZllvmliter   Znumbar   r?   r   r   Zlibsr   r   r   Z	getLoggerrB   ZloggerZADDRSPACE_GENERICZADDRSPACE_GLOBALZADDRSPACE_SHAREDZADDRSPACE_CONSTANTZADDRSPACE_LOCALrG   rF   r   r;   rt   rv   rg   r-   modulesr   ZLockr$   r&   r   rH   r   tuple__itemsize__r   r:   ZNVVM_VERSIONru   rw   rx   ZMISSING_LIBDEVICE_MSGr~   ry   r   r   r   r   r   r   r   rj   r   r   r   r]   Zre_metadata_debuginfor   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   <module>   s   $
Q}

&"











j