#!/usr/bin/env python3
# Copyright (C) 2019 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.

# DB2 support requires installation of the IBM Data Server Client:
#  http://www-01.ibm.com/support/docview.wss?uid=swg27016878
# as well as the ibm_db2 Python DBI driver for DB2:
#  https://pypi.org/pypi/ibm_db

# SQLAnywhere support requires installation of the SAP SQL Anywhere binaries:
# https://help.sap.com/docs/SAP_SQL_Anywhere/a3e900ad39b94d689987e838835f39fe/8157f2236ce21014a9f387041b1c1047.html?locale=en-US&version=17.0
# as well as the sqlanydb Python DB driver for SqlAnywhere:
# https://pypi.org/project/sqlanydb/
"""Check_MK SQL Test"""
import argparse
import logging
import os
import sys
from collections.abc import Sequence
from typing import Any, NoReturn

import cmk.utils.password_store

cmk.utils.password_store.replace_passwords()

LOG = logging.getLogger(__name__)

DEFAULT_PORTS = {
    "postgres": 5432,
    "mysql": 3306,
    "mssql": 1433,
    "oracle": 1521,
    "db2": 50000,
    "sqlanywhere": 2638,
}

MP_INF: tuple[float, float] = (float("-inf"), float("+inf"))

#   . parse commandline argumens


def levels(values: str) -> tuple[float, float]:
    lower, upper = values.split(":")
    _lower = float(lower) if lower else MP_INF[0]
    _upper = float(upper) if upper else MP_INF[1]
    return (_lower, _upper)


def sql_cmd_piece(values: str) -> str:
    """Parse every piece of the SQL command (replace \\n and \\;)"""
    return values.replace(r"\n", "\n").replace(r"\;", ";")


def parse_args(argv: list[str]) -> argparse.Namespace:
    """Parse commandline arguments (incl password store and logging set up)"""
    this = str(os.path.basename(argv[0]))
    help_fmt = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(prog=this, description=__doc__, formatter_class=help_fmt)
    # flags
    parser.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help="""Verbose mode: print SQL statement and levels
                             (for even more output use -vv""",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="""Debug mode: let Python exceptions come through""",
    )
    parser.add_argument(
        "-m",
        "--metrics",
        nargs="?",
        metavar="METRIC_NAME",
        const="performance_data",
        help="""Add performance data to the output. Store data with metric_name in RRD.""",
    )
    parser.add_argument(
        "-o",
        "--procedure",
        action="store_true",
        help="""Treat the main argument as a procedure instead
                              of an SQL-Statement""",
    )
    parser.add_argument(
        "-i",
        "--input",
        metavar="CSV",
        default=[],
        type=lambda s: s.split(","),
        help="""Comma separated list of values of input variables
                             if required by the procedure""",
    )
    # optional arguments
    parser.add_argument(
        "-d",
        "--dbms",
        default="postgres",
        choices=["postgres", "mysql", "mssql", "oracle", "db2", "sqlanywhere"],
        help='''Name of the database management system.
                             Default is "postgres"''',
    )
    parser.add_argument(
        "-H",
        "--hostname",
        metavar="HOST",
        default="127.0.0.1",
        help='''Hostname or IP-Address where the database lives.
                             Default is "127.0.0.1"''',
    )
    parser.add_argument(
        "-P",
        "--port",
        default=None,
        type=int,
        help="""Port used to connect to the database.
                             Default depends on DBMS""",
    )
    parser.add_argument(
        "-w",
        "--warning",
        metavar="RANGE",
        default=MP_INF,
        type=levels,
        help="""Lower and upper level for the warning state,
                             separated by a colon""",
    )
    parser.add_argument(
        "-c",
        "--critical",
        metavar="RANGE",
        default=MP_INF,
        type=levels,
        help="""Lower and upper level for the critical state,
                             separated by a colon""",
    )
    parser.add_argument(
        "-t",
        "--text",
        default="",
        help="""Additional text prefixed to the output""",
    )

    # required arguments
    parser.add_argument(
        "-n",
        "--name",
        required=True,
        help="""Name of the database on the DBMS""",
    )
    parser.add_argument(
        "-u",
        "--user",
        required=True,
        help="""Username for database access""",
    )
    parser.add_argument(
        "-p",
        "--password",
        required=True,
        help="""Password for database access""",
    )
    parser.add_argument(
        "cmd",
        metavar="SQL-Statement|Procedure",
        type=sql_cmd_piece,
        nargs="+",
        help="""Valid SQL-Statement for the selected database.
                             The statement must return at least a number and a
                             string, plus optional performance data.

                             Alternatively: If the the "-o" option is given,
                             treat the argument as a procedure name.

                             The procedure must return one output variable,
                             which content is evaluated the same way as the
                             output of the SQL-Statement""",
    )
    args = parser.parse_args(argv[1:])
    args.cmd = " ".join(args.cmd)

    # LOGGING
    fmt = "%(message)s"
    if args.verbose > 1:
        fmt = "%(levelname)s: %(lineno)s: " + fmt
        if args.dbms == "mssql":
            os.environ["TDSDUMP"] = "stdout"
    logging.basicConfig(level=max(30 - 10 * args.verbose, 0), format=fmt)

    # V-VERBOSE INFO
    for key, val in args.__dict__.items():
        if key in ("user", "password"):
            val = "****"
        LOG.debug("argparse: %s = %r", key, val)
    return args


# .


def bail_out(exit_code: int, output: str) -> NoReturn:
    sys.stdout.write("%s\n" % output)
    sys.exit(exit_code)


#   . DBMS specific code here!
#
# For every DBMS specify a connect and execute function.
# Add them in the dict in the 'main' connect and execute functions
#
def _default_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: str
) -> list[tuple[Any, ...]]:
    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        LOG.info("Input Values: %s", inpt)
        cursor.callproc(cmd, inpt)
        LOG.debug("inpt after 'callproc' = %r", inpt)
    else:
        LOG.info("SQL Statement: %s", cmd)
        cursor.execute(cmd)

    return cursor.fetchall()


def postgres_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    import psycopg2  # type: ignore[import] # pylint: disable=import-outside-toplevel

    return psycopg2.connect(host=host, port=port, database=db_name, user=user, password=pwd)


def postgres_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: str
) -> list[tuple[Any, ...]]:
    return _default_execute(cursor, cmd, inpt, procedure)


def mysql_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    import pymysql  # pylint: disable=import-outside-toplevel

    return pymysql.connect(host=host, port=port, db=db_name, user=user, passwd=pwd)


def mysql_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: str
) -> list[tuple[Any, ...]]:
    return _default_execute(cursor, cmd, inpt, procedure)


def mssql_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    import pymssql  # type: ignore[import] # pylint: disable=import-outside-toplevel

    return pymssql.connect(host=host, port=port, database=db_name, user=user, password=pwd)


def mssql_execute(
    cursor: Any, cmd: str, _inpt: Sequence[str], procedure: bool
) -> list[tuple[Any, ...]]:
    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        cmd = "EXEC " + cmd
    else:
        LOG.info("SQL Statement: %s", cmd)

    cursor.execute(cmd)

    return cursor.fetchall()


def oracle_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    sys.path.append(
        f"/usr/lib/python{sys.version_info.major}.{sys.version_info.minor}/site-packages"
    )
    try:
        import oracledb  # type: ignore[import] # pylint: disable=import-error,import-outside-toplevel
    except ImportError as exc:
        bail_out(3, "%s. Please install it via 'pip install oracledb'." % exc)

    cstring = f"{user}/{pwd}@{host}:{port}/{db_name}"
    return oracledb.connect(cstring)


def oracle_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: bool
) -> list[tuple[Any, ...]]:
    try:
        import oracledb  # pylint: disable=import-error,import-outside-toplevel
    except ImportError as exc:
        bail_out(3, "%s. Please install it via 'pip install oracledb'." % exc)

    if procedure:
        LOG.info("SQL Procedure Name: %s", cmd)
        LOG.info("Input Values: %s", inpt)

        # In an earlier version, this code-branch
        # had been executed regardles of the dbms.
        # clearly this is oracle specific.
        outvar = cursor.var(oracledb.STRING)  # pylint: disable=undefined-variable
        # However, I have not been able to test it.
        parameters = [*inpt, outvar]
        cursor.callproc(cmd, parameters)

        LOG.debug("parameters after 'callproc' = %r", parameters)
        LOG.debug("outvar = %r", outvar)

        # for empty input this is just
        #  _res = outvar.getvalue()
        _res = ",".join(i.getvalue() for i in parameters)

        LOG.debug("outvar.getvalue() = %r", _res)
        params_result = _res.split(",")
        LOG.debug("params_result = %r", params_result)

    else:
        LOG.info("SQL Statement: %s", cmd)
        cursor.execute(cmd)

    return cursor.fetchall()


def db2_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    # IBM data server driver
    try:
        import ibm_db  # type: ignore[import] # pylint: disable=import-error,import-outside-toplevel
        import ibm_db_dbi  # type: ignore[import] # pylint: disable=import-error,import-outside-toplevel
    except ImportError as exc:
        bail_out(3, "%s. Please install it via pip." % exc)

    cstring = (
        "DRIVER={IBM DB2 ODBC DRIVER};DATABASE=%s;"
        "HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=%s;" % (db_name, host, port, user, pwd)
    )
    ibm_db_conn = ibm_db.connect(cstring, "", "")
    return ibm_db_dbi.Connection(ibm_db_conn)


def db2_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: str
) -> list[tuple[Any, ...]]:
    return _default_execute(cursor, cmd, inpt, procedure)


def sqlanywhere_connect(host: str, port: int, db_name: str, user: str, pwd: str) -> Any:
    try:
        import sqlanydb  # type: ignore[import] # pylint: disable=import-error,import-outside-toplevel
    except ImportError as exc:
        bail_out(3, "%s. Please install it via 'pip install sqlanydb'." % exc)
    return sqlanydb.connect(uid=user, pwd=pwd, dbn=db_name, host=f"{host}:{port}")


def sqlanywhere_execute(
    cursor: Any, cmd: str, inpt: Sequence[str], procedure: str
) -> list[tuple[Any, ...]]:
    return _default_execute(cursor, cmd, inpt, procedure)


# .


def connect(dbms: str, host: str, port: int | None, db_name: str, user: str, pwd: str) -> Any:
    """Connect to the correct database

    A python library is imported depending on the value of dbms.
    Return the created connection object.
    """
    if port is None:
        port = DEFAULT_PORTS[dbms]

    return {
        "postgres": postgres_connect,
        "mysql": mysql_connect,
        "mssql": mssql_connect,
        "oracle": oracle_connect,
        "db2": db2_connect,
        "sqlanywhere": sqlanywhere_connect,
    }[dbms](host, port, db_name, user, pwd)


def execute(
    dbms: str, connection: Any, cmd: str, inpt: Sequence[str], procedure: bool = False
) -> list[tuple[Any, ...]]:
    """Execute the sql statement, or call the procedure.

    Some corrections are made for libraries that do not adhere to the
    python SQL API: https://www.python.org/dev/peps/pep-0249/
    """
    cursor = connection.cursor()

    try:
        result = {
            "postgres": postgres_execute,
            "mysql": mysql_execute,
            "mssql": mssql_execute,
            "oracle": oracle_execute,
            "db2": db2_execute,
            "sqlanywhere": sqlanywhere_execute,
        }[dbms](
            cursor, cmd, inpt, procedure
        )  # type: ignore[operator]
    finally:
        cursor.close()
        connection.close()

    LOG.info("SQL Result:\n%r", result)
    return result


def process_result(
    result: list[tuple[Any, ...]],
    warn: tuple[float, float],
    crit: tuple[float, float],
    metrics: str | None,
    debug: bool,
) -> tuple[int, str]:
    """Process the first row (!) of the result of the SQL command.

    Only the first row of the result (result[0]) is considered.
    It is assumed to be an sequence of length 3, consisting of of
    [numerical_value, text, performance_data].
    The full result is returned as muliline output.
    """
    if not result:
        bail_out(3, "SQL statement/procedure returned no data")
    row0 = result[0]

    number = float(row0[0])

    # handle case where sql query only results in one column
    if len(row0) == 1:
        text = "%s" % row0[0]
    else:
        text = "%s" % row0[1]

    perf = ""
    if metrics:
        try:
            perf = f" | {metrics}={str(row0[2])}"
        except IndexError:
            if debug:
                raise

    state = 0
    if warn != MP_INF or crit != MP_INF:
        if not warn[0] <= number < warn[1]:
            state = 1
        if not crit[0] <= number < crit[1]:
            state = 2
        text += ": %s" % number
    else:  # no levels were given
        if number in (0, 1, 2, 3):
            state = int(number)
        else:
            bail_out(3, "<%d> is not a state, and no levels given" % number)

    return state, text + perf


def main(argv: list[str] | None = None) -> None:
    args = parse_args(argv or sys.argv)

    msg = "connecting to database"
    try:
        conn = connect(args.dbms, args.hostname, args.port, args.name, args.user, args.password)

        msg = "executing SQL command"
        result = execute(args.dbms, conn, args.cmd, args.input, procedure=args.procedure)

        msg = "processing result of SQL statement/procedure"
        state, text = process_result(
            result,
            args.warning,
            args.critical,
            metrics=args.metrics,
            debug=args.debug,
        )
    except Exception as exc:
        if args.debug:
            raise
        errmsg = str(exc).strip("()").replace(r"\n", " ")
        bail_out(3, f"Error while {msg}: {errmsg}")

    bail_out(state, args.text + text)


if __name__ == "__main__":
    main()
