#
# Copyright (c) 2020 10X Genomics, Inc. All rights reserved.
#
"""Data loading utilities which have limited dependencies on anything else."""

from __future__ import annotations

import json
import os

import pandas as pd

from cellranger import constants as cr_constants
from cellranger.spatial.pipeline_mode import PipelineMode, Product, SlideType

# String and other constants used in our spatial assay.
IMAGEX_LOWRES = "pxl_col_in_lowres"
IMAGEY_LOWRES = "pxl_row_in_lowres"

TISSUE_POSITIONS_HEADER = [
    "barcode",
    "in_tissue",
    "array_row",
    "array_col",
    "pxl_row_in_fullres",
    "pxl_col_in_fullres",
]

TISSUE_POSITIONS_HEADER_TYPES = {
    "barcode": "str",
    "in_tissue": "int32",
    "array_row": "int32",
    "array_col": "int32",
    "pxl_row_in_fullres": "float64",
    "pxl_col_in_fullres": "float64",
}

# Maximum high and low image dimension for display and, sometimes, processing
HIRES_MAX_DIM_DEFAULT = 2000
HIRES_MAX_DIM_DICT = {
    PipelineMode(Product.VISIUM, SlideType.XL): 4000,
    PipelineMode(Product.CYT, SlideType.XL): 4000,
    PipelineMode(Product.CYT, SlideType.VISIUM_HD): 6000,
}
LORES_MAX_DIM = 600

# Size for websummary image. Should be smaller than HIRES_MAX_DIM_DICT
HD_WS_MAX_DIM = 2000

# constants for dark_images mro parameter
DARK_IMAGES_NONE = 0
DARK_IMAGES_CHANNELS = 1
DARK_IMAGES_COLORIZED = 2
DISALLOWED_DARK_IMAGES_EXTENSION = [".png"]

# prefixes of production slide
VISIUM_PRODUCTION_SLIDE_PREFIXES = ["V1", "V2", "V3", "V4", "V5"]

# prefixes for * any * hd slide. Not just production
VISIUM_HD_SLIDE_PREFIXES = ["H1", "SJ", "14072023", "14082023", "26062023", "RD", "UN"]

# Slide IDs which are used generically in cytassist runs. Thus disable checking
# slide ID if we see this.
SLIDE_ID_EXCEPTIONS = [
    "H13UA-N778",
    "001744-121",
    "001744-122",
    "001744-123",
    "001744-124",
    "001744-126",
    "001744-127",
    "001744-128",
    "001744-129",
    "001744-210",
    "001744-212",
    "001744-213",
    "H1-RRRDDDD",
    "H1-6C48CQT",
    "H1-MVDT69H",
]

# Run info prefixes and suffixes
RUN_INFO_FILE_SUFFIX = "run-info.csv"
RUN_INFO_LINE_PREFIX = "Run Name,"

# Map from slide capture area to the capture area generated by the
# cytassist machine. The Cytassist machine generates capture areas "A" and "B"
# Capture area "A" is mapped to slide capture area "A1" in all slides
# Capture area "B" is mapped to slide capture area "D1" in standard slides and
# "B1" for XL slides.
SLIDE_CAPTURE_AREA_TO_CYTASSIST_CAPTURE_AREA_MAP = {"A1": "A", "B1": "B", "D1": "B"}

# A map from capture area to the various suffixes the cytassist videos and images
# use for the capture areas
CAPTURE_AREA_TO_MACHINE_SUFFIX = {"A": ["A", "A1"], "B": ["B", "D1", "B1", "D"]}

# Tags of Tifs in the TGZ that we do not want to pick up
TAGS_OF_TIFS_TO_NOT_PICK = ["PreSandwichClosing", "SandwichDone"]

# Number of fields in Cytassist TIFF
NUMBER_CYTASSIST_TIFF_FIELD_SEPARATORS = 6


# Map from slide capture area seen in the Tiff metadata to the capture area
# in the pipeline. The Cytassist TIFFs contain capture areas "A1" and "D1"
# for standard SD slides and standard HD slides. However, they contain capture
# areas "A" and "B" for SD XL slides. The pipeline only uses capture areas
# "A1", "B1" and "D1". Hence this map.
CYTASSIST_TIFF_CAPTURE_AREA_TO_PIPELINE_CAPTURE_AREA = {
    "A1": "A1",
    "A": "A1",
    "B1": "B1",
    "B": "B1",
    "D": "D1",
    "D1": "D1",
}


def parse_slide_sample_area_id(slide_sample_area_id):
    """Given an input to the pipeline like V19L01-006-B1,.

    parse out slide sample id and area id
    """
    slide_sample_id, area_id = slide_sample_area_id[:-3], slide_sample_area_id[-2:]
    return slide_sample_id, area_id


def is_production_slide(slide_sample_area_id: str) -> bool:
    """Determine if the slide is a production slide.

    Args:
        slide_sample_area_id (str): slide sample area ID

    Returns:
        bool: True if production slide
    """
    return any(slide_sample_area_id.upper().startswith(x) for x in VISIUM_PRODUCTION_SLIDE_PREFIXES)


def is_hd_slide(slide_sample_area_id: str) -> bool:
    """Determine if the slide is an hd slide.

    Args:
        slide_sample_area_id (str): slide sample area ID

    Returns:
        bool: True if hd slide
    """
    return any(slide_sample_area_id.upper().startswith(x) for x in VISIUM_HD_SLIDE_PREFIXES)


def get_cytassist_capture_area(slide_sample_area_id: str) -> str:
    """Get Cytassist capture area from slide sample area ID.

    Args:
        slide_sample_area_id (str): slide sample area ID

    Raises:
        ValueError: if slide sample area ID has an invalid capture ID

    Returns:
        str: Capture ID generated by the Cytassist machine - either "A" or "B"
    """
    _, area_id = parse_slide_sample_area_id(slide_sample_area_id)
    capture_area = SLIDE_CAPTURE_AREA_TO_CYTASSIST_CAPTURE_AREA_MAP.get(area_id.upper(), None)
    if capture_area:
        return capture_area
    else:
        raise ValueError(
            "Invalid Capture ID. "
            + f"Slide sample area ID input: {slide_sample_area_id}, capture area ID inferred: {area_id}."
            + " Valid slide capture area IDs are A1, B1 or D1"
        )


def get_cytassist_images_from_extracted_tgz_folder(
    base_folder: str | bytes | os.PathLike, capture_area: str
) -> list[str | bytes | os.PathLike]:
    """Gets cytassist images corresponding to a capture area from extracted tgz folder.

    Args:
        base_folder (str | bytes | os.PathLike): extracted tgz folder
        capture_area (str): capture area

    Returns:
        list[str | bytes | os.PathLike]: list of extracted TGZ files
    """
    run_name = None
    # Get run info if it exists
    run_info_name = [
        os.path.join(base_folder, x)
        for x in os.listdir(base_folder)
        if x.endswith(RUN_INFO_FILE_SUFFIX)
    ]
    # Read run name from run info if it exists. The run info is
    # an invalid CSV file.
    if len(run_info_name) == 1:
        with open(run_info_name[0]) as f:
            for line in f:
                if line.startswith(RUN_INFO_LINE_PREFIX):
                    run_name = line.removeprefix(RUN_INFO_LINE_PREFIX)

    # If run_name is found use it to extract the TIF file
    possible_cytassist_images = None
    if run_name:
        possible_cytassist_images = [
            os.path.join(base_folder, x)
            for x in os.listdir(base_folder)
            if x.endswith(".tif")
            and len(y := x.split(run_name)) == 2
            and len(z := y[1].split("_")) > 3
            and z[2] in CAPTURE_AREA_TO_MACHINE_SUFFIX.get(capture_area, [])
            and z[3].rsplit(".", maxsplit=1)[0] not in TAGS_OF_TIFS_TO_NOT_PICK
        ]

    # If could not find run_name or run_name based processing failed, try decoding allowing
    # sample name to have `_` but no underscore in run_name
    if not possible_cytassist_images:
        possible_cytassist_images = [
            os.path.join(base_folder, x)
            for x in os.listdir(base_folder)
            if x.endswith(".tif")
            and x.split("_", maxsplit=NUMBER_CYTASSIST_TIFF_FIELD_SEPARATORS)[-2]
            in CAPTURE_AREA_TO_MACHINE_SUFFIX.get(capture_area, [])
            and x.split("_", maxsplit=NUMBER_CYTASSIST_TIFF_FIELD_SEPARATORS)[-1].rsplit(
                ".", maxsplit=1
            )[0]
            not in TAGS_OF_TIFS_TO_NOT_PICK
        ]

        # If still cant decode, try decoding allowing
        # run_name to have `_` but no underscore in sample_name
        if not possible_cytassist_images:
            possible_cytassist_images = [
                os.path.join(base_folder, x)
                for x in os.listdir(base_folder)
                if x.endswith(".tif")
                and x.split("_")[-2] in CAPTURE_AREA_TO_MACHINE_SUFFIX.get(capture_area, [])
                and x.split("_")[-1].rsplit(".", maxsplit=1)[0] not in TAGS_OF_TIFS_TO_NOT_PICK
            ]

    return possible_cytassist_images


def get_all_images_from_tgz_folder(
    base_folder: str | bytes | os.PathLike,
) -> list[str | bytes | os.PathLike]:
    """Get open close images in the cytassist tarball."""
    return [os.path.join(base_folder, x) for x in os.listdir(base_folder) if x.endswith(".tif")]


def get_galfile_path(barcode_whitelist: str) -> str:
    """Given a barcode whitelist, return the path to the corresponding GAL file."""
    path_to_galfile = os.path.join(cr_constants.BARCODE_WHITELIST_PATH, barcode_whitelist + ".gal")
    return path_to_galfile


def read_from_json(filename):
    """Read from a given json file."""
    with open(filename) as json_file:
        data = json.load(json_file)

    return data


def get_scalefactors(scalefactors_fn: str) -> dict[str, float]:
    """Load the scale factors.

    Args:
        scalefactors_fn:
    """
    with open(scalefactors_fn) as scalefactors:
        return json.load(scalefactors)


def get_lowres_coordinates(tissue_positions_csv: str, scalefactors_json: str) -> pd.DataFrame:
    """Return a pandas data frame that is just like the tissue_positions_csv but has the lowres scaled image coordinates.

    Args:
        tissue_positions_csv (str): Path to the tissue_positions.csv
        scalefactors_json (str): Path to the scalefactors_json.json

    Returns:
        pd.DataFrame:
    """
    coords = read_tissue_positions_csv(tissue_positions_csv)

    # read in scalefactors json and adjust coords for downsampled image
    scalef = get_scalefactors(scalefactors_json)["tissue_lowres_scalef"]
    coords[IMAGEY_LOWRES] = coords["pxl_row_in_fullres"] * scalef
    coords[IMAGEX_LOWRES] = coords["pxl_col_in_fullres"] * scalef
    return coords


def estimate_mem_gb_pandas_csv(csv_filename: str | None) -> float:
    """Memory required to load the CSV such as tissue positions, filtered barcodes using pandas.

    If the input file foes not exists, returns 0

    Args:
        csv_filename (str | None): Filename

    Returns:
        float: memory in GB
    """
    if csv_filename is None or (not os.path.exists(csv_filename)):
        return 0.0
    # Empirically estimated memory by loading the csv file and checking
    # the RSS used
    mem_gb_per_gb_on_disk = 4.1
    file_size_gb = os.path.getsize(csv_filename) / (1024**3)
    return mem_gb_per_gb_on_disk * file_size_gb


def read_tissue_positions_csv(tissue_positions_fn) -> pd.DataFrame:
    # output dir to search for a file name
    # raw data
    # file name
    """Read the tissue positions csv as a pandas dataframe.

    Args:
        tissue_positions_fn (str): Filename

    Returns:
        pd.DataFrame: Csv as a dataframe
    """
    # For backwards compatibility
    ## First check if the file has a header. If there are digits there is no header
    with open(tissue_positions_fn) as f:
        first_line = f.readline()

    no_header = any(map(str.isdigit, first_line))

    # Set the kwargs according to the header state
    kwargs = {"names": TISSUE_POSITIONS_HEADER} if no_header else {"header": 0}

    coords = pd.read_csv(
        tissue_positions_fn,
        **kwargs,
        dtype=TISSUE_POSITIONS_HEADER_TYPES,
        sep=",",
    )
    coords["barcode"] = coords["barcode"].str.encode(encoding="ascii")
    coords = coords.set_index("barcode")
    return coords
