#!/usr/bin/env python3
# Copyright (C) 2019 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.

import getopt
import json
import os
import subprocess
import sys
import time
import traceback
from pathlib import Path

import requests
import urllib3
from requests_kerberos import HTTPKerberosAuth  # type: ignore[import]

import cmk.utils.password_store
from cmk.utils.crypto.secrets import AutomationUserSecret
from cmk.utils.user import UserId

cmk.utils.password_store.replace_passwords()

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


def usage():
    sys.stderr.write(
        """
USAGE: check_bi_aggr -b <BASE_URL> -a <AGGR_NAME> -u <USER> -s <SECRET>
                     [-m <AUTH_MODE>] [-r] [-n <HOSTNAME>] [-t <TIMEOUT>] [-d]

OPTIONS:
  -b BASE_URL           The base URL to the monitoring environment, e.g.
                        http://<hostname>/<site-id>
  -a AGGR_NAME          Name of the aggregation, not the aggregation group.
                        It is possible that there are multiple aggregations
                        with an equal name, but you should ensure, that it
                        is a unique one to prevent confusions
  -u USER               User-ID of an automation user which is permitted to
                        see all contents of the aggregation
  -s SECRET             Automation secret of the user
  --use-automation-user Use credentials from the local "automation" user
  -m AUTH_MODE          Authentication mode, either "basic", "digest",
                        "header" or "kerberos", defaults to "header"
  -t TIMEOUT            HTTP connect timeout in seconds (Default: 60)
  -r                    track downtimes. This requires the hostname to be set.
  -n HOSTNAME           The hostname for which this check is run.
  --in-downtime S       S can be "ok" or "warn". Force this state if the
                        aggregate is in scheduled downtime. OK states will always
                        be unchanged.
  --acknowledged S      Same as --in-downtime, but for acknowledged aggregates.
  -d                    Enable debug mode
  -h, --help            Show this help message and exit

"""
    )


short_options = "b:a:u:s:m:t:n:dhr"
long_options = ["help", "in-downtime=", "acknowledged=", "use-automation-user"]

try:
    opts, args = getopt.getopt(sys.argv[1:], short_options, long_options)
except getopt.GetoptError as err:
    sys.stderr.write("%s\n" % err)
    sys.exit(1)

base_url = None
aggr_name = None

username = ""
password: str | None = None
use_automation_user = False

auth_mode = "header"
timeout = 60
debug = False
opt_in_downtime = None
opt_acknowledged = None
track_downtime = False
hostname = None

for o, a in opts:
    if o in ["-h", "--help"]:
        usage()
        sys.exit(0)
    elif o == "-b":
        base_url = a
    elif o == "-a":
        aggr_name = a
    elif o == "-u":
        username = a
    elif o == "-s":
        password = a
    elif o == "-m":
        auth_mode = a
    elif o == "-t":
        timeout = int(a)
    elif o == "-r":
        track_downtime = True
    elif o == "-n":
        hostname = a
    elif o == "-d":
        debug = True
    elif o == "--in-downtime":
        opt_in_downtime = a
    elif o == "--acknowledged":
        opt_acknowledged = a
    elif o == "--use-automation-user":
        use_automation_user = True

if not base_url:
    sys.stderr.write("Please provide the URL to the monitoring instance.\n")
    usage()
    sys.exit(1)

if not aggr_name:
    sys.stderr.write("Please provide the name of the aggregation.\n")
    usage()
    sys.exit(1)

if use_automation_user:
    username = "automation"
    try:
        password = AutomationUserSecret(UserId(username)).read()
    except (OSError, ValueError):
        sys.stderr.write('Unable to read credentials for "automation" user.\n')
        sys.exit(1)

if not username or not password:
    sys.stderr.write("Please provide valid user credentials.\n")
    usage()
    sys.exit(1)

if track_downtime and not hostname:
    sys.stderr.write("Please provide a hostname when using downtime tracking.\n")
    usage()
    sys.exit(1)


def init_auth(pw: str) -> requests.auth.AuthBase | None:
    if auth_mode == "kerberos":
        kinit_completed = subprocess.run(
            ["kinit", username],
            input="%s\n" % pw,
            capture_output=True,
            encoding="utf-8",
            check=False,
        )
        if kinit_completed.returncode or kinit_completed.stderr:
            sys.stderr.write("Error getting Kerberos Ticket:\n")
            sys.stderr.write(
                f"stdout: {kinit_completed.stdout}\nstderr: {kinit_completed.stderr}\nrc: {kinit_completed.returncode}"
            )
            sys.exit(1)
        return HTTPKerberosAuth(principal=username)
    if auth_mode == "digest":
        return requests.auth.HTTPDigestAuth(username, pw)
    if auth_mode == "basic":
        return requests.auth.HTTPBasicAuth(username, pw)
    if auth_mode == "header":
        return None
    raise Exception("Unknown auth mode: " + auth_mode)


endpoint_url = f"{base_url.rstrip('/')}/check_mk/api/1.0/domain-types/bi_aggregation/actions/aggregation_state/invoke"

auth = init_auth(password)

if debug:
    sys.stderr.write("URL: %s\n" % endpoint_url)

try:
    r = requests.post(
        endpoint_url,
        timeout=timeout,
        auth=auth,
        headers={"Authorization": f"Bearer {username} {password}"} if not auth else None,
        json={"filter_names": [aggr_name]},
    )
    r.raise_for_status()
    raw_response = r.text
except requests.Timeout:
    sys.stdout.write("ERROR: Socket timeout while opening URL: %s\n" % (endpoint_url))
    sys.exit(3)
except requests.URLRequired as e:
    sys.stdout.write("UNKNOWN: %s\n" % e)
    sys.exit(3)
except Exception as e:
    sys.stdout.write(
        f"ERROR: Exception while opening URL: {endpoint_url} - {e}\n{traceback.format_exc()}"
    )
    sys.exit(3)

try:
    response_data = json.loads(raw_response)
except Exception as e:
    sys.stdout.write(f"ERROR: Invalid response ({e}): {raw_response}\n")
    sys.exit(3)

try:
    aggr_state = response_data["aggregations"][aggr_name]["state"]
except KeyError:
    sys.stdout.write(f"ERROR Aggregation {aggr_name} does not exist or user is not permitted")
    sys.exit(3)

if aggr_state == -1:
    aggr_state = 3

aggr_output = "Aggregation state is %s" % ["OK", "WARN", "CRIT", "UNKNOWN"][aggr_state]

# Handle downtimes and acknowledgements
is_aggr_in_downtime = response_data["aggregations"][aggr_name]["in_downtime"]
if opt_in_downtime and is_aggr_in_downtime:
    aggr_output += ", currently in downtime"
    if opt_in_downtime == "ok":
        aggr_state = 0
    else:  # "warn"
        aggr_state = min(aggr_state, 1)

if track_downtime:
    # connect to livestatus
    try:
        import livestatus
    except ImportError:
        sys.stderr.write(
            "The python livestatus api module is missing. Please install from\n"
            "Check_MK livestatus sources to a python import path.\n"
        )
        sys.exit(1)

    socket_path = Path(os.environ["OMD_ROOT"]) / "tmp/run/live"

    conn = livestatus.SingleSiteConnection(f"unix:{socket_path}")

    now = time.time()
    # find out if, according to previous tracking, there already is a downtime
    ids = conn.query_table(
        (
            "GET downtimes\n"
            "Columns: id\n"
            "Filter: host_name = %s\n"
            "Filter: service_description = %s\n"
            "Filter: author = tracking\n"
            "Filter: end_time > %d"
        )
        % (hostname, aggr_name, now)
    )
    downtime_tracked = len(ids) > 0
    if downtime_tracked != is_aggr_in_downtime:
        # there is a discrepance between tracked downtime state and the real state
        if is_aggr_in_downtime:
            # need to track downtime
            conn.command(
                "[%d] SCHEDULE_SVC_DOWNTIME;%s;%s;%d;%d;1;0;0;"
                "tracking;Automatic downtime" % (now, hostname, aggr_name, now, 2147483647)
            )
        else:
            for dt_id in ids:
                conn.command("[%d] DEL_SVC_DOWNTIME;%d" % (now, dt_id[0]))

is_aggr_acknowledged = response_data["aggregations"][aggr_name]["acknowledged"]
if opt_acknowledged and is_aggr_acknowledged:
    aggr_output += ", is acknowledged"
    if opt_acknowledged == "ok":
        aggr_state = 0
    else:  # "warn"
        aggr_state = min(aggr_state, 1)

sys.stdout.write("%s\n" % aggr_output)
sys.exit(aggr_state)
