# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import os
from pathlib import Path

import nsysstats
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

from nsys_recipe.lib import export
from nsys_recipe.log import logger

DEFAULT_FORMAT = "parquetdir"


class TableNotFoundError(Exception):
    def __init__(self, table):
        super().__init__(f"Table '{table}' does not exist.")


class ColumnNotFoundError(Exception):
    def __init__(self, column, table):
        super().__init__(f"Column '{column}' does not exist in table '{table}'.")


class InvalidExportError(Exception):
    def __init__(self, filename):
        super().__init__(f"Could not validate {filename}.")


class _Loader:
    output_suffix = ""

    def __init__(self, report_path):
        self._report_path = report_path

    def validate_export_time(self, path):
        """Check whether the export file is newer than the report file."""
        return os.path.getctime(self._report_path) < os.path.getctime(path)

    def validate_table(self, path, table, columns=None):
        """Check whether the given table and columns are present in the
        export file.

        Parameters
        ----------
        path : str
            Path to the export file or directory.
        table : str
            Name of the table to validate.
        columns : list of str, optional
            List of columns to validate. If not given, all columns will be validated.

        Returns
        -------
        handle : object
            Object that can be used as input to the 'read_table' function.
        """
        raise NotImplementedError()

    def read_table(self, handle, table, columns=None):
        """Read table from the export file.

        Parameters
        ----------
        handle : object
            Object obtained from the 'validate_table' function.
        table : str
            Name of the table to read.
        columns : list of str, optional
            List of columns to read. If not given, all columns will be read.

        Returns
        -------
        result : dataframe
            Table as a dataframe.
        """
        raise NotImplementedError()

    def list_tables(self, path):
        raise NotImplementedError()

    def get_export_path(self):
        if not self.output_suffix:
            raise NotImplementedError("output_suffix must be set.")

        return str(self._report_path.with_suffix("")) + self.output_suffix


class ParquetLoader(_Loader):
    output_suffix = "_pqtdir"
    file_extension = ".parquet"

    def validate_table(self, path, table, columns=None):
        file_path = Path(path) / f"{table}{self.file_extension}"

        if not Path(file_path).exists():
            raise TableNotFoundError(table)

        if not self.validate_export_time(file_path):
            raise InvalidExportError(file_path)

        parquet_file = pq.ParquetFile(file_path)
        parquet_schema = parquet_file.schema
        parquet_columns = parquet_schema.names

        columns = columns or []
        for column in columns:
            if column not in parquet_columns:
                raise ColumnNotFoundError(column, table)

        return parquet_file

    def read_table(self, handle, table, columns=None):
        return handle.read(columns).to_pandas()

    def list_tables(self, path):
        path = Path(path)

        if not path.exists():
            return []

        filenames = list(path.glob(f"*{self.file_extension}"))
        return [filename.stem for filename in filenames]


class ArrowLoader(_Loader):
    output_suffix = "_arwdir"
    file_extension = ".arrow"

    def validate_table(self, path, table, columns=None):
        file_path = Path(path) / f"{table}{self.file_extension}"

        if not Path(file_path).exists():
            raise TableNotFoundError(table)

        if not self.validate_export_time(file_path):
            raise InvalidExportError(file_path)

        reader = pa.RecordBatchStreamReader(file_path)
        arrow_schema = reader.schema
        arrow_columns = arrow_schema.names

        columns = columns or []
        for column in columns:
            if column not in arrow_columns:
                raise ColumnNotFoundError(column, table)

        return reader

    def read_table(self, handle, table, columns=None):
        df = handle.read_pandas()
        return df[columns] if columns else df

    def list_tables(self, path):
        path = Path(path)

        if not path.exists():
            return []

        filenames = list(path.glob(f"*{self.file_extension}"))
        return [filename.stem for filename in filenames]


class SqliteLoader(_Loader):
    output_suffix = ".sqlite"
    file_extension = ".sqlite"

    def __init__(self, report_path):
        super().__init__(report_path)
        self._sql_report = None
        self._sqlite_report_ctime = None

    def _get_sql_report(self, path):
        if (
            self._sql_report is None
            or self._sql_report.dbfile != path
            or self._sqlite_report_ctime != os.path.getctime(path)
        ):
            self._sql_report = nsysstats.Report(path)
            self._sqlite_report_ctime = os.path.getctime(path)

        return self._sql_report

    def validate_tables(self, path, tables):
        if not Path(path).exists() or not self.validate_export_time(path):
            raise InvalidExportError(path)

        sql_report = self._get_sql_report(path)

        tables = tables or []
        for table in tables:
            if not sql_report.table_exists(table):
                raise TableNotFoundError(table)

        return sql_report

    def validate_table(self, path, table, columns=None):
        sql_report = self.validate_tables(path, [table])

        columns = columns or []
        for column in columns:
            if not sql_report.table_col_exists(table, column):
                raise ColumnNotFoundError(column, table)

        return sql_report

    def read_sql_query(self, handle, query):
        return pd.read_sql(query, handle.dbcon)

    def read_table(self, handle, table, columns=None):
        column_query = ",".join(columns) if columns else "*"
        query = f"SELECT {column_query} FROM {table}"
        return self.read_sql_query(handle, query)

    def list_tables(self, path):
        if Path(path) is None:
            return []

        try:
            return self._get_sql_report(path).tables
        except Exception:
            return []


class ServiceFactory:
    def __init__(self, report_path):
        if not Path(report_path).exists:
            raise FileNotFoundError(f"{report_path} does not exist.")

        self._report_path = Path(report_path)
        self._service_instances = {}

    def _create_service(self, format):
        loader_map = {
            "parquetdir": ParquetLoader,
            "arrowdir": ArrowLoader,
            "sqlite": SqliteLoader,
        }

        if format not in loader_map:
            raise NotImplementedError("Invalid format type.")

        return loader_map[format](self._report_path)

    def get_service(self, format):
        if format not in self._service_instances:
            self._service_instances[format] = self._create_service(format)

        return self._service_instances[format]


class DataReader:
    """The DataReader class provides a high-level interface for exporting and
    reading data from Nsight Systems report files."""

    def __init__(self, report_path):
        self._service_factory = ServiceFactory(report_path)
        self._report_path = Path(report_path)

    def _handle_exceptions(self, e):
        if isinstance(e, TableNotFoundError):
            logger.error(
                f"{self._report_path}: {e}"
                " Please ensure the table name is correct or re-try with a recent version of Nsight Systems."
            )
        elif isinstance(e, ColumnNotFoundError):
            logger.error(
                f"{self._report_path}: {e}"
                " Please ensure the column name is correct or re-try with a recent version of Nsight Systems."
            )
        else:
            raise e

    def get_export_path(self, service, path):
        export_path = service.get_export_path()

        if path is not None:
            export_path = str(Path(path) / Path(export_path).name)

        return export_path

    def _check_deprecation_for_recipes(self, result_dict, hints):
        check_deprecation = hints.get("check_deprecation", True)

        if not check_deprecation or result_dict is None:
            return True

        for table, df in result_dict.items():
            if not df.empty:
                continue

            if (
                table == "ANALYSIS_DETAILS"
                or table == "TARGET_INFO_SESSION_START_TIME"
                or table == "NIC_ID_MAP"
            ):
                logger.error(
                    f"{self._report_path}: Report is outdated and does not contain '{table}'."
                    " Please generate a new report file using a recent version of Nsight Systems."
                )
                return False

        return True

    def _read_tables_and_report_missing(self, table_column_dict, report_missing, hints):
        """Read known tables and report any missing tables.

        Returns
        -------
        result_dict : dict
            Dictionary mapping table names to dataframes.
        missing_dict : dict
            Dictionary mapping missing table names to lists of column names.
        """
        service = self._service_factory.get_service(hints.get("format", DEFAULT_FORMAT))
        export_path = self.get_export_path(service, hints.get("path"))

        result_dict = {}
        missing_dict = {}

        for table, columns in table_column_dict.items():
            try:
                handle = service.validate_table(export_path, table, columns)
                result_dict[table] = service.read_table(handle, table, columns)
            except Exception as e:
                if report_missing:
                    self._handle_exceptions(e)
                    return None, None

                missing_dict[table] = columns

        return result_dict, missing_dict

    def _read_sql_query(self, query, tables, report_missing, hints):
        format_type = hints.get("format", "sqlite")
        service = self._service_factory.get_service(format_type)
        export_path = self.get_export_path(service, hints.get("path"))
        overwrite = hints.get("overwrite", False)

        if overwrite:
            return None

        try:
            handle = service.validate_tables(export_path, tables)
            df = service.read_sql_query(handle, query)
            return df
        except Exception as e:
            if report_missing:
                self._handle_exceptions(e)
            return None

    def export(self, tables, hints=None):
        if hints is None:
            hints = {}

        format_type = hints.get("format", DEFAULT_FORMAT)
        service = self._service_factory.get_service(format_type)
        export_path = self.get_export_path(service, hints.get("path"))
        export_args = hints.get("export_args", None)

        return export.export_file(
            self._report_path, tables, format_type, export_path, export_args
        )

    def read_tables(self, table_column_dict, hints=None):
        """Read tables into dataframes.

        Parameters
        ----------
        table_column_dict : dict
            Dictionary mapping table names to column names to be read.
        hints : dict, optional
            Additional configurations. The supported hints are:
            - 'format' (str): the export file format. Default is 'parquetdir'.
            - 'path' (str): the export file path. Default is in the same
                directory as the report file.
            - 'overwrite' (bool): whether to fresh export even though the
                existing file is valid. Default is False.
            - 'check_deprecation' (bool): whether to check if report file is
                deprecated for recipes. Default is True.
            - 'export_args' (list): a list of arguments to be passed when
                calling `nsys export`.

        Returns
        -------
        result : dict or None
            Dictionary containing the dataframes for each table, or None if
            there was an error reading at least one table.
        """
        if hints is None:
            hints = {}

        overwrite = hints.get("overwrite", False)
        if overwrite:
            result_dict, missing_dict = {}, table_column_dict
        else:
            result_dict, missing_dict = self._read_tables_and_report_missing(
                table_column_dict, False, hints
            )

        if missing_dict:
            missing_tables = list(missing_dict.keys())

            if not self.export(missing_tables, hints):
                return None

            remaining_dict, _ = self._read_tables_and_report_missing(
                missing_dict, True, hints
            )

            if remaining_dict is None:
                return None

            result_dict.update(remaining_dict)

        if not self._check_deprecation_for_recipes(result_dict, hints):
            return None

        return result_dict

    def read_table(self, table, columns=None, hints=None):
        """Read a single table into a dataframe.

        Parameters
        ----------
        table : str
            Name of the table to read.
        columns : list of str, optional
            List of columns to read. If not given, all columns will be read.
        hints : dict, optional
            Additional configurations. The supported hints are:
            - 'format' (str): the export file format. Default is 'parquetdir'.
            - 'path' (str): the export file path. Default is in the same
                directory as the report file.
            - 'overwrite' (bool): whether to fresh export even though the
                existing file is valid. Default is False.
            - 'check_deprecation' (bool): whether to check if report file is
                deprecated for recipes. Default is True.
            - 'export_args' (list): a list of arguments to be passed when
                calling `nsys export`.

        Returns
        -------
        result : dataframe or None
            Dataframe containing the table, or None if there was an error.
        """
        return self.read_tables({table: columns}, hints).get(table)

    def read_sql_query(self, query, tables=None, hints=None):
        """Read the SQL query into a dataframe.

        Parameters
        ----------
        query : str
            SQL query to execute.
        tables : list of str or str, optional
            If specified, the function will export the tables before executing
            the query and check whether the table names are valid. If no
            tables are provided, all tables will be exported, and no checks
            will be made before executing the query.
        hints : dict, optional
            Additional configurations. The supported hints are:
            - 'format' (str): the export file format. Default is 'sqlite'.
            - 'path' (str): the export file path. Default is in the same
                directory as the report file.
            - 'overwrite' (bool): whether to fresh export even though the
                existing file is valid. Default is False.
            - 'export_args' (list): a list of arguments to be passed when
                calling `nsys export`.

        Returns
        -------
        result : dataframe or None
            Result of the SQL query, or None if there was an error.
        """
        if hints is None:
            hints = {"format": "sqlite"}
        elif hints.setdefault("format", "sqlite") != "sqlite":
            raise NotImplementedError("Invalid format type.")

        if isinstance(tables, str):
            tables = [tables]

        df = self._read_sql_query(query, tables, False, hints)
        if df is not None:
            return df

        if not self.export(tables, hints):
            return None

        return self._read_sql_query(query, tables, True, hints)

    def list_tables(self, hints=None):
        """List the available tables in the report file."""
        if hints is None:
            hints = {}

        service = self._service_factory.get_service(hints.get("format", DEFAULT_FORMAT))
        export_path = self.get_export_path(service, hints.get("path"))

        return service.list_tables(export_path)
