# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from typing import Optional, Union

import numpy as np
import pandas as pd

from nsys_recipe.lib import nvtx, overlap


class RangeColumnUnifier:
    """Unify range columns in a dataframe.

    Parameters
    ----------
    df : dataframe
        Dataframe to operate on.
    start_column : str, optional
        Name of the column containing the start time. If not provided, the
        name will be deduced.
    end_column : str, optional
        Name of the column containing the end time. If not provided, the
        name will be deduced.

    Usage
    -----
    with RangeColumnUnifier(df, start_column, end_column) as unifier:
        # Perform range operations on the dataframe.
    """

    def __init__(
        self,
        df: pd.DataFrame,
        start_column: Optional[str] = None,
        end_column: Optional[str] = None,
    ):
        self._df = df
        self._original_df: Optional[pd.DataFrame] = None

        self._start_column, self._end_column = get_time_cols(
            df, start_column, end_column
        )
        self._temp_start: Optional[pd.Series] = None
        self._temp_end: Optional[pd.Series] = None

        if self._start_column not in df:
            raise ValueError(f"Column '{start_column}' not found in dataframe.")
        elif self._end_column not in df:
            raise ValueError(f"Column '{end_column}' not found in dataframe.")

    def __enter__(self):
        self._original_df = self._df.copy()

        self.save_original_cols()
        self.unify_cols()
        return self

    def __exit__(self, *args):
        self.restore_original_cols()

    def save_original_cols(self) -> None:
        if "start" in self._df.columns and self._start_column != "start":
            self._temp_start = self._df.pop("start")
        if "end" in self._df.columns and self._end_column != "end":
            self._temp_end = self._df.pop("end")

    def unify_cols(self) -> None:
        # Time range.
        if self._start_column != self._end_column:
            self._df.rename(
                columns={self._start_column: "start", self._end_column: "end"},
                inplace=True,
            )
        # Point in time.
        else:
            self._df.rename(columns={self._start_column: "start"}, inplace=True)
            self._df["end"] = self._df["start"]

    def restore_original_cols(self) -> None:
        # The start and end columns will contain the same data since they
        # represent a single point in time rather than a range. We drop the
        # end column to restore the original state.
        if self._start_column == self._end_column:
            self._df.drop(columns=["end"], inplace=True)
        else:
            self._df.rename(columns={"end": self._end_column}, inplace=True)

        self._df.rename(columns={"start": self._start_column}, inplace=True)

        if self._temp_start is not None:
            self._df["start"] = self._temp_start[self._df.index]

        if self._temp_end is not None:
            self._df["end"] = self._temp_end[self._df.index]

    def filter_by_time_range(
        self,
        start_time: int,
        end_time: int,
        strict_start: Optional[int] = None,
        strict_end: Optional[int] = None,
    ) -> None:
        """Filter the dataframe to retain only events that start or end within
        the given range.

        Parameters
        ----------
        start_time : int
            Start time of the desired range.
        end_time : int
            End time of the desired range.
        strict_start : int, optional
            Whether to apply strict boundaries for the start time.
            If specified we apply strict boundaries for the start time. We discard all events that start
            before strict_start.
        strict_end : int, optional
            If specified we apply strict boundaries for the end time. We discard all events that end
            after strict_end.
        """
        if self._df is None or self._df.empty:
            return

        mask = pd.Series(True, index=self._df.index)
        if end_time is not None:
            mask &= self._df["start"] <= end_time
        if strict_end is not None:
            mask &= self._df["end"] <= strict_end
        if start_time is not None:
            mask &= self._df["end"] >= start_time
        if strict_start is not None:
            mask &= self._df["start"] >= strict_start

        self._df.drop(self._df[~mask].index, inplace=True)

    def apply_time_offset(self, session_offset: int) -> None:
        """Synchronize session start times.

        Parameters
        ----------
        session_offset : int
            Offset of the session time.
        """
        if self._df is None or self._df.empty:
            return

        self._df.loc[:, "start"] += session_offset
        self._df.loc[:, "end"] += session_offset

    def _get_filtered_nvtx(
        self,
        nvtx_df: pd.DataFrame,
        range_name: str,
        domain_name: Optional[str],
        index: Optional[int],
    ) -> tuple[Optional[pd.DataFrame], Optional[str]]:
        filtered_nvtx_df = nvtx_df[nvtx_df["text"] == range_name]
        if filtered_nvtx_df.empty:
            return None, f"Range '{range_name}' not found."

        if domain_name is None:
            filtered_nvtx_df = nvtx.filter_by_domain_id(
                filtered_nvtx_df, nvtx.DEFAULT_DOMAIN_ID
            )

            if filtered_nvtx_df.empty:
                return None, f"Range '{range_name}' not found in the default domain."
        elif domain_name != "*":
            domain_df = nvtx.filter_by_domain_name(nvtx_df, domain_name)
            filtered_nvtx_df = filtered_nvtx_df[
                filtered_nvtx_df.index.isin(domain_df.index)
            ]

            if filtered_nvtx_df.empty:
                return (
                    None,
                    f"Range '{range_name}' not found in domain '{domain_name}'.",
                )

        if index is not None:
            row_count = len(filtered_nvtx_df)
            if index >= row_count:
                return None, f"Index {index} exceeds the number of rows ({row_count})."

            filtered_nvtx_df = filtered_nvtx_df.iloc[[index]]

        return filtered_nvtx_df, None

    def _filter_by_overlap(self, filtered_df: pd.DataFrame) -> None:
        type_df1 = self._df[["start", "end"]].assign(type="df1")
        type_df2 = filtered_df[["start", "end"]].assign(type="df2")

        all_df = pd.concat([type_df1, type_df2]).reset_index(drop=True)
        all_df["group"] = overlap.group_overlapping_ranges(all_df)

        group_df1 = all_df[all_df["type"] == "df1"]
        group_df2 = all_df[all_df["type"] == "df2"]

        mask = group_df1["group"].isin(group_df2["group"])
        self._df.drop(self._df[~mask].index, inplace=True)

    def filter_by_nvtx(
        self,
        nvtx_df: pd.DataFrame,
        range_name: str,
        domain_name: Optional[str] = None,
        index: Optional[int] = None,
    ) -> Optional[str]:
        """Filter the dataframes based on the matching NVTX ranges.

        Parameters
        ----------
        nvtx_df : dataframe
            Dataframe containing the NVTX ranges.
        range_name : str
            Name of the NVTX range to filter by.
        domain_name : str
            Name of the NVTX domain to filter by.
        index : int
            Index of the NVTX range to filter by.

        Returns
        -------
        err_msg : str or None
            Error message if there was an error, or None otherwise.
        """
        if self._df is None or self._df.empty:
            return None

        filtered_nvtx_df, err_msg = self._get_filtered_nvtx(
            nvtx_df, range_name, domain_name, index
        )

        if err_msg is not None:
            return err_msg

        assert filtered_nvtx_df is not None
        self._filter_by_overlap(filtered_nvtx_df)
        return None

    def filter_by_projected_nvtx(
        self,
        nvtx_df: pd.DataFrame,
        cuda_df: pd.DataFrame,
        range_name: str,
        domain_name: Optional[str] = None,
        index: Optional[int] = None,
    ) -> Optional[str]:
        """Filter the dataframes based on the matching projected NVTX ranges.

        Parameters
        ----------
        nvtx_df : dataframe
            Dataframe containing the NVTX ranges.
        cuda_df : dataframe
            Dataframe containing the CUDA ranges used to project the NVTX ranges.
        range_name : str
            Name of the NVTX range to filter by.
        domain_name : str
            Name of the NVTX domain to filter by.
        index : int
            Index of the NVTX range to filter by.

        Returns
        -------
        err_msg : str or None
            Error message if there was an error, or None otherwise.
        """
        if self._df is None or self._df.empty:
            return None

        filtered_nvtx_df, err_msg = self._get_filtered_nvtx(
            nvtx_df, range_name, domain_name, index
        )

        if err_msg is not None:
            return err_msg

        assert filtered_nvtx_df is not None

        # The projection functions need both the "start/end" and
        # "gpu_start/gpu_end" columns. If self._df is a CUDA dataframe that
        # should be used for the projection, we need to restore the original
        # range columns that were renamed.
        if self._df is cuda_df:
            assert self._original_df is not None
            cuda_df = self._original_df

        proj_nvtx_df = nvtx.project_nvtx_onto_gpu(filtered_nvtx_df, cuda_df)
        if proj_nvtx_df.empty:
            return None

        self._filter_by_overlap(proj_nvtx_df)
        return None


def get_time_cols(
    df: pd.DataFrame,
    start_column: Optional[str] = None,
    end_column: Optional[str] = None,
) -> tuple[str, str]:
    """Get the start and end time column names.

    Parameters
    ----------
    df : dataframe
        Dataframe to extract time columns from.
    start_column : str, optional
        Name of the column containing the start time. If not provided, the
        name will be deduced.
    end_column : str, optional
        Name of the column containing the end time. If not provided, the
        name will be deduced.

    Returns
    -------
    time_columns : tuple
        Tuple containing the start and end time column names.
    """
    if "start" in df.columns:
        if "end" in df.columns:
            # Time range.
            default_start_column, default_end_column = ("start", "end")
        else:
            # Point in time.
            default_start_column, default_end_column = ("start", "start")
    elif "timestamp" in df.columns:
        # Point in time.
        default_start_column, default_end_column = ("timestamp", "timestamp")
    elif "rawTimestamp" in df.columns:
        # Point in time.
        default_start_column, default_end_column = ("rawTimestamp", "rawTimestamp")
    else:
        raise NotImplementedError("Unknown time column format.")

    start_column = start_column or default_start_column
    end_column = end_column or default_end_column

    return start_column, end_column


def replace_id_with_value(
    main_df: pd.DataFrame,
    str_df: pd.DataFrame,
    id_column: str,
    value_col_name: Optional[str] = None,
) -> pd.DataFrame:
    """Replace the values in 'id_column' of 'main_df' with the corresponding
    string value stored in 'str_df'.

    Parameters
    ----------
    main_df : dataframe
        Dataframe containing 'id_column'.
    str_df : dataframe
        Dataframe 'StringId' that maps IDs to string values.
    id_column : str
        Name of the column that should be replaced with the corresponding
        string values.
    value_col_name : str, optional
        Name of the column that contains the string value of 'id_column'.
        If not specified, the 'id_column' will be retained as the column name.
    """
    renamed_str_df = str_df.rename(columns={"id": id_column})
    merged_df = main_df.merge(renamed_str_df, on=id_column, how="left")

    # Drop the original 'id_column' column.
    merged_df = merged_df.drop(columns=[id_column])
    # Rename the 'value' column.
    value_col_name = value_col_name or id_column
    return merged_df.rename(columns={"value": value_col_name})


def add_cols_from_global_pid(df: pd.DataFrame) -> None:
    # <Hardware ID:8><VM ID:8><Process ID:24><ThreadID:24>
    global_id = df["globalPid"]

    if "pid" not in df:
        # Technically, we could get pandas extension types here, we can't infer the specific type.
        df["pid"] = (global_id >> np.array(24, dtype=global_id.dtype)) & 0x00FFFFFF  # type: ignore[arg-type]


def add_cols_from_global_tid(df: pd.DataFrame) -> None:
    # <Hardware ID:8><VM ID:8><Process ID:24><ThreadID:24>
    global_id = df["globalTid"]

    if "pid" not in df:
        # Technically, we could get pandas extension types here, we can't infer the specific type.
        df["pid"] = (global_id >> np.array(24, dtype=global_id.dtype)) & 0x00FFFFFF  # type: ignore[arg-type]
    if "tid" not in df:
        df["tid"] = global_id & 0x00FFFFFF


def add_cols_from_type_id(df: pd.DataFrame) -> None:
    # <Hardware ID:8><VM ID:8><Source ID:16><Event tag:24><GPU ID:8>
    type_id = df["typeId"]

    if "gpuId" not in df:
        df["gpuId"] = type_id & 0xFF


def decompose_bit_fields(df: pd.DataFrame) -> None:
    if "globalPid" in df:
        add_cols_from_global_pid(df)
    elif "globalTid" in df:
        add_cols_from_global_tid(df)
    elif "typeId" in df:
        add_cols_from_type_id(df)


def filter_by_pattern(
    df: pd.DataFrame, pattern: Union[str, list[str]], column: str
) -> pd.DataFrame:
    """Filter the dataframe based on the matching regex pattern.

    Parameters
    ----------
    df : DataFrame
        Dataframe to filter.
    pattern : str or list of str
        Regex pattern or a list of patterns to use for filtering.
    column : str
        Name of the column to check for matches.
    """
    if isinstance(pattern, list):
        pattern = "|".join(pattern)

    return df[df[column].str.fullmatch(pattern)].reset_index(drop=True)
