U
    f!|                     @   s  d Z ddlZddlZddlZddlmZ ddlZddlZddl	Z	ddl
mZmZmZ ddlmZ ddlmZmZmZmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZmZmZ ddl m!Z! ddddgZ"eddd\Z#Z$e %e#dd Z&e$dd Z'eddd\Z#Z$e %e#dd Z(e$dd Z)e&e'fe(e)fgZ*edddddd\Z+Z,ee,Z,e+e,fgZ-e Z.e.j/Z0e.j1Z2dd Z3dd Z4dd Z5e	j67d e*d!d" Z8e	j67d e-d#d$ Z9e	j67d e*d%d& Z:e	j67d e-d'd( Z;d)d* Z<d+d, Z=d-d. Z>d/d0 Z?d1d2 Z@d3d4 ZAd5d6 ZBd7d8 ZCd9d: ZDd;d< ZEd=d> ZFd?d@ ZGdAdB ZHdCdD ZIdEdF ZJdGdH ZKe	j67dIeegdJdK ZLdLdM ZMe!eNdNdOdP ZOe	j67dIeegdQdR ZPdSdT ZQe!edNdUdV ZRdWdX ZSdYdZ ZTd[d\ ZUe	j67d]ejVejWge	j67d^eegd_d` ZXdadb ZYe	j67d^eegdcdd ZZe	j67dIeegdedf Z[e	j67dIeege	j67dgdhdidjgdkdl Z\e	j67dIeegdmdn Z]dS )ozD
Testing for Multi-layer Perceptron module (sklearn.neural_network)
    N)StringIO)assert_allcloseassert_almost_equalassert_array_equal)
csr_matrix)load_digits	load_irismake_multilabel_classificationmake_regression)ConvergenceWarning)roc_auc_score)MLPClassifierMLPRegressor)LabelBinarizerMinMaxScalerscale)ignore_warningsidentitylogistictanhrelu   T)Zn_classZ
return_X_y      
   g      4@g      Y@   )	n_samples
n_featuresZbiasnoiserandom_statec               
   C   s   t d d } td d }g }td}dd }|D ]\}td|dd}ttd || | W 5 Q R X |t	||j
d	 ||j
d g q2tt|d D ] }|| ||d  k stqd S )
Nd   r   c                 S   s   t t | S N)npsumabs)x r&   L/tmp/pip-target-lpfmz8o1/lib/python/sklearn/neural_network/tests/test_mlp.py<lambda>G       ztest_alpha.<locals>.<lambda>r      )hidden_layer_sizesalphar   categoryr   )X_digits_binaryy_digits_binaryr"   aranger   r   r   fitappendarraycoefs_rangelenallAssertionError)XyZalpha_vectorsZalpha_valuesZabsolute_sumr,   mlpir&   r&   r'   
test_alpha@   s    
 r>   c               
   C   s  t dddgg} t dg}tddddddd	dd
}dgd	 |_dgd	 |_d|_t ddgddgddgg|jd< t dgdgg|jd< t ddg|jd< t dg|jd< g d	 |_g d	 |_d|_d|_	d|_
d|_dg|jd  |_dg|jd  |_d|_d|_t j|_g |_d|_dd |jD |_dd |jD |_|j| |ddgd t|jd t ddgddgddggdd t|jd t dgdggdd t|jd t ddgdd t|jd t ddd t|| d  d!dd d S )"Ng333333?皙?gffffff?r   sgdg?r   r*   r   )solverlearning_rate_initr,   
activationr   max_iterr+   momentum皙?g333333?      ?      ?r   c                 S   s   g | ]}t |qS r&   r"   Z
zeros_like).0Z
interceptsr&   r&   r'   
<listcomp>   s    ztest_fit.<locals>.<listcomp>c                 S   s   g | ]}t |qS r&   rI   )rJ   Zcoefsr&   r&   r'   rK      s     classesgJ+?guX?gડ2?g'?g.NV?gVSbb)decimalgF??g#070?gY,?g~?gZd;O?)r   r*   gS?)r"   r4   r   r5   intercepts_
n_outputs_Z_coef_gradsZ_intercept_gradsZn_features_in_n_iter_Zlearning_rate_	n_layers_Zout_activation_Zt_inf
best_loss_Zloss_curve__no_improvement_countZ_intercept_velocityZ_coef_velocitypartial_fitr   predict_probar:   r;   r<   r&   r&   r'   test_fitU   s\    "

"	rY   c                     s  dD ]} d}d}t jjdd}||| dt t |d |  }t |tD ]}t	|dddd	ddd

 | t dd jj D } jd gjg jg }g g g g   tjd D ]}t  jd ||d  f t  jd ||d  f || }	||d  }
t |	|
f t |
 qއ fdd}||\}}t t |}t |d}t |}d}t|D ]F}|d d |f | }||| d ||| d  |d  ||< qt|| qVqd S )N)r   r      r   *   )seedr*   lbfgsh㈵>rF   )rC   r+   rA   r,   rB   rD   r   c                 S   s   g | ]}|  qS r&   )Zravel)rJ   lr&   r&   r'   rK      s     z!test_gradient.<locals>.<listcomp>r   c              	      s    |  S r!   )Z_loss_grad_lbfgs)tr:   YZactivationsZ
coef_gradsZdeltasZintercept_gradsr<   r&   r'   loss_grad_fun   s          z$test_gradient.<locals>.loss_grad_fung       @)r"   randomRandomStateZrandmodr1   r   fit_transformACTIVATION_TYPESr   r2   Zhstackr5   rO   shaper+   rP   r3   r6   rR   emptyZzerossizeeyer   )Zn_labelsr   r   r   r;   rC   thetaZlayer_unitsr=   Zfan_inZfan_outrc   valueZgradZnumgradnEepsilonZdthetar&   ra   r'   test_gradient   sX    

	
""
rr   zX,yc           	   	   C   s   | d d }|d d }| dd  }|j d |jjf}tD ]\}tddddd|d}||| ||}|||dks|t|j d |jjf|ks:tq:d S )	N   r   r]   2   Tr*   rA   r+   rD   shuffler   rC   ffffff?)	ri   dtypekindrh   r   r2   predictscorer9   )	r:   r;   ZX_trainZy_trainZX_testZexpected_shape_dtyperC   r<   Z	y_predictr&   r&   r'   test_lbfgs_classification   s"    
r|   c              	   C   s`   t D ]V}tddddd|d}|| | |dkrF|| |dksZtq|| |d	kstqd S )
Nr]   rt   rs   Tr*   ru   r   r?   g\(\?)rh   r   r2   r{   r9   )r:   r;   rC   r<   r&   r&   r'   test_lbfgs_regression  s    r}   c              
   C   sX   d}t D ]J}tddd|dd|d}tt  || | ||jksHtW 5 Q R X qd S )Nr   r]   rt   rs   Tr*   )rA   r+   rD   max_funrv   r   rC   )rh   r   pytestwarnsr   r2   rQ   r9   r:   r;   r~   rC   r<   r&   r&   r'    test_lbfgs_classification_maxfun$  s    	r   c                 C   sZ   d}t D ]L}tdddd|dd|d}tt  || | ||jksJtW 5 Q R X qd S )	Nr   r]   rt   g        rs   Tr*   )rA   r+   tolrD   r~   rv   r   rC   )rh   r   r   r   r   r2   rQ   r9   r   r&   r&   r'   test_lbfgs_regression_maxfun9  s    
r   c               
   C   s   ddgddgddgddgg} ddddg}d	D ]}t d
d|dddd}ttd. || | |jj}|| | |jj}W 5 Q R X |dkr||kstq,|dkr,|jtd|j	 |ks,tq,d S )Nr   r   r*      rZ   r   )
invscalingconstantr@      g      ?T)rA   r+   learning_raterD   power_t
warm_startr-   r   r   	   )
r   r   r   r2   
_optimizerr   r9   rB   powr   )r:   r;   r   r<   Zprev_etaZpost_etar&   r&   r'   test_learning_rate_warmstartO  s(    r   c               	   C   s   t dddd\} }tddddddd	d
}|| | || |dksHttddddddd	d}tdD ]}|j| |dddddgd qf|| |dksttdd}|| ||  d S )Nrt   r   Tr   r   Zreturn_indicatorr]   r^   rs   r   rF   )rA   r+   r,   rD   r   rC   rB   g
ףp=
?r@   )rA   r+   rD   r   rC   r,   rB   r    r*   r   r   r   rL   ?early_stopping)r	   r   r2   r{   r9   r6   rV   rz   )r:   r;   r<   r=   r&   r&   r'   test_multilabel_classificationh  s<      
		
r   c                  C   sD   t ddd\} }tddddd}|| | || |dks@td S )	Nr   rZ   )r   	n_targetsr]   rt   r*   )rA   r+   rD   r   r   )r
   r   r2   r{   r9   rX   r&   r&   r'   test_multioutput_regression  s       r   c               	   C   s\   ddgg} dg}t dd}|j| |ddgd tt |j| |ddgd W 5 Q R X d S )Nr   r   r   r@   rA   r*   rL   )r   rV   r   raises
ValueErrorr:   r;   clfr&   r&   r'   test_partial_fit_classes_error  s    

r   c               
   C   s   t D ]\} }tddddddd}ttd || | W 5 Q R X || }tddddd	}tdD ]}|j| |t	|d
 qd|| }t
|| || |dkstqd S )Nr@   r    r*   r   r^   rF   )rA   rD   r   r   r,   rB   r-   )rA   r   r,   rB   rL   rw   )classification_datasetsr   r   r   r2   rz   r6   rV   r"   uniquer   r{   r9   )r:   r;   r<   pred1r=   pred2r&   r&   r'   test_partial_fit_classification  s.    	
   

r   c                  C   sv   t dd} | jdgdgdggdddgdddd	gd
 | dggd	g | dgdgdgdggdddd	gdksrtd S )Nr   )r   r*   r   r   abcdrL   r   )r   rV   r{   r9   )r   r&   r&   r'   test_partial_fit_unseen_classes  s    
*r   c               
   C   s   t } t}dD ]}tddddd| jd |d}tjd	d
 || | W 5 Q R X || }tdddd| jd |d}tdD ]}|	| | q||| }t
|| || |}|dkstqd S )N)r   r   r@   r    r   r*   {Gz?r   )rA   rD   rC   r   rB   
batch_sizerE   T)record)rA   rC   rB   r   r   rE   g?)X_regy_regr   ri   warningscatch_warningsr2   rz   r6   rV   r   r{   r9   )r:   r;   rE   r<   r   r=   r   r{   r&   r&   r'   test_partial_fit_regression  s:    	


r   c               	   C   s^   ddgddgg} ddg}t t tddj| |dgd W 5 Q R X ttd	dd
rZtd S )Nr   r   r*   r   r   r@   r   rL   r]   rV   )r   r   r   r   rV   hasattrr9   )r:   r;   r&   r&   r'   test_partial_fit_errors  s
    "r   c               	   C   sr   t jd} d}t t jj}|| j|dfd }| j|d}t }d}t	j
t|d ||| W 5 Q R X d S )Nr   r   r   )rk   zrSolver produced non-finite parameter weights. The input data may contain large values and need to be preprocessed.match)r"   rd   re   Zfinfofloat64maxuniformZstandard_normalr   r   r   r   r2   )rngr   Zfmaxr:   r;   r   msgr&   r&   r'   test_nonfinite_params  s    r   c            	   	   C   s   t d d } td d }tdddd}ttd || | W 5 Q R X || }|| }|jd d }}|j	dd	}|j	dd	}|j||fkst
t|| t|t| t||d d df d
kst
d S )Nrt   rZ   r   r*   )r+   rC   r   r-   r   r   ZaxisrH   )r/   r0   r   r   r   r2   rW   predict_log_probari   argmaxr9   r   r   r"   logr   	r:   r;   r   y_probay_log_probar   	n_classes	proba_maxproba_log_maxr&   r&   r'   test_predict_proba_binary  s    


r   c            	   	   C   s   t d d } td d }tdd}ttd || | W 5 Q R X || }|| }|jd t	
|j }}|jdd}|jdd}|j||fkstt|| t|t	| d S )Nr   rZ   )r+   r-   r   r*   r   )X_digits_multiy_digits_multir   r   r   r2   rW   r   ri   r"   r   rk   r   r9   r   r   r   r   r&   r&   r'   test_predict_proba_multiclass"  s    



r   c            	      C   s   t dddd\} }|j\}}tdddd}|| | || }|j||fksRtt|dk| || }|jd	d
}|jd	d
}|	d	d	 
|	d	d	 dkstt|| t|t| d S )Nrt   r   Tr   r]      rA   r+   r   rG   r*   r   g|=)r	   ri   r   r2   rW   r9   r   r   r   r#   dotr   r"   r   )	r:   rb   r   r   r   r   r   r   r   r&   r&   r'   test_predict_proba_multilabel7  s"      



&
r   c                  C   s   t ddddd\} }dD ]\}tdddd|d}tdddd|d}|| | || | t|jd |jd stqtdddddd}tddddd	d}|| | || | t|jd |jd rtd S )
Nrt   rZ   r*   r   )r   r   r   r   )TF)r+   rD   r   r   rv   TF)r
   r   r2   r"   Zarray_equalr5   r9   )r:   r;   rv   Zmlp1Zmlp2r&   r&   r'   test_shuffleO  sH            r   c                  C   s   t d d } td d }t| }tdddd}|| | || }||| ||}t|| || }||}t|| d S )Nrt   r]      r*   r   )r/   r0   r   r   r2   rz   r   r   )r:   r;   ZX_sparser<   r   r   r&   r&   r'   test_sparse_matricesu  s    




r   c                  C   sF   ddgddgg} ddg}t dddd	}|| | |j|jksBtd S )
Nr   r   r*   r   r   rG     r@   )r   rD   rA   )r   r2   rD   rQ   r9   r   r&   r&   r'   test_tolerance  s
    r   c               	   C   s   ddgddgg} ddg}t ddddd}tj}t  t_}ttd	 || | W 5 Q R X || | |t_d
| ks~t	d S )Nr   r   r*   r   r   r@   r   )rA   rD   verboser+   r-   Z	Iteration)
r   sysstdoutr   r   r   r2   rV   getvaluer9   )r:   r;   r   Z
old_stdoutoutputr&   r&   r'   test_verbose_sgd  s    r   MLPEstimatorc                 C   s   t d d }td d }d}| |dddd}||| |j|jksHt|jd ksVtt|jt	sft|j}|j
}t||kst|| |d kst|| |d kst| |ddd	d}||| |jd kst|j
d kst|jd k	std S )
Nr    rF   r   r@   T)r   rD   rA   r   r   F)r/   r0   r2   rD   rQ   r9   rT   
isinstancevalidation_scores_listZbest_validation_score_r   )r   r:   r;   r   Zmlp_estimatorZvalid_scoresZbest_valid_scorer&   r&   r'   test_early_stopping  s8          r   c                  C   sX   ddgddgg} ddg}t dddd	d
}|| | |j|jksDtd|jjksTtd S )Nr   r   r*   r   r   rG   r   r@   Zadaptive)r   rD   rA   r   gư>)r   r2   rD   rQ   r9   r   r   r   r&   r&   r'   test_adaptive_learning_rate  s    r   r-   c            
   
   C   sf  t } t}tdgd dgd  }tdgd dgd  dgd  }tdgd dgd  dgd  }tdgd	 dgd	  dgd
  dgd
  }tdgd dgd  dgd  dgd  dgd  }tdddd| |}|| | || | ||||fD ]V}tdddd| |}dt| }	tjt	t
|	d || | W 5 Q R X q
d S )Nr   K   r*   (   r   F   rt   r   %   &   r   r   r]   T)r+   rA   r   z}warm_start can only be used where `y` has the same classes as in the previous call to fit. Previously got [0 1 2], `y` has %sr   )X_irisy_irisr"   r4   r   r2   r   r   r   r   reescape)
r:   r;   Z
y_2classesZ
y_3classesZy_3classes_altZ
y_4classesZ
y_5classesr   Zy_imessager&   r&   r'   test_warm_start  s*    $$.8 r   c                 C   sV   t t }}d}| ddd|d}||| ||jks8t||| ||jksRtd S )Nr   r   r@   T)r+   rA   r   rD   )r   r   r2   rQ   r9   )r   r:   r;   rD   r   r&   r&   r'   test_warm_start_full_iteration  s    
   r   c                  C   sj   t d d } td d }d}d}dD ]@}t||d|d}|| | |j|d ksVt||jks$tq$d S )Nr    r   r   )r   rZ   r   rt   r    r@   r   rD   rA   n_iter_no_changer*   )r/   r0   r   r2   rU   r9   rQ   )r:   r;   r   rD   r   r   r&   r&   r'   test_n_iter_no_change  s       r   c                  C   sh   t d d } td d }d}tj}d}t||d|d}|| | |j|ksPt|j|jd ksdtd S )Nr    g    eAr   r@   r   r*   )	r/   r0   r"   rS   r   r2   rQ   r9   rU   )r:   r;   r   r   rD   r   r&   r&   r'   test_n_iter_no_change_inf
  s       r   c               	   C   s\   ddgddgddgddgg} ddddg}t dd}tjtd	d
 || | W 5 Q R X d S )Nr*   r   r   r   rZ   r   Tr   z0The least populated class in y has only 1 memberr   )r   r   r   r   r2   rX   r&   r&   r'   test_early_stopping_stratified$  s    
 r   c                  C   s   t ddddd} | td d td d  | tdd  }| tdd  }t ddddd}|td d tjtd d  |tdd  tj}|tdd  tj}t	|| t
||dd d S )	Nr^   rZ   r   r*   rt   r,   r+   r   rD   ,  r   Zrtol)r   r2   X_digitsy_digitsrz   rW   astyper"   float32r   r   )mlp_64pred_64Zproba_64mlp_32pred_32Zproba_32r&   r&   r'   "test_mlp_classifier_dtypes_casting0  s(          $
r   c                  C   s   t ddddd} | td d td d  | tdd  }t ddddd}|td d tjtd d  |tdd  tj}t||dd d S )	Nr^   r   r*   rt   r   r   -C6?r   )	r   r2   r   r   rz   r   r"   r   r   )r   r   r   r   r&   r&   r'   !test_mlp_regressor_dtypes_castingD  s"          $r   rx   	Estimatorc                    s   t  t }}|ddddd}||d d |d d  ||dd  }t fdd|jD sjtt fd	d|jD st|t	kr|j
 kstd S )
Nr^   r   r*   rt   r   r   c                    s   g | ]}|j  kqS r&   rx   )rJ   Z	interceptr   r&   r'   rK   ^  s     z)test_mlp_param_dtypes.<locals>.<listcomp>c                    s   g | ]}|j  kqS r&   r   )rJ   Zcoefr   r&   r'   rK   `  s     )r   r   r   r2   rz   r8   rO   r9   r5   r   rx   )rx   r   r:   r;   r<   predr&   r   r'   test_mlp_param_dtypesT  s    r  c           
      C   s   t ddddd}dggdg }}||| | d }t|| t|}dggd	g }}tdD ]}||| qb||}	t|	|d
d dS )zYLoading from MLP and partial fitting updates weights. Non-regression
    test for #19626.)r[   r[   r   r   )r+   r   rB   rD   r   r   zmlp.pklr*   r   r   N)	r   r2   joblibdumploadr6   rV   rz   r   )
Ztmp_pathZpre_trained_estimatorfeaturestargetZpickled_fileZload_estimatorZfine_tune_featuresZfine_tune_target_Zpredicted_valuer&   r&   r'   (test_mlp_loading_from_joblib_partial_fitf  s        

r	  c              	   C   s   t d}tjd}|j|ddddgd}|jtddd	d
}| ddd}t	
  t	dt ||| W 5 Q R X dS )zCheck that feature names are preserved when early stopping is enabled.

    Feature names are required for consistency checks during scoring.

    Non-regression test for gh-24846
    Zpandasr   r   r   Z	colname_aZ	colname_b)datacolumnsr*   Z	colname_y)r
  nameTrF   )r   Zvalidation_fractionerrorN)r   Zimportorskipr"   rd   re   Z	DataFrameZrandnZSeriesfullr   r   simplefilterUserWarningr2   )r   pdr   r:   r;   modelr&   r&   r'   test_preserve_feature_names  s    

r  c                 C   sT   | ddddd}| tt t|j}|jdd | tt t|j|ksPtdS )z0Check that early stopping works with warm start.r   r   T)rD   r   r   r      rD   N)r2   r   r   r7   r   
set_paramsr9   )r   r<   Zn_validation_scoresr&   r&   r'   'test_mlp_warm_start_with_early_stopping  s       
r  rA   r@   Zadamr]   c              	   C   s   | |dddt jdd}tt |tt W 5 Q R X |jdksFt	|j
dd tt |tt W 5 Q R X |jdkst	dS )	zCheck that we stop the number of iteration at `max_iter` when warm starting.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/24764
    TFr   r   )rA   r   r   rD   r   r   r  r  N)r"   rS   r   r   r   r2   r   r   rQ   r9   r  )r   rA   r  r&   r&   r'   "test_mlp_warm_start_no_convergence  s    	r  c              	   C   sB   | ddd tt}d}tjt|d |tt W 5 Q R X dS )zoCheck partial fit does not fail after fit when early_stopping=True.

    Non-regression test for gh-25693.
    Tr   )r   r   z0partial_fit does not support early_stopping=Truer   N)r2   r   r   r   r   r   rV   )r   r<   r   r&   r&   r'   test_mlp_partial_fit_after_fit  s    r  )^__doc__r   r   r   ior   r  numpyr"   r   Znumpy.testingr   r   r   Zscipy.sparser   Zsklearn.datasetsr   r   r	   r
   Zsklearn.exceptionsr   Zsklearn.metricsr   Zsklearn.neural_networkr   r   Zsklearn.preprocessingr   r   r   Zsklearn.utils._testingr   rh   r   r   rg   r   r   r/   r0   r   r   r   Zregression_datasetsZirisr
  r   r  r   r>   rY   rr   markZparametrizer|   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   RuntimeWarningr   r   r   r   r   r   r   r   r   r  r	  r  r  r  r  r&   r&   r&   r'   <module>   s       

aA



&


%&

	




