{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import bcolz\n",
    "import json\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as ps\n",
    "\n",
    "from examples_queue import ExamplesQueue\n",
    "from interval_queue import IntervalQueue\n",
    "from bcolz_reader import BcolzReader\n",
    "from dataset_interval_reader import DatasetIntervalReader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Write some dummy test data\n",
    "\n",
    "Creates some genome-wide data directories, both with one-hot fasta sequences and bigwigs.\n",
    "\n",
    "Also creates random 1000bp intervals and random integer labels for some number of tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "NUM_SEQ_CHARS = 4\n",
    "SEQ_LEN_CHR = int(1e6)\n",
    "NUM_CHRS = 5\n",
    "NUM_DATAFILES = 3\n",
    "\n",
    "NUM_INTERVALS = 500\n",
    "INTERVAL_LENGTH = 1000\n",
    "NUM_TASKS = 6\n",
    "\n",
    "DATA_DIR = 'test-data'\n",
    "\n",
    "FA_DIRS = [os.path.join(DATA_DIR, 'sequence-fa-{}'.format(i)) for i in range(NUM_DATAFILES)]\n",
    "BW_DIRS = [os.path.join(DATA_DIR, 'genome-bgw-{}'.format(i)) for i in range(NUM_DATAFILES)]\n",
    "\n",
    "BLOSC_CPARAMS = bcolz.cparams(clevel=5, shuffle=bcolz.SHUFFLE, cname='lz4')\n",
    "\n",
    "if not os.path.isdir(DATA_DIR):\n",
    "    os.mkdir(DATA_DIR)\n",
    "\n",
    "for FA_DIR in FA_DIRS:\n",
    "    if not os.path.isdir(FA_DIR):\n",
    "        os.mkdir(FA_DIR)\n",
    "\n",
    "for BW_DIR in BW_DIRS:\n",
    "    if not os.path.isdir(BW_DIR):\n",
    "        os.mkdir(BW_DIR)\n",
    "\n",
    "def random_fasta_seq():\n",
    "    seq_idxs = np.random.randint(0, NUM_SEQ_CHARS, SEQ_LEN_CHR)\n",
    "    seq_arr = np.zeros((NUM_SEQ_CHARS, SEQ_LEN_CHR))\n",
    "    seq_arr[seq_idxs, np.arange(SEQ_LEN_CHR, dtype=int)] = 1\n",
    "    return seq_arr\n",
    "\n",
    "def random_bw_data():\n",
    "    # Just use low-frequency wave function for now\n",
    "    bw_data = np.sin(np.arange(SEQ_LEN_CHR) / 1e-2)\n",
    "    return bw_data\n",
    "\n",
    "def random_labels():\n",
    "    # Just random labels for now, as ints\n",
    "    labels = np.random.randint(0, 10, size=(NUM_INTERVALS, NUM_TASKS))\n",
    "    return labels\n",
    "\n",
    "def random_intervals():\n",
    "    interval_starts = np.random.randint(0, SEQ_LEN_CHR - INTERVAL_LENGTH, size=NUM_INTERVALS)\n",
    "    interval_ends = interval_starts + INTERVAL_LENGTH\n",
    "    interval_chrs = np.random.randint(0, NUM_CHRS, size=NUM_INTERVALS)\n",
    "    interval_chrs = np.array(list(map(lambda x: 'chr{}'.format(x), interval_chrs)))\n",
    "    return interval_chrs, interval_starts, interval_ends\n",
    "\n",
    "def dump_to_disk(chr_key, arr, base_dir, is_transpose=False):\n",
    "    target_fname = os.path.join(base_dir, chr_key)\n",
    "    if is_transpose:\n",
    "        arr = arr.T\n",
    "    c_arr = bcolz.carray(arr, cparams=BLOSC_CPARAMS, rootdir=target_fname, mode='w')\n",
    "    c_arr.flush()\n",
    "\n",
    "def write_metadata(base_dir, is_transpose=False):\n",
    "    # Check the first file to get the shape\n",
    "    arr_shape = bcolz.carray(rootdir=os.path.join(base_dir, 'chr0'), mode='r').shape\n",
    "    if is_transpose:\n",
    "        arr_shape = arr_shape[::-1]\n",
    "    chr_shapes = {'chr{}'.format(i): arr_shape for i in range(NUM_CHRS)}\n",
    "    type_str = 'array_2D_transpose_bcolz' if is_transpose else 'array_bcolz'\n",
    "    metadata = {'type': type_str, 'file_shapes': chr_shapes}\n",
    "    with open(os.path.join(base_dir, 'metadata.json'), 'w') as fp:\n",
    "        json.dump(metadata, fp)\n",
    "\n",
    "for FA_DIR in FA_DIRS:\n",
    "    seq_arrs = {'chr{}'.format(i): random_fasta_seq() for i in range(NUM_CHRS)}\n",
    "    for chr_key, arr in seq_arrs.items():\n",
    "        dump_to_disk(chr_key, arr, FA_DIR, is_transpose=True)\n",
    "    write_metadata(FA_DIR, is_transpose=True)\n",
    "\n",
    "for BW_DIR in BW_DIRS:\n",
    "    bw_arrs = {'chr{}'.format(i): random_bw_data() for i in range(NUM_CHRS)}\n",
    "    for chr_key, arr in bw_arrs.items():\n",
    "        dump_to_disk(chr_key, arr, BW_DIR)\n",
    "    write_metadata(BW_DIR)\n",
    "\n",
    "    \n",
    "labels = random_labels()\n",
    "interval_chrs, interval_starts, interval_ends = random_intervals()\n",
    "\n",
    "intervals = {\n",
    "    'chrom': interval_chrs,\n",
    "    'start': interval_starts,\n",
    "    'end': interval_ends,\n",
    "}\n",
    "\n",
    "datafile_paths = {os.path.basename(p): p for p in FA_DIRS + BW_DIRS}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up the readers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data we want to read is in `test-data/`.\n",
    "\n",
    "We just need to use `DatasetIntervalReader` to create readers for all the files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "reader = DatasetIntervalReader(intervals, datafile_paths, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "\n",
    "outputs = reader.dequeue_many(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data/genome-bgw-0': <tf.Tensor 'examples-queue_DequeueMany:8' shape=(32, 1000) dtype=float32>,\n",
       " 'data/genome-bgw-1': <tf.Tensor 'examples-queue_DequeueMany:7' shape=(32, 1000) dtype=float32>,\n",
       " 'data/genome-bgw-2': <tf.Tensor 'examples-queue_DequeueMany:6' shape=(32, 1000) dtype=float32>,\n",
       " 'data/sequence-fa-0': <tf.Tensor 'examples-queue_DequeueMany:4' shape=(32, 4, 1000) dtype=float32>,\n",
       " 'data/sequence-fa-1': <tf.Tensor 'examples-queue_DequeueMany:3' shape=(32, 4, 1000) dtype=float32>,\n",
       " 'data/sequence-fa-2': <tf.Tensor 'examples-queue_DequeueMany:5' shape=(32, 4, 1000) dtype=float32>,\n",
       " 'intervals/chrom': <tf.Tensor 'examples-queue_DequeueMany:2' shape=(32,) dtype=string>,\n",
       " 'intervals/end': <tf.Tensor 'examples-queue_DequeueMany:1' shape=(32,) dtype=int64>,\n",
       " 'intervals/start': <tf.Tensor 'examples-queue_DequeueMany:0' shape=(32,) dtype=int64>,\n",
       " 'labels': <tf.Tensor 'examples-queue_DequeueMany:9' shape=(32, 6) dtype=int64>}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "s = tf.InteractiveSession()\n",
    "\n",
    "s.run(tf.global_variables_initializer())\n",
    "\n",
    "# Note that you must start queue runners before fetching any of the dequeues.\n",
    "queue_runner_threads = tf.train.start_queue_runners(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<Thread(Thread-4, started daemon 123145512886272)>,\n",
       " <Thread(Thread-5, started daemon 123145517092864)>,\n",
       " <Thread(Thread-6, started daemon 123145521299456)>,\n",
       " <Thread(Thread-7, started daemon 123145525506048)>,\n",
       " <Thread(Thread-8, started daemon 123145529712640)>]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is the list of threads that are enquing examples\n",
    "queue_runner_threads"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Try extracting some data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "my_batch = s.run(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['labels',\n",
       " 'intervals/start',\n",
       " 'data/sequence-fa-2',\n",
       " 'data/sequence-fa-1',\n",
       " 'data/sequence-fa-0',\n",
       " 'data/genome-bgw-1',\n",
       " 'data/genome-bgw-0',\n",
       " 'data/genome-bgw-2',\n",
       " 'intervals/chrom',\n",
       " 'intervals/end']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 5, 2, 5, 7, 5],\n",
       "       [6, 1, 5, 2, 0, 8],\n",
       "       [8, 7, 3, 8, 7, 4],\n",
       "       [9, 7, 6, 1, 3, 8],\n",
       "       [3, 9, 8, 0, 5, 0],\n",
       "       [5, 7, 5, 4, 2, 9],\n",
       "       [2, 9, 8, 3, 6, 7],\n",
       "       [9, 4, 2, 1, 8, 8],\n",
       "       [1, 0, 3, 4, 0, 5],\n",
       "       [9, 3, 0, 0, 1, 7]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_batch['labels'][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 1.,  0.,  0., ...,  0.,  0.,  0.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  0.,  1.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  1.,  0.],\n",
       "        [ 0.,  1.,  1., ...,  1.,  0.,  0.]],\n",
       "\n",
       "       [[ 1.,  0.,  0., ...,  0.,  1.,  0.],\n",
       "        [ 0.,  0.,  1., ...,  1.,  0.,  0.],\n",
       "        [ 0.,  1.,  0., ...,  0.,  0.,  1.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  0.,  0.]],\n",
       "\n",
       "       [[ 1.,  0.,  0., ...,  0.,  0.,  1.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  0.,  0.],\n",
       "        [ 0.,  1.,  0., ...,  1.,  0.,  0.],\n",
       "        [ 0.,  0.,  1., ...,  0.,  1.,  0.]],\n",
       "\n",
       "       ..., \n",
       "       [[ 0.,  0.,  1., ...,  0.,  1.,  0.],\n",
       "        [ 0.,  1.,  0., ...,  0.,  0.,  0.],\n",
       "        [ 1.,  0.,  0., ...,  1.,  0.,  0.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  0.,  1.]],\n",
       "\n",
       "       [[ 0.,  0.,  1., ...,  0.,  0.,  0.],\n",
       "        [ 1.,  0.,  0., ...,  0.,  0.,  0.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  1.,  1.],\n",
       "        [ 0.,  1.,  0., ...,  1.,  0.,  0.]],\n",
       "\n",
       "       [[ 1.,  0.,  0., ...,  1.,  1.,  0.],\n",
       "        [ 0.,  0.,  1., ...,  0.,  0.,  0.],\n",
       "        [ 0.,  1.,  0., ...,  0.,  0.,  1.],\n",
       "        [ 0.,  0.,  0., ...,  0.,  0.,  0.]]], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_batch['data/sequence-fa-1'][:10, :20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ -9.72820282e-01,  -7.21626580e-01,  -2.71724135e-01,\n",
       "          2.53000885e-01,   7.08059013e-01,   9.68144417e-01,\n",
       "          9.61639345e-01,   6.90335155e-01,   2.28938684e-01,\n",
       "         -2.95498848e-01],\n",
       "       [ -8.75662088e-01,  -9.99636233e-01,  -8.48348320e-01,\n",
       "         -4.63457286e-01,   4.90523763e-02,   5.48054874e-01,\n",
       "          8.96143734e-01,   9.97468472e-01,   8.24127972e-01,\n",
       "          4.23853785e-01],\n",
       "       [  8.66041660e-01,   9.99972641e-01,   8.58548880e-01,\n",
       "          4.80713159e-01,  -2.94928234e-02,  -5.31577587e-01,\n",
       "         -8.87285948e-01,  -9.98669267e-01,  -8.35056722e-01,\n",
       "         -4.41501111e-01],\n",
       "       [ -7.45452762e-01,  -9.80340302e-01,  -9.45279121e-01,\n",
       "         -6.49923742e-01,  -1.75603911e-01,   3.47070605e-01,\n",
       "          7.74174988e-01,   9.88100767e-01,   9.29940939e-01,\n",
       "          6.15710437e-01],\n",
       "       [  2.57217377e-01,  -2.67524838e-01,  -7.18600810e-01,\n",
       "         -9.71801281e-01,  -9.57404315e-01,  -6.79374337e-01,\n",
       "         -2.14270353e-01,   3.09835643e-01,   7.48624563e-01,\n",
       "          9.81270552e-01],\n",
       "       [ -1.29671350e-01,   3.90272349e-01,   8.02749813e-01,\n",
       "          9.94180202e-01,   9.11850929e-01,   5.78432381e-01,\n",
       "          8.57353061e-02,  -4.30570006e-01,  -8.28312576e-01,\n",
       "         -9.97969151e-01],\n",
       "       [ -4.10453528e-01,   1.07803658e-01,   5.96375763e-01,\n",
       "          9.20728505e-01,   9.91547346e-01,   7.89331496e-01,\n",
       "          3.69763553e-01,  -1.51623353e-01,  -6.31258905e-01,\n",
       "         -9.37069535e-01],\n",
       "       [ -9.99710321e-01,  -8.74256313e-01,  -5.08065164e-01,\n",
       "         -1.97199243e-03,   5.04664183e-01,   8.72334898e-01,\n",
       "          9.99797463e-01,   8.51953566e-01,   4.69513834e-01,\n",
       "         -4.22122963e-02],\n",
       "       [  5.57772875e-01,   6.06978536e-02,  -4.53091085e-01,\n",
       "         -8.42115879e-01,  -9.99253690e-01,  -8.81234765e-01,\n",
       "         -5.20557046e-01,  -1.65375955e-02,   4.92035717e-01,\n",
       "          8.65120947e-01],\n",
       "       [ -4.47176918e-02,   4.67298180e-01,   8.50637794e-01,\n",
       "          9.99743879e-01,   8.73558223e-01,   5.06827593e-01,\n",
       "          5.35774161e-04,  -5.05903542e-01,  -8.73036146e-01,\n",
       "         -9.99767542e-01]], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_batch['data/genome-bgw-1'][:10, :10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:tfnew]",
   "language": "python",
   "name": "conda-env-tfnew-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
