#!/usr/bin/env python3
"""
Scraper for HDMA motif lexicon from the MOTIFS.html page.

Extracts the pattern → motif_name_safe → annotation mapping from:
https://greenleaflab.github.io/HDMA/MOTIFS.html

The page uses DataTables with embedded JavaScript data.
"""

import re
import json
import sys
import argparse
import csv
from pathlib import Path
from typing import Dict, List, Optional
try:
    import requests
except ImportError:
    print("Error: requests module not found. Install with: pip install requests", file=sys.stderr)
    sys.exit(1)

try:
    from bs4 import BeautifulSoup
except ImportError:
    print("Error: beautifulsoup4 module not found. Install with: pip install beautifulsoup4", file=sys.stderr)
    sys.exit(1)

try:
    import pandas as pd
    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False
    print("Warning: pandas not available, will use CSV writer instead", file=sys.stderr)


def extract_datatable_data(html_content: str) -> Optional[List]:
    """
    Extract DataTable data from HTML page.
    
    The data is embedded in JavaScript as a DataTable initialization.
    We need to find the data array in the script tags.
    """
    # Look for the data directly in the HTML (sometimes it's in a script tag as JSON)
    # DataTables often embed data as: var tableData = [[...], [...]];
    
    # First, try to find JSON data in script tags
    json_pattern = r'var\s+\w*[Dd]ata\w*\s*=\s*(\[\[[\s\S]*?\]\])'
    matches = re.finditer(json_pattern, html_content, re.MULTILINE)
    for match in matches:
        try:
            data_str = match.group(1)
            # Try to parse as JSON
            data = json.loads(data_str)
            if isinstance(data, list) and len(data) > 0 and isinstance(data[0], list):
                return data
        except (json.JSONDecodeError, ValueError):
            continue
    
    soup = BeautifulSoup(html_content, 'html.parser')
    
    # Find all script tags
    scripts = soup.find_all('script')
    
    for script in scripts:
        if script.string is None:
            continue
            
        script_content = script.string
        
        # Look for DataTable initialization with data
        # Pattern: $(document).ready(function() { $('#table_id').DataTable({ data: [...] })
        # The data might be in the format: data: [[...], [...]]
        
        # Try to find the data array - look for nested arrays
        # Pattern: data: [[...], [...]]
        # We need to match balanced brackets
        data_patterns = [
            # Look for data: followed by array
            r'"data"\s*:\s*(\[\[[\s\S]*?\]\])',
            r"'data'\s*:\s*(\[\[[\s\S]*?\]\])",
            r'data\s*:\s*(\[\[[\s\S]*?\]\])',
        ]
        
        for pattern in data_patterns:
            matches = re.finditer(pattern, script_content, re.DOTALL)
            for match in matches:
                try:
                    data_str = match.group(1)
                    # Clean up - remove trailing commas before closing brackets
                    data_str = re.sub(r',\s*\]', ']', data_str)
                    data_str = re.sub(r',\s*\]\s*\]', ']]', data_str)
                    
                    # Try to parse as JSON
                    data = json.loads(data_str)
                    if isinstance(data, list) and len(data) > 0:
                        return data
                except (json.JSONDecodeError, ValueError) as e:
                    continue
        
        # Alternative: Look for variable assignments with data arrays
        var_patterns = [
            r'var\s+\w+\s*=\s*(\[\[[\s\S]*?\]\])',
            r'const\s+\w+\s*=\s*(\[\[[\s\S]*?\]\])',
            r'let\s+\w+\s*=\s*(\[\[[\s\S]*?\]\])',
        ]
        
        for pattern in var_patterns:
            matches = re.finditer(pattern, script_content, re.DOTALL)
            for match in matches:
                try:
                    data_str = match.group(1)
                    data_str = re.sub(r',\s*\]', ']', data_str)
                    data_str = re.sub(r',\s*\]\s*\]', ']]', data_str)
                    data = json.loads(data_str)
                    if isinstance(data, list) and len(data) > 0:
                        return data
                except (json.JSONDecodeError, ValueError):
                    continue
    
    return None


def extract_from_datatable_init(html_content: str) -> Optional[List[Dict]]:
    """
    Alternative method: Extract data from DataTable initialization.
    
    Look for the full DataTable config object and extract the data array.
    """
    # Find the DataTable initialization
    # Pattern: $('#table').DataTable({ ... data: [...] ... })
    
    # Use regex to find the DataTable config
    pattern = r'\$\([^)]+\)\.DataTable\s*\(\s*(\{[\s\S]*?\})\s*\)'
    matches = re.finditer(pattern, html_content)
    
    for match in matches:
        config_str = match.group(1)
        # Try to extract the data field
        data_match = re.search(r'"data"\s*:\s*(\[\[[\s\S]*?\]\])', config_str)
        if data_match:
            try:
                data_str = data_match.group(1)
                data = json.loads(data_str)
                if isinstance(data, list) and len(data) > 0:
                    return data
            except (json.JSONDecodeError, ValueError):
                continue
    
    return None


def parse_html_table(html_content: str) -> Optional[List[List]]:
    """
    Fallback: Parse the HTML table directly if JavaScript extraction fails.
    """
    soup = BeautifulSoup(html_content, 'html.parser')
    table = soup.find('table')
    
    if table is None:
        return None
    
    rows = []
    # Get header row
    header_row = table.find('thead')
    if header_row:
        headers = [th.get_text(strip=True) for th in header_row.find_all(['th', 'td'])]
        if headers:
            rows.append(headers)
    
    # Get data rows
    tbody = table.find('tbody')
    if tbody:
        for tr in tbody.find_all('tr'):
            cells = [td.get_text(strip=True) for td in tr.find_all(['td', 'th'])]
            if cells:
                rows.append(cells)
    
    return rows if rows else None


def scrape_hdma_motifs(url: str = "https://greenleaflab.github.io/HDMA/MOTIFS.html"):
    """
    Scrape the HDMA motif lexicon from the HTML page.
    
    Returns either a pandas DataFrame (if available) or a list of lists with headers.
    """
    print(f"Fetching {url}...")
    response = requests.get(url, timeout=30)
    response.raise_for_status()
    
    html_content = response.text
    
    # Try multiple extraction methods
    data = None
    
    # Method 1: Extract from DataTable JavaScript
    print("Attempting to extract DataTable data from JavaScript...")
    data = extract_datatable_data(html_content)
    
    if data is None:
        print("Trying alternative extraction method...")
        data = extract_from_datatable_init(html_content)
    
    if data is None:
        print("Falling back to HTML table parsing...")
        data = parse_html_table(html_content)
    
    if data is None:
        raise ValueError("Could not extract data from HTML page")
    
    # Define column names based on the HDMA MOTIFS page structure
    # From the page, the columns are:
    columns = [
        'pattern_class', 'idx_uniq', 'motif_name', 'motif_name_safe', 
        'annotation', 'annotation_broad', 'category', 'query_consensus',
        'cwm_fwd', 'cwm_rev', 'total_hits', 'total_n_seqlets',
        'n_component_celltypes', 'top_organ', 'cwm_entropy', 'entropy_ratio',
        'pattern', 'best_match', 'best_match_TOMTOM_qval'
    ]
    
    # Convert data to appropriate format
    # Check if data is a flat array (single list) or nested (list of lists)
    if not isinstance(data[0], list):
        raise ValueError(f"Data appears to be a flat array with {len(data)} elements. Expected nested array (list of lists).")
    
    # Data is nested (list of lists)
    num_top_level = len(data)
    if num_top_level == 0:
        raise ValueError("Extracted data is empty")
    
    num_elements_first = len(data[0]) if isinstance(data[0], list) else 1
    
    # Check if data is transposed (columns first instead of rows first)
    # If we have ~20 top-level elements and each has ~508 elements, it's likely transposed
    # Expected: 508 rows with 19 columns
    # Transposed: 19-20 columns with 508 rows each
    if num_top_level <= len(columns) + 5 and num_elements_first > len(columns) * 2:
        # Data appears to be transposed - transpose it
        print(f"Detected transposed data: {num_top_level} columns with {num_elements_first} rows each. Transposing...", file=sys.stderr)
        # Transpose: convert columns to rows
        rows = []
        for i in range(num_elements_first):
            row = [col[i] if i < len(col) else '' for col in data]
            rows.append(row)
        data = rows
        num_top_level = len(data)
        num_elements_first = len(data[0]) if data else 0
    
    # Now check if first row might be headers
    if num_top_level > 0 and isinstance(data[0], list):
        # Check if first row looks like headers (all strings, not numeric)
        first_row_all_strings = all(isinstance(x, str) for x in data[0][:5]) if len(data[0]) >= 5 else False
        first_row_not_numeric = not any(
            str(x).replace('.', '').replace('-', '').isdigit() 
            for x in data[0][:5] 
            if isinstance(x, str)
        )
        
        if first_row_all_strings and first_row_not_numeric and len(data[0]) <= len(columns) + 5:
            # First row might be headers
            headers = data[0]
            rows = data[1:]
        else:
            # Use predefined columns
            headers = columns[:num_elements_first] if num_elements_first <= len(columns) else columns
            rows = data
    else:
        headers = columns
        rows = data
    
    # Validate that all rows have the same length
    if rows:
        expected_len = len(rows[0])
        # Filter out rows that don't match expected length
        valid_rows = [row for row in rows if len(row) == expected_len]
        if len(valid_rows) != len(rows):
            print(f"Warning: Filtered out {len(rows) - len(valid_rows)} rows with incorrect length", file=sys.stderr)
        rows = valid_rows
        
        # Ensure headers match row length
        if len(headers) != expected_len:
            if len(headers) < expected_len:
                # Extend headers with generic names
                headers = headers + [f'col_{i}' for i in range(len(headers), expected_len)]
                print(f"Warning: Extended headers to {expected_len} columns", file=sys.stderr)
            else:
                # Truncate headers
                headers = headers[:expected_len]
                print(f"Warning: Truncated headers to {expected_len} columns", file=sys.stderr)
    
    # Fix column names and remove unwanted columns
    if rows and len(headers) > 0:
        # Debug: print first row to understand structure
        if len(rows) > 0 and len(rows[0]) > 0:
            print(f"Debug: First row has {len(rows[0])} columns", file=sys.stderr)
            print(f"Debug: Headers: {headers[:10]}", file=sys.stderr)
            print(f"Debug: First row values: {rows[0][:10]}", file=sys.stderr)
        
        # Find indices of columns to rename/remove
        # Column 0 (first column) should be removed (has numeric data, not pattern)
        # The 'pattern' column should also be removed
        # After removal, we need to rename:
        # - best_match -> pattern
        # Keep HTML's motif_name as-is (don't rename motif_name_safe to motif_name)
        # Keep motif_name_safe and annotation as-is
        
        # Find the index of columns before removing anything
        best_match_idx = None
        pattern_idx = None
        
        for i, col_name in enumerate(headers):
            if col_name == 'best_match':
                best_match_idx = i
            elif col_name == 'pattern':
                pattern_idx = i
        
        # Remove first column (index 0) which has wrong data
        if len(rows[0]) > 0:
            rows = [row[1:] for row in rows]
            if len(headers) > 0:
                headers = headers[1:]
                # Adjust indices after removing first column
                if best_match_idx is not None and best_match_idx > 0:
                    best_match_idx -= 1
                if pattern_idx is not None and pattern_idx > 0:
                    pattern_idx -= 1
        
        # Remove 'pattern' column if it exists
        if pattern_idx is not None:
            rows = [[row[i] for i in range(len(row)) if i != pattern_idx] for row in rows]
            headers = [headers[i] for i in range(len(headers)) if i != pattern_idx]
            # Adjust indices after removing pattern column
            if best_match_idx is not None and best_match_idx > pattern_idx:
                best_match_idx -= 1
        
        # Now rename the columns (only rename best_match to pattern)
        new_headers = []
        for i, col_name in enumerate(headers):
            if i == best_match_idx:
                new_headers.append('pattern')
            else:
                # Keep all other column names as-is (including HTML's motif_name)
                new_headers.append(col_name)
        headers = new_headers
    
    if HAS_PANDAS:
        if rows:
            return pd.DataFrame(rows, columns=headers)
        else:
            return pd.DataFrame(columns=headers)
    else:
        return {'headers': headers, 'rows': rows}


def main():
    parser = argparse.ArgumentParser(
        description='Scrape HDMA motif lexicon from MOTIFS.html page'
    )
    parser.add_argument(
        '--url',
        default='https://greenleaflab.github.io/HDMA/MOTIFS.html',
        help='URL of the HDMA MOTIFS page'
    )
    parser.add_argument(
        '--output',
        type=str,
        help='Output TSV file path (default: hdma_motif_mapping.tsv)'
    )
    parser.add_argument(
        '--columns',
        nargs='+',
        default=['pattern', 'motif_name_safe', 'annotation_broad', 'best_match'],
        help='Columns to include in output (default: pattern motif_name_safe annotation_broad best_match)'
    )
    
    args = parser.parse_args()
    
    try:
        # Scrape the data
        result = scrape_hdma_motifs(args.url)
        
        # Handle both DataFrame and dict formats
        if HAS_PANDAS and isinstance(result, pd.DataFrame):
            df = result
            print(f"Extracted {len(df)} motifs")
            print(f"Columns: {list(df.columns)}")
            
            # Select requested columns (if they exist)
            available_columns = [col for col in args.columns if col in df.columns]
            if available_columns:
                df_output = df[available_columns].copy()
            else:
                print(f"Warning: None of the requested columns found. Using all columns.", file=sys.stderr)
                df_output = df.copy()
            
            # Determine output path
            if args.output:
                output_path = Path(args.output)
            else:
                output_path = Path('hdma_motif_mapping.tsv')

            # Rename motif_name_safe to motif_name
            df_output = df_output.rename(columns={'motif_name_safe': 'motif_name'})
            
            # Save to TSV
            df_output.to_csv(output_path, sep='\t', index=False)
            print(f"Saved {len(df_output)} motifs to {output_path}")
            
            # Print sample
            print("\nSample data:")
            print(df_output.head(10).to_string())
        else:
            # Use CSV writer
            data_dict = result
            headers = data_dict['headers']
            rows = data_dict['rows']
            
            print(f"Extracted {len(rows)} motifs")
            print(f"Columns: {headers}")
            
            # Find column indices for requested columns
            col_indices = []
            output_headers = []
            for col in args.columns:
                if col in headers:
                    idx = headers.index(col)
                    col_indices.append(idx)
                    output_headers.append(col)
            
            if not output_headers:
                print(f"Warning: None of the requested columns found. Using all columns.", file=sys.stderr)
                col_indices = list(range(len(headers)))
                output_headers = headers
            
            # Determine output path
            if args.output:
                output_path = Path(args.output)
            else:
                output_path = Path('hdma_motif_mapping.tsv')
            
            # Write TSV
            with open(output_path, 'w', newline='') as f:
                writer = csv.writer(f, delimiter='\t')
                writer.writerow(output_headers)
                for row in rows:
                    if len(row) > max(col_indices):
                        writer.writerow([row[i] for i in col_indices])
            
            print(f"Saved {len(rows)} motifs to {output_path}")
            
            # Print sample
            print("\nSample data (first 10 rows):")
            for i, row in enumerate(rows[:10]):
                if len(row) > max(col_indices):
                    print("\t".join([str(row[i]) for i in col_indices]))
        
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == '__main__':
    main()

