#!/usr/bin/env python3
"""
Simple script to count motifs in a modisco h5 file.

Usage:
    python count_motifs_h5.py [h5_file_path]
    
If no path is provided, uses the default from config.sh:
    /oak/stanford/groups/akundaje/airanman/refs/HDMA/motif_compendium.modisco_object.h5
"""

import sys
import os
import h5py

# Default h5 file path
DEFAULT_H5_PATH = "/oak/stanford/groups/akundaje/airanman/refs/HDMA/motif_compendium.modisco_object.h5"


def count_motifs_in_h5(h5_path):
    """
    Count motifs in a modisco h5 file.
    
    Parameters:
    -----------
    h5_path : str
        Path to the h5 file
    
    Returns:
    --------
    dict
        Dictionary with counts and motif names
    """
    if not os.path.exists(h5_path):
        raise FileNotFoundError(f"H5 file not found: {h5_path}")
    
    print(f"Reading h5 file: {h5_path}")
    
    with h5py.File(h5_path, 'r') as f:
        # Extract motif names from pos_patterns and neg_patterns groups
        motif_names = []
        
        print("\nExtracting motif names from h5 structure...")
        
        # Check for pos_patterns group
        if 'pos_patterns' in f:
            pos_patterns = f['pos_patterns']
            pos_motifs = list(pos_patterns.keys())
            motif_names.extend(pos_motifs)
            print(f"Found {len(pos_motifs)} positive patterns")
        
        # Check for neg_patterns group
        if 'neg_patterns' in f:
            neg_patterns = f['neg_patterns']
            neg_motifs = list(neg_patterns.keys())
            motif_names.extend(neg_motifs)
            print(f"Found {len(neg_motifs)} negative patterns")
        
        # Remove duplicates and sort
        unique_motifs = sorted(list(set(motif_names)))
        
        return {
            'total_motifs': len(unique_motifs),
            'motif_names': unique_motifs
        }


def main():
    # Get h5 file path from command line or use default
    if len(sys.argv) > 1:
        h5_path = sys.argv[1]
    else:
        h5_path = DEFAULT_H5_PATH
    
    try:
        result = count_motifs_in_h5(h5_path)
        
        print(f"\n{'='*80}")
        print(f"RESULTS:")
        print(f"{'='*80}")
        print(f"Total motifs found: {result['total_motifs']}")
        
        if result['motif_names']:
            print(f"\nFirst 20 motif names:")
            for i, name in enumerate(result['motif_names'][:20], 1):
                print(f"  {i}. {name}")
            
            if len(result['motif_names']) > 20:
                print(f"\n... and {len(result['motif_names']) - 20} more")
            
            print(f"\nLast 10 motif names:")
            for i, name in enumerate(result['motif_names'][-10:], len(result['motif_names']) - 9):
                print(f"  {i}. {name}")
        else:
            print("\nWarning: Could not extract motif names from h5 file structure")
            print("The file structure may be different than expected.")
        
        print(f"\n{'='*80}")
        
    except FileNotFoundError as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error reading h5 file: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == '__main__':
    main()

