##################################
#                                #
# Last modified 2018/05/21       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import GradientBoostingRegressor

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s datafilename chrFieldID DNAse_field_ID TFs_field_IDs N_estimators learning_rate outprefix [-max_features int] [-max_depth int/None]' % sys.argv[0]
        print '\tformat of TFs_field_IDs: combination of comma separated and start-end (including end)'
        print '\theader line is assumed' 
        print '\tuse - for input if you want to read from standard input' 
        print '\t!!!! Make sure the input file is not sorted by signal !!!!g' 
        sys.exit(1)

    doCS=False
    if '-cufflinksStatus' in sys.argv:
        doCS=True
        print 'will discard FAIL status entries'

    datafilename = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    DNAseFieldID = int(sys.argv[3])
    TFIDfields = []
    if ',' in sys.argv[4]:
        blocks = sys.argv[4].split(',')
        for block in blocks:
            if '-' in block:
                fields1 = int(block.split('-')[0])
                fields2 = int(block.split('-')[1])
                for ID in range(fields1,fields2+1):
                    TFIDfields.append(ID)
            else:
                TFIDfields.append(int(block))
    Nest = int(sys.argv[5])
    LR = float(sys.argv[6])
    outprefix = sys.argv[7]

    predictors = []
    DNAse = []
    coordinates = []

    labels = []

    if datafilename == '-':
        lineslist  = sys.stdin
    else:
        lineslist  = open(datafilename)
    for line in lineslist:
        fields = line.strip().split('\t')
        if line[0]=='#':
            for ID in TFIDfields:
                labels.append(fields[ID])
            continue
        DNAse_scores = float(fields[DNAseFieldID])
        DNAse.append(DNAse_scores)
        chr_start_end = fields[chrFieldID] + ':' + fields[chrFieldID + 1] + '-' + fields[chrFieldID + 2]
        coordinates.append(chr_start_end)
        scores = []
        for ID in TFIDfields:
            s = float(fields[ID])
            if s > 0:
                s = 1.0
            else:
                s = 0.0
            scores.append(s)
        predictors.append(scores)

    print 'finished inputting data'

    y_train = DNAse[:len(DNAse)/2]
    y_test = DNAse[len(DNAse)/2:]

    X_train = predictors[:len(DNAse)/2]
    X_test = predictors[len(DNAse)/2:]

#    params = {'n_estimators':Nest, 'learning_rate':LR, 'max_depth':1, 'random_state':0, 'loss':'ls'}

    params = {'n_estimators':Nest, 'learning_rate':LR, 'loss':'ls'}

    if '-max_features' in sys.argv:
        params['max_features'] = int(sys.argv[sys.argv.index('-max_features') + 1])

    if '-max_depth' in sys.argv:
        params['max_depth'] = int(sys.argv[sys.argv.index('-max_depth') + 1])

    est = GradientBoostingRegressor(**params).fit(X_train, y_train)

    mse = mean_squared_error(y_test, est.predict(X_test))

    print "MSE: %.4f" % mse

    outfile = open(outprefix + '.deviance', 'w')

    outline = '#Boosting_iteration\tTraining Set Deviance\tTest Set Deviance'
    outfile.write(outline + '\n')

    test_score = np.zeros((params['n_estimators'],), dtype=np.float64)

    for i, y_pred in enumerate(est.staged_predict(X_test)):
        test_score[i] = est.loss_(y_test, y_pred)

    for i, y_pred in enumerate(est.staged_predict(X_test)):
#        outline = str(params['n_estimators'] + 1) + '\t' + est.train_score_[i] + '\t' + test_score[i]
        outline = str(i + 1) + '\t' + str(est.train_score_[i]) + '\t' + str(test_score[i])
        outfile.write(outline + '\n')

    outfile.close()
        
    outfile = open(outprefix + '.feature_importance', 'w')

    outline = '#Feature\tImportance'
    outfile.write(outline + '\n')

    feature_importance = est.feature_importances_
    feature_importance = 100.0 * (feature_importance / feature_importance.max())

    for i in range(len(labels)):
        outline = labels[i] + '\t' + str(feature_importance[i])
        outfile.write(outline + '\n')

    outfile.close()

run()

