In [41]:
import numpy as np
import h5py
import keras
from collections import OrderedDict

celltype_to_corefiles = {   
    'A549': {
        'scores_h5file': 'A549_scores_5k.h5',
        'scores_seqids': 'A549_toppredpos_5k.txt',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/A549/A549.summits.400bp.implanted.valid.bed.gz",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/A549/A549.motif.matches.txt',
    },
    'HepG2': {
        'scores_h5file': 'HepG2_scores_5k.h5',
        'scores_seqids': 'HepG2_toppredpos_5k.txt',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/HepG2/HepG2.summits.400bp.implanted.valid.bed.gz",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/HepG2/HepG2.motif.matches.txt',
    },
    'H1ESC': {
        'scores_h5file': 'H1ESC_scores_5k.h5',
        'scores_seqids': 'H1ESC_toppredpos_5k.txt',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/H1ESC/H1ESC.summits.400bp.implanted.valid.bed.gz",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/H1ESC/H1ESC.motif.matches.txt',
    },
    'K562': {
        'scores_h5file': 'K562_scores_5k.h5',
        'scores_seqids': 'K562_toppredpos_5k.txt',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/deepsea_K562/K562.pos.summits.valid.implanted.bed.gz",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/K562/K562.pos.motif.matches.txt',
    } 
}


In [42]:
!ls

A549_motifauprcs.tsv		      H1ESC_scores_1k.h5
A549_motifaurocs.tsv		      H1ESC_scores_5k.h5
A549_motifscoring_results.json.gz     H1ESC_toppredpos_1k.txt
A549_scores_1k.h5		      H1ESC_toppredpos_5k.txt
A549_scores_5k.h5		      HepG2_motifauprcs.tsv
A549_toppredpos_1k.txt		      HepG2_motifaurocs.tsv
A549_toppredpos_5k.txt		      HepG2_motifscoring_results.json.gz
CompareGCdist.ipynb		      HepG2_scores_1k.h5
ComputeMotifAuROCs.ipynb	      HepG2_scores_5k.h5
Compute_Scores.ipynb		      HepG2_toppredpos_1k.txt
evanbcopy			      HepG2_toppredpos_5k.txt
GM12878_motifauprcs.tsv		      K562_scores_1k.h5
GM12878_motifaurocs.tsv		      K562_toppredpos_1k.txt
GM12878_motifscoring_results.json.gz  PlotOutputOnReference.ipynb
H1ESC_motifauprcs.tsv		      retrain_models
H1ESC_motifaurocs.tsv		      VizMotifAuROCResults.ipynb
H1ESC_motifscoring_results.json.gz


In [43]:
import h5py
import numpy as np
import gzip
from collections import namedtuple, defaultdict
import sys
from sklearn.metrics import roc_auc_score, average_precision_score

MotifMatch = namedtuple("MotifMatch",
                        ["motifname", "seqname", "start", "end",
                         "strand", "hitstrength", "matchstring"])


def load_motif_matches(motif_match_file, seqnames_to_include):
    seqname_to_motifmatches = {}
    #returns a dictionary that maps seqname to a list of MotifMatch objects
    for row in open(motif_match_file):
        (motifname, seqname, homerstart, homerend,
         strand, hitstrength, matchstring) = row.rstrip().split("\t")
        if (seqname in seqnames_to_include):
            motifmatch = MotifMatch(motifname=motifname.split("-")[1],
                                    seqname=seqname,
                                    start=int(homerstart)-1, #1-indexed-inclusive to 0-indexed-inclusive
                                    end=int(homerend), #1-indexed-inclusive to 0-indexed-exclusive
                                    strand=strand,
                                    hitstrength=float(hitstrength),
                                    matchstring=matchstring)
            if (seqname not in seqname_to_motifmatches):
                seqname_to_motifmatches[seqname] = []
            seqname_to_motifmatches[seqname].append(motifmatch)
    return seqname_to_motifmatches


def onehot_encode(seqs):
    ltr = {'A': [1,0,0,0], 'C': [0,1,0,0], 'G': [0,0,1,0], 'T': [0,0,0,1], 'N': [0,0,0,0]}
    return np.array([[ltr[x] for x in seq.upper()] for seq in seqs])


def get_indices_of_subset(superset_seqnames, subset_seqnames):
    seqname_to_idx = dict([(x[1], x[0]) for x in enumerate(superset_seqnames)])
    idx_ordering = [seqname_to_idx[x] for x in subset_seqnames]
    return idx_ordering


def load_posseqs(corefiles, pos_idx_ordering):
    all_posseqs = [x.decode("utf-8").rstrip().split("\t")[1]
                   for x in gzip.open(corefiles['positive_sequences_file'])]
    posseqs = [all_posseqs[idx] for idx in pos_idx_ordering]
    return onehot_encode(posseqs), posseqs


def load_scores(corefiles, onehot_seqs):
    h5pyfile = h5py.File(corefiles['scores_h5file'], "r")
    method_to_scores = {}
    for method in h5pyfile.keys():
        scores = np.array(h5pyfile[method][:])
        #sanity check with onehot_seqs
        assert np.max(np.abs(np.sum(scores*onehot_seqs, axis=-1)
                             - np.sum(scores, axis=-1)))==0.0
        method_to_scores[method] = np.sum(scores,axis=-1)
    return method_to_scores


def get_sum_scores_in_window(scores, windowlen):
    assert len(scores.shape)==2
    cumsum_scores = np.pad(array=np.cumsum(scores, axis=-1),
                           pad_width=((0,0),(1,0)),
                           mode='constant',
                           constant_values=0)
    assert cumsum_scores.shape==(scores.shape[0], scores.shape[1]+1)
    to_return = cumsum_scores[:,windowlen:]-cumsum_scores[:,0:-windowlen]
    assert to_return.shape==(scores.shape[0], scores.shape[1]-(windowlen-1))
    return to_return


def get_scores_for_common_sequences(corefiles):
    #######
    #Load all the seqnames
    positives_seqnames = [x.decode("utf-8").rstrip().split("\t")[0]
                          for x in gzip.open(corefiles['positive_sequences_file'])]
    subset_seqnames = [x.rstrip() for x in open(corefiles['scores_seqids'])]   
    ########
    #Figure out the mapping from sequence to indices for the common seqnames
    positives_idx_ordering = get_indices_of_subset(superset_seqnames=positives_seqnames,
                                                   subset_seqnames=subset_seqnames)
    ########
    #Load the data using the idx ordering
    onehot_posseqs, posseqs = load_posseqs(corefiles=corefiles,
                                           pos_idx_ordering=positives_idx_ordering)
    method_to_scores = load_scores(corefiles=corefiles,onehot_seqs=onehot_posseqs)  
    #strip away the 'dinuc_shuffled_motifs_implanted_' from the front
    seqnames = [x.replace("dinuc_shuffled_motifs_implanted_", "")
                for x in subset_seqnames]
    return method_to_scores, onehot_posseqs, posseqs, seqnames


def get_motifmatches_and_nullwindows_mask(corefiles, seqnames):
    print("Reading in motif file")
    sys.stdout.flush()
    seqname_to_motifmatches = load_motif_matches(
        motif_match_file=corefiles['motifmatches_file'],
        seqnames_to_include=seqnames)
    print("Read motif file")
    sys.stdout.flush()
    motifmatches_in_seqs = [seqname_to_motifmatches[x] for x in seqnames]
    
    #Get locations of each motif
    #Also get a mask for locations obscured by the motifs
    #covered_positions has a 1 if there is a motif at the
    # position and 0 otherwise
    covered_positions = []
    motifname_to_hitlocations = defaultdict(list)
    motifname_to_motiflen = {}
    for seqidx,(motifmatches, seq) in enumerate(zip(motifmatches_in_seqs, seqs)):
        covered_positions_entry = np.zeros(len(seq))
        for motifmatch in motifmatches:
            #sanity check
            assert motifmatch.matchstring == seq[motifmatch.start:motifmatch.end].upper(), (
                        motifmatch.matchstring, seq[motifmatch.start:motifmatch.end])  
            covered_positions_entry[motifmatch.start:motifmatch.end] = 1
            motifname_to_hitlocations[motifmatch.motifname].append(
                (seqidx, motifmatch.start))
            if motifmatch.motifname in motifname_to_motiflen:
                assert len(motifmatch.motifname)==motifname_to_motiflen[motifmatch.motifname]
            else:
                motifname_to_motiflen[motifmatch.motifname] = len(motifmatch.motifname)
        covered_positions.append(covered_positions_entry)
    covered_positions = np.array(covered_positions)
    assert len(covered_positions)==len(seqs)
    
    #get a mapping from motiflen to windows with all zeros, i.e. the negatives
    motiflens = sorted(set(len(y.motifname) for x in motifmatches_in_seqs for y in x))
    motiflen_to_nullwindowsmask = {}
    for motiflen in motiflens:
        coveredposition_windowsums = get_sum_scores_in_window(
            scores=covered_positions, windowlen=motiflen)
        nullwindowsmask = (coveredposition_windowsums==0.0)
        motiflen_to_nullwindowsmask[motiflen] = nullwindowsmask
        print("Number of null windows for length",motiflen, np.sum(nullwindowsmask))
    
    return (motifmatches_in_seqs, motifname_to_hitlocations,
            motifname_to_motiflen, motiflen_to_nullwindowsmask)


def compute_motif_scores(method_to_scores, motifname_to_hitlocations,
                         motifname_to_motiflen, motiflen_to_nullwindowsmask):
    motiflen_to_motifnames = defaultdict(list)
    for motifname in motifname_to_motiflen:
        motiflen_to_motifnames[motifname_to_motiflen[motifname]].append(motifname)
    motifname_to_method_to_hitscores = defaultdict(dict)
    motifname_to_method_to_auroc = defaultdict(dict)
    motifname_to_method_to_auprc = defaultdict(dict)
    motifname_to_numhits = {}
    motifname_to_baselineauprc = {}
    for motiflen in sorted(motiflen_to_motifnames.keys()):
        print("Doing motifs of length",motiflen)
        sys.stdout.flush()
        for method in sorted(method_to_scores.keys()):
            print("Method",method)
            sys.stdout.flush()
            scores = method_to_scores[method]
            cumsum_scores = get_sum_scores_in_window(
                             scores=scores, windowlen=motiflen)
            assert motiflen_to_nullwindowsmask[motiflen].shape==cumsum_scores.shape
            nullwindowscores = cumsum_scores[motiflen_to_nullwindowsmask[motiflen]]
            for motifname in motiflen_to_motifnames[motiflen]:
                hitlocations = motifname_to_hitlocations[motifname]
                motifname_to_numhits[motifname] = len(hitlocations)
                baseline_auprc = len(hitlocations)/np.sum(motiflen_to_nullwindowsmask[motiflen])
                motifname_to_baselineauprc[motifname] = baseline_auprc
                hitscores = list(float(x) for x in cumsum_scores[tuple(zip(*hitlocations))])
                motifname_to_method_to_hitscores[motifname][method] = hitscores
                y_true = [1 for x in hitscores]+[0 for x in nullwindowscores]

                y_score = list(hitscores)+list(nullwindowscores)
                auroc = roc_auc_score(y_true=y_true, y_score=y_score)
                auprc = average_precision_score(y_true=y_true, y_score=y_score)
                motifname_to_method_to_auroc[motifname][method] = auroc
                motifname_to_method_to_auprc[motifname][method] = auprc
    return (motifname_to_method_to_hitscores,
            motifname_to_method_to_auroc, motifname_to_method_to_auprc,
            motifname_to_numhits, motifname_to_baselineauprc)


In [44]:
import json
celltypes = ['HepG2', 'H1ESC', 'A549']
#celltypes = ['HepG2', 'H1ESC']

for celltype in celltypes:
    print("\n\nON",celltype)
    #get the scores for the different methods for those common sequences
    corefiles = celltype_to_corefiles[celltype]
    method_to_scores, onehot_seqs, seqs, seqnames =\
        get_scores_for_common_sequences(corefiles=corefiles)
    (motifmatches_in_seqs, motifname_to_hitlocations,
     motifname_to_motiflen, motiflen_to_nullwindowsmask) =\
        get_motifmatches_and_nullwindows_mask(corefiles=corefiles,
                                              seqnames=seqnames)
    (motifname_to_method_to_hitscores,
     motifname_to_method_to_auroc,
     motifname_to_method_to_auprc,
     motifname_to_numhits,
     motifname_to_baselineauprc) = compute_motif_scores(
        method_to_scores=method_to_scores,
        motifname_to_hitlocations=motifname_to_hitlocations,
        motifname_to_motiflen=motifname_to_motiflen,
        motiflen_to_nullwindowsmask=motiflen_to_nullwindowsmask)

    #save things to json
    open(celltype+"_motifscoring_results.json",'w').write(
         json.dumps(
            {'motifname_to_hitlocations': motifname_to_hitlocations,
             'motifname_to_method_to_hitscores': motifname_to_method_to_hitscores,
             'motifname_to_method_to_auroc': motifname_to_method_to_auroc,
             'motifname_to_method_to_auprc': motifname_to_method_to_auprc,
             'motifname_to_numhits': motifname_to_numhits,
             'motifname_to_baselineauprc': motifname_to_baselineauprc}))



ON HepG2
Reading in motif file
Read motif file
Number of null windows for length 8 774435
Number of null windows for length 10 689252
Number of null windows for length 12 613849
Doing motifs of length 8
Method dldefault_dinucshuffref10
Method dlrescale_dinucshuffref10
Method dlrescale_gcref
Method dlrescale_zeroref
Method gradtimesinp
Method intgrad10_dinucshuffref10
Method intgrad10_gcref
Method intgrad10_zeroref
Method intgrad20_dinucshuffref10
Method intgrad20_gcref
Method intgrad20_zeroref
Method ism
Doing motifs of length 10
Method dldefault_dinucshuffref10
Method dlrescale_dinucshuffref10
Method dlrescale_gcref
Method dlrescale_zeroref
Method gradtimesinp
Method intgrad10_dinucshuffref10
Method intgrad10_gcref
Method intgrad10_zeroref
Method intgrad20_dinucshuffref10
Method intgrad20_gcref
Method intgrad20_zeroref
Method ism
Doing motifs of length 12
Method dldefault_dinucshuffref10
Method dlrescale_dinucshuffref10
Method dlrescale_gcref
Method dlrescale_zeroref
Method gradtime

In [45]:
!gzip -f *.json

In [46]:
import gzip

for celltype in celltypes:
    dicts = json.loads(gzip.open(celltype+"_motifscoring_results.json.gz").read())
    motifname_to_method_to_auroc = dicts['motifname_to_method_to_auroc']
    motifname_to_method_to_auprc = dicts['motifname_to_method_to_auprc']
    motifname_to_numhits = dicts['motifname_to_numhits']
    motifname_to_baselineauprc = dicts['motifname_to_baselineauprc']
    
    motifnames = [x[0] for x in sorted(motifname_to_numhits.items(), key=lambda x: -x[1])]
    methods = sorted(motifname_to_method_to_auroc[motifnames[0]])
    auroc_outf = open(celltype+"_motifaurocs.tsv",'w')
    auprc_outf = open(celltype+"_motifauprcs.tsv",'w')
    auroc_outf.write("Motifname\tnumhits\t"+"\t".join(methods)+'\n')
    auprc_outf.write("Motifname\tnumhits\tbaselineauprc\t"+"\t".join(methods)+'\n')
    for motifname in motifnames:
        baselineauprc = motifname_to_baselineauprc[motifname]
        auroc_outf.write(motifname
                   +"\t"+str(motifname_to_numhits[motifname])
                   +"\t"+"\t".join(str(motifname_to_method_to_auroc[motifname][method])
                                       for method in methods)+"\n")
        auprc_outf.write(motifname
                   +"\t"+str(motifname_to_numhits[motifname])
                   +"\t"+str(baselineauprc)
                   +"\t"+"\t".join(str(motifname_to_method_to_auprc[motifname][method])
                                       for method in methods)+"\n")
    auroc_outf.close()
    auprc_outf.close()
    

In [47]:
!head -10 H1ESC_motifaurocs.tsv

Motifname	numhits	dldefault_dinucshuffref10	dlrescale_dinucshuffref10	dlrescale_gcref	dlrescale_zeroref	gradtimesinp	intgrad10_dinucshuffref10	intgrad10_gcref	intgrad10_zeroref	intgrad20_dinucshuffref10	intgrad20_gcref	intgrad20_zeroref	ism
WAMCGCGS	8423	0.6244153505141627	0.6254442843402404	0.6363459540499196	0.6364488445551334	0.5950267209226232	0.6272823593486698	0.6596459919207023	0.6564096196853018	0.6277104516330176	0.6546788947783404	0.6593487063254057	0.659663653109249
GGGAAAAA	6894	0.6188666404285545	0.6213784108825389	0.6112557554849323	0.6008716805170822	0.603216449287973	0.6327942262607315	0.5984257001455939	0.6057445745618684	0.6333120859850716	0.6019673312964564	0.6025685146379381	0.680825409972838
CGCCGCTCTA	6502	0.6333379588538612	0.632504845009095	0.6724589452959753	0.6798868528806336	0.5895650973615383	0.5961780709131546	0.6881397405459655	0.678104416125431	0.5963843817256717	0.6827408376427694	0.6840763028433764	0.5787040164029813
KCCGGTTT	6495	0.5846704964092966