{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a7e60bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import h5py\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(\"../2_train_models\")\n",
    "from file_configs import MergedFilesConfig\n",
    "\n",
    "from common_functions import load_coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d3bcfcb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify what set of models to look at\n",
    "cell_type = \"K562\"\n",
    "\n",
    "# these usually don't change\n",
    "model_type = \"strand_merged_umap\"\n",
    "data_type = \"procap\"\n",
    "\n",
    "# size of the model inputs and outputs\n",
    "in_window = 2114\n",
    "slice_len = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0bc9fdf3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-03-12 21:23:20--  https://hgdownload.cse.ucsc.edu/goldenPath/hg38/database/rmsk.txt.gz\n",
      "Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.198.53\n",
      "Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.198.53|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 155633856 (148M) [application/x-gzip]\n",
      "Saving to: ‘/users/kcochran/projects/procapnet/annotations/rmsk.txt.gz’\n",
      "\n",
      "/users/kcochran/pro 100%[===================>] 148.42M   106MB/s    in 1.4s    \n",
      "\n",
      "2024-03-12 21:23:22 (106 MB/s) - ‘/users/kcochran/projects/procapnet/annotations/rmsk.txt.gz’ saved [155633856/155633856]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# download the hg38 repeatmasker/repbase annotation of all repeats from UCSC\n",
    "\n",
    "! wget https://hgdownload.cse.ucsc.edu/goldenPath/hg38/database/rmsk.txt.gz -O \"/users/kcochran/projects/procapnet/annotations/rmsk.txt.gz\"\n",
    "! gunzip \"/users/kcochran/projects/procapnet/annotations/rmsk.txt.gz\"\n",
    "! awk -v OFS=\"\\t\" '{ print $6, $7, $8, $11, $12, $13 }' \"/users/kcochran/projects/procapnet/annotations/rmsk.txt\" > \"/users/kcochran/projects/procapnet/annotations/rmsk.bed\"\n",
    "! rm \"/users/kcochran/projects/procapnet/annotations/rmsk.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "638e6c03",
   "metadata": {},
   "outputs": [],
   "source": [
    "repeat_masker_bed = \"/users/kcochran/projects/procapnet/annotations/rmsk.bed\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e336fbe9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chr1\t10000\t10468\t(TAACCC)n\tSimple_repeat\tSimple_repeat\r\n",
      "chr1\t10468\t11447\tTAR1\tSatellite\ttelo\r\n",
      "chr1\t11504\t11675\tL1MC5a\tLINE\tL1\r\n",
      "chr1\t11677\t11780\tMER5B\tDNA\thAT-Charlie\r\n",
      "chr1\t15264\t15355\tMIR3\tSINE\tMIR\r\n",
      "chr1\t15797\t15849\t(TGCTCC)n\tSimple_repeat\tSimple_repeat\r\n",
      "chr1\t16712\t16744\t(TGG)n\tSimple_repeat\tSimple_repeat\r\n",
      "chr1\t18906\t19048\tL2a\tLINE\tL2\r\n",
      "chr1\t19971\t20405\tL3\tLINE\tCR1\r\n",
      "chr1\t20530\t20679\tPlat_L3\tLINE\tCR1\r\n"
     ]
    }
   ],
   "source": [
    "! head $repeat_masker_bed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5c6e8dfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the config object (filepath holder) for these models\n",
    "config = MergedFilesConfig(cell_type, model_type, data_type)\n",
    "\n",
    "proj_dir = config.proj_dir\n",
    "tmp_data_dir = \"te_enrichment_files/\"\n",
    "os.makedirs(tmp_data_dir, exist_ok=True)\n",
    "\n",
    "# load the chrom, start, end info for all peaks\n",
    "coords = load_coords(config.all_peak_path, in_window)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f9bf9140",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for each pattern we want to look at, we need to:\n",
    "#  1) load all the genomic sequences supporting this pattern (seqlets)\n",
    "#  2) write their coordinates to a bed file\n",
    "#  3) get enrichment stats for those coordinates vs. repeat annotations\n",
    "\n",
    "def extract_genomic_coords_at_seqlets(modisco_results_path, pattern_i,\n",
    "                                      coords, in_window, slice_len):\n",
    "    \n",
    "    # this function figures out the genomic coordinates for each sequence\n",
    "    # that modisco clustered into a pattern of interest\n",
    "        \n",
    "    def seqlet_coord_to_input_coord(seqlet_coord):\n",
    "        # seqlet coordinates are by default w.r.t. the model output window;\n",
    "        # this adjusts coords to be w.r.t. the whole peak / model input window\n",
    "        return seqlet_coord + ((in_window - slice_len) // 2)\n",
    "    \n",
    "    def load_seqlets(modisco_results, pattern_i):\n",
    "        patterns_grp = modisco_results[\"pos_patterns\"]\n",
    "        pattern_grp = patterns_grp[\"pattern_\" + str(pattern_i)]\n",
    "        return pattern_grp[\"seqlets\"]\n",
    "        \n",
    "    # open the modisco hdf5 results file object\n",
    "    modisco_results = h5py.File(modisco_results_path, \"r\")\n",
    "    \n",
    "    # hunt through the hdf5 file structure for the seqlet info\n",
    "    seqlets = load_seqlets(modisco_results, pattern_i)\n",
    "    \n",
    "    coord_indexes = seqlets[\"example_idx\"][:]\n",
    "    seqlet_starts = seqlets[\"start\"][:]\n",
    "    seqlet_ends = seqlets[\"end\"][:]\n",
    "    seqlet_rcs = seqlets[\"is_revcomp\"][:]\n",
    "    \n",
    "    # adjust seqlet coordinates so that they're normal genomic coords\n",
    "    \n",
    "    input_starts = seqlet_coord_to_input_coord(seqlet_starts)\n",
    "    input_ends = seqlet_coord_to_input_coord(seqlet_ends)\n",
    "\n",
    "    genomic_coords = []\n",
    "    for coord_index, input_start, input_end in zip(coord_indexes, input_starts, input_ends):\n",
    "        chrom, peak_start = coords[coord_index][:2]\n",
    "        genomic_coords.append((chrom, peak_start + input_start, peak_start + input_end))\n",
    "\n",
    "    modisco_results.close()\n",
    "        \n",
    "    return genomic_coords\n",
    "\n",
    "\n",
    "def write_coords_to_bed(bed_filepath, coords):\n",
    "    to_write = \"\\n\".join([\"\\t\".join([str(i) for i in coord]) for coord in coords])\n",
    "    \n",
    "    if bed_filepath.endswith(\".gz\"):\n",
    "        with gzip.open(bed_filepath, \"w\") as f:\n",
    "            f.write(to_write.encode())\n",
    "    else:\n",
    "        with open(bed_filepath, \"w\") as f:\n",
    "            f.write(to_write)\n",
    "\n",
    "            \n",
    "def write_all_seqlet_coords_to_beds(task, tmp_data_dir = tmp_data_dir):\n",
    "    # this is specific to K562\n",
    "    \n",
    "    if task == \"profile\":\n",
    "        # labels for each of the patterns stored in the modisco results file\n",
    "        pattern_labels = [\"BRE/SP\", \"CA-Inr\", \"ETS\", \"NFY\", \"NRF1\",\n",
    "                          \"ATF1\", \"TATA\", \"THAP11\", \"YY1\", \"AP1\",\n",
    "                          \"DPR_CGG\", \"DPR_CGG\", \"DPR_CGG\", \"TA-Inr\", \"DPR_CGG\",\n",
    "                          \"CTCF\", \"DPR_CGG\", \"NRF1-like\", \"DPR_CGG\", \"ZBTB33\",\n",
    "                          \"DPR_CGG\", \"TCT\", \"BRE/SP_TE\", \"TATATA\", \"CA-Inr_TE\",\n",
    "                          \"DPR_CGG\", \"DPR_CGG\", \"TATA_TE\", \"CA-Inr_dimer\", \"DPR_CGG\", \"DPR_CGG\",\n",
    "                          \"ZBTBT33_TATA_TE\", \"TE\", \"GC-rich\", \"CA-Inr_TE\", \"YY1-like\",\n",
    "                          \"NFY_TE\", \"ZBTB33_TE\", \"ETS_TE\", \"ETS_dimer\", \"NFY_C-rich\",\n",
    "                          \"ETS_CA-Inr_dimer\", \"YY1-like\"]\n",
    "\n",
    "        # modisco results hdf5 object containing the motifs/patterns and subpatterns\n",
    "        modisco_results_path = config.modisco_profile_results_path\n",
    "    else:\n",
    "        pattern_labels = [\"ETS\", \"BRE/SP\", \"NRF1\", \"ETS-like\", \"NFY\", \"ATF1\",\n",
    "                          \"CpG\", \"AP1\", \"CpG_spacing\", \"CpG_spacing\", \"THAP11\",\n",
    "                          \"CpG_spacing\", \"CpG_spacing\", \"CpG_spacing\",\n",
    "                          \"ZBTB33\", \"CpG_spacing\", \"BRE/SP_TE\", \"CpG_spacing\", \"CpG_spacing\",\n",
    "                          \"THAP11-like\", \"CpG_spacing\", \"BRE/SP_TE\", \"ZBTB33_TE\", \"CpG_spacing\",\n",
    "                          \"BRE/SP_TE\", \"TE\", \"BRE/SP_TE\", \"TE\", \"TE\",\n",
    "                          \"ETS-like\", \"BRE/SP_TE\", \"Unknown\", \"NFY_TE\"]\n",
    "\n",
    "        modisco_results_path = config.modisco_counts_results_path\n",
    "    \n",
    "    \n",
    "    for pattern_i, pattern_label in enumerate(pattern_labels):\n",
    "        if \"TE\" in pattern_label:\n",
    "\n",
    "            seqlet_coords = extract_genomic_coords_at_seqlets(modisco_results_path, pattern_i,\n",
    "                                                              coords, in_window, slice_len)\n",
    "\n",
    "            bed_filepath = tmp_data_dir + task + \".pattern_\" + str(pattern_i) + \".seqlets.bed\"\n",
    "            write_coords_to_bed(bed_filepath, seqlet_coords)\n",
    "            \n",
    "            \n",
    "for task in [\"profile\", \"counts\"]:\n",
    "    write_all_seqlet_coords_to_beds(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4785f29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# then, run repeat enrichment analysis on each of the bed files\n",
    "# (for each of the patterns that might be a TE, for both model tasks)\n",
    "\n",
    "bed_paths = [tmp_data_dir + fpath for fpath in os.listdir(tmp_data_dir)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "05adddda",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\hline\n",
      "TE Pattern & Repeat Overlap (\\%) & Most Common Repeat Family & Fraction of Repeat Overlap Matching Family (\\%) \\\\\n",
      "\\hline\n",
      "\\endhead\n",
      "\\hline\n",
      "\\endfoot\n",
      "\\hline\n",
      "\\endlastfoot\n",
      "Profile TE Pattern 1 & 82\\% & ERV1 (LTR) & 100\\% \\\\\n",
      "Profile TE Pattern 2 & 84\\% & ERVL-MaLR (LTR) & 97\\% \\\\\n",
      "Profile TE Pattern 3 & 81\\% & ERV1 (LTR) & 100\\% \\\\\n",
      "Profile TE Pattern 4 & 91\\% & ERV1 (LTR) & 99\\% \\\\\n",
      "Profile TE Pattern 5 & 75\\% & Alu (SINE) & 93\\% \\\\\n",
      "Profile TE Pattern 6 & 81\\% & ERVL-MaLR (LTR) & 100\\% \\\\\n",
      "Profile TE Pattern 7 & 70\\% & ERV1 (LTR) & 100\\% \\\\\n",
      "Profile TE Pattern 8 & 78\\% & Alu (SINE) & 94\\% \\\\\n",
      "Profile TE Pattern 9 & 67\\% & ERV1 (LTR) & 93\\% \\\\\n",
      "Counts TE Pattern 1 & 96\\% & ERV1 (LTR) & 100\\% \\\\\n",
      "Counts TE Pattern 2 & 82\\% & L1 (LINE) & 97\\% \\\\\n",
      "Counts TE Pattern 3 & 82\\% & ERVL-MaLR (LTR) & 98\\% \\\\\n",
      "Counts TE Pattern 4 & 73\\% & L1 (LINE) & 96\\% \\\\\n",
      "Counts TE Pattern 5 & 95\\% & ERVL-MaLR (LTR) & 96\\% \\\\\n",
      "Counts TE Pattern 6 & 91\\% & Alu (SINE) & 97\\% \\\\\n",
      "Counts TE Pattern 7 & 85\\% & ERV1 (LTR) & 86\\% \\\\\n",
      "Counts TE Pattern 8 & 73\\% & ERVL-MaLR (LTR) & 90\\% \\\\\n",
      "Counts TE Pattern 9 & 82\\% & ERV1 (LTR) & 93\\% \\\\\n",
      "Counts TE Pattern 10 & 61\\% & Alu (SINE) & 94\\%\\label{Tab2}\\\\\n"
     ]
    }
   ],
   "source": [
    "import subprocess\n",
    "\n",
    "def run_bedtools_intersect(filepath_a, filepath_b, dest_filepath, other_args=[]):\n",
    "    cmd = [\"bedtools\", \"intersect\"]\n",
    "    cmd += [\"-a\", filepath_a]\n",
    "    cmd += [\"-b\", filepath_b]\n",
    "    for arg in other_args:\n",
    "        cmd += [arg]\n",
    "        \n",
    "    with open(dest_filepath, \"w\") as outf:\n",
    "        subprocess.call(cmd, stdout=outf)\n",
    "        \n",
    "\n",
    "def calc_repeat_overlap_frac(results_fpath, seqlets_bed):\n",
    "    # returns the fractions of seqlets that overlapped any repeat annotation\n",
    "    with open(seqlets_bed) as f:\n",
    "        num_seqlets = sum([1 for line in f])\n",
    "    \n",
    "    with open(results_fpath) as resultsf:\n",
    "        hits = [line.strip().split() for line in resultsf]\n",
    "        \n",
    "    # don't double-count stuff that overlapped two annotations?\n",
    "    hit_coords = set([tuple(hit[:3]) for hit in hits])\n",
    "    \n",
    "    return len(hit_coords) / num_seqlets\n",
    "    \n",
    "    \n",
    "def get_most_common_repeat_type(results_fpath, repeat_col):\n",
    "    # the repeat masker file has 3 columns of repeat labels:\n",
    "    # -3 is the most specific label / subfamily,\n",
    "    # -2 is the least specific or the overall class (LINE, SINE, LTR...)\n",
    "    # -1 is the middle or family (e.g. Alu)\n",
    "    \n",
    "    assert repeat_col in [-3, -2, -1], repeat_col\n",
    "    \n",
    "    with open(results_fpath) as resultsf:\n",
    "        hits = [line.strip().split() for line in resultsf]\n",
    "    \n",
    "    # for every repeat type overlapped, count how often it was overlapped\n",
    "    repeat_types, repeat_type_counts = np.unique([hit[repeat_col] for hit in hits],\n",
    "                                                 return_counts=True)\n",
    "    \n",
    "    # sort repeat types overlapped so the most common is first\n",
    "    repeat_types_and_counts = sorted(list(zip(repeat_type_counts, repeat_types)),\n",
    "                                     reverse=True)\n",
    "    \n",
    "    # calc the fraction of repeat overlaps that came from this most-common one\n",
    "    total = sum([count[0] for count in repeat_types_and_counts])\n",
    "    frac_most_common_repeat_type = repeat_types_and_counts[0][0] / total\n",
    "    \n",
    "    # return name of most common repeat type, and fraction of overlaps it was\n",
    "    return repeat_types_and_counts[0][1], frac_most_common_repeat_type\n",
    "    \n",
    "\n",
    "def run_bedtools_intersect_calc_repeat_overlap(rmsk_bed, seqlets_bed):\n",
    "    assert not seqlets_bed.endswith(\".gz\") # probably doesn't work with gzipped beds\n",
    "    \n",
    "    tmpf = \"tmp.bed\"\n",
    "    \n",
    "    run_bedtools_intersect(seqlets_bed, rmsk_bed, tmpf, other_args=[\"-wa\", \"-wb\"])\n",
    "        \n",
    "    repeat_overlap_frac = calc_repeat_overlap_frac(tmpf, seqlets_bed)\n",
    "    \n",
    "    repeat_class_hit = get_most_common_repeat_type(tmpf, -2)\n",
    "    repeat_family_hit = get_most_common_repeat_type(tmpf, -1)\n",
    "        \n",
    "    return repeat_overlap_frac, repeat_class_hit, repeat_family_hit\n",
    "\n",
    "    \n",
    "def sort_bed_paths(bed_paths):\n",
    "    # orders filenames by the pattern number that is inside the filename;\n",
    "    # would make sense to plot in an order consistent with the supp fig of patterns\n",
    "    \n",
    "    sort_by = []\n",
    "    for bed_path in bed_paths:\n",
    "        pattern_num = int(bed_path.split(\"pattern_\")[-1].split(\".\")[0])\n",
    "        sort_by.append(pattern_num)\n",
    "        \n",
    "    bed_paths = np.array(bed_paths)[np.argsort(sort_by)[::-1]]\n",
    "    return bed_paths\n",
    "    \n",
    "\n",
    "def make_table(repeat_masker_bed, bed_paths):\n",
    "    # this function prints a full LateX table you can copy-paste into overleaf\n",
    "    \n",
    "    # table preamble stuff\n",
    "    print(r'\\hline')\n",
    "    print(r'TE Pattern & Repeat Overlap (\\%) & Most Common Repeat Family & Fraction of Repeat Overlap Matching Family (\\%) \\\\')\n",
    "    print(r'\\hline')\n",
    "    print(r'\\endhead')\n",
    "    print(r'\\hline')\n",
    "    print(r'\\endfoot')\n",
    "    print(r'\\hline')\n",
    "    print(r'\\endlastfoot')\n",
    "\n",
    "    # separate the patterns by task, sort in numeric order\n",
    "    prof_bed_paths = sort_bed_paths([bed_path for bed_path in bed_paths if \"profile\" in bed_path])\n",
    "    counts_bed_paths = sort_bed_paths([bed_path for bed_path in bed_paths if \"counts\" in bed_path])\n",
    "\n",
    "    # first, list out repeat enrichment of profile patterns\n",
    "    # (formatted to be like a row in a table)\n",
    "    for pattern_i, bed_path in enumerate(prof_bed_paths):\n",
    "        pattern_label = \"Profile TE Pattern \" + str(pattern_i + 1)\n",
    "        row_str = pattern_label + \" & \"\n",
    "        \n",
    "        repeat_overlap_results = run_bedtools_intersect_calc_repeat_overlap(repeat_masker_bed, bed_path)\n",
    "        repeat_overlap_frac, repeat_class_hit, repeat_family_hit = repeat_overlap_results\n",
    "        \n",
    "        row_str += \"%0.f\" % (repeat_overlap_frac * 100) + r'\\% & '\n",
    "        \n",
    "        row_str += repeat_family_hit[0] + \" (\" + repeat_class_hit[0] + \") & \"\n",
    "        row_str += \"%0.f\" % (repeat_family_hit[1] * 100) + r'\\% \\\\'\n",
    "        \n",
    "        print(row_str)\n",
    "        \n",
    "    print(r'\\hline')\n",
    "        \n",
    "    # then do the same for counts patterns\n",
    "    for pattern_i, bed_path in enumerate(counts_bed_paths):\n",
    "        pattern_label = \"Counts TE Pattern \" + str(pattern_i + 1)\n",
    "        row_str = pattern_label + \" & \"\n",
    "        \n",
    "        repeat_overlap_results = run_bedtools_intersect_calc_repeat_overlap(repeat_masker_bed, bed_path)\n",
    "        repeat_overlap_frac, repeat_class_hit, repeat_family_hit = repeat_overlap_results\n",
    "        \n",
    "        row_str += \"%0.f\" % (repeat_overlap_frac * 100) + r'\\% & '\n",
    "        \n",
    "        row_str += repeat_family_hit[0] + \" (\" + repeat_class_hit[0] + \") & \"\n",
    "        \n",
    "        ########################\n",
    "        if pattern_i != len(counts_bed_paths) - 1:\n",
    "            row_str += \"%0.f\" % (repeat_family_hit[1] * 100) + r'\\% \\\\'\n",
    "        else:\n",
    "            row_str += \"%0.f\" % (repeat_family_hit[1] * 100) + r'\\%\\label{Tab2}\\\\'\n",
    "            \n",
    "        print(row_str)\n",
    "    \n",
    "    \n",
    "make_table(repeat_masker_bed, bed_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4e4808b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:procap_A100] *",
   "language": "python",
   "name": "conda-env-procap_A100-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
