#!/usr/bin/env python3
# Copyright (C) 2019 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.
"""
This program overwrites __import__ and tracks memory before
and after any import statement issued. The resulting data can
be emitted in various formats for futher processing.
"""

import argparse
import contextlib
import gc
import importlib
import inspect
import json
import os
import subprocess
import sys
from collections.abc import Iterator
from sysconfig import get_path
from types import FrameType
from typing import Any, NamedTuple

# Some constants to help parse /proc/*/smaps
KiB = 1024
MiB = KiB * KiB


# Struct to keep the values. We're only interested in the
# private memory, as this is the only one the program has any
# control over. We can't possibly reduce any other type of memory.
class Memory(NamedTuple):
    private: int
    pss: int
    rss: int


def _get_python_stdlib() -> Iterator[str]:
    std_lib = get_path("stdlib")
    # noinspection PyTypeChecker
    for top, _dirs, files in os.walk(std_lib):
        if "dist-packages" in top:
            continue

        if "__init__.py" in files:
            yield top[len(std_lib) + 1 :]

        for nm in files:
            if nm != "__init__.py" and nm[-3:] == ".py":
                yield os.path.join(top, nm)[len(std_lib) + 1 : -3].replace("/", ".")


# We pull a list of Python's stdlib so we can color it differently.
_STDLIB = sorted(_get_python_stdlib())


def log(msg, *args):
    sys.stderr.write(msg % args)
    sys.stderr.write("\n")
    sys.stderr.flush()


class ConsoleFormatter:
    __slots__ = ["out"]

    def __init__(self, args) -> None:  # type: ignore[no-untyped-def]
        self.out = sys.stdout

    def push(self, importer, imported):
        self.out.write(">")
        self.out.flush()

    def pop(self):
        self.out.write("<")
        self.out.flush()

    def setsize(self, _size):
        self.out.write("X")
        self.out.flush()

    def finish(self):
        pass


class JSONFormatter:
    """

    Output looks like this:

    {"name": "root",
     "children": [{"name": "child1", "value": 12345},
                  {"name": "child2", "value": 23456,
                   "children": [{...}]}]}


    """

    __slots__ = ("metric", "stack", "open")

    def __init__(self, args) -> None:  # type: ignore[no-untyped-def]
        self.open = args.open
        self.metric = {"name": "root"}
        # TODO: The stack is an almost untypeable mess...
        self.stack: list[Any] = [self.metric]

    def push(self, from_, module):
        node = {"name": module}
        parent = self.stack[-1]
        if "children" not in parent:
            parent["children"] = []

        # Prevent duplicate children with the same name, which can be
        # if there are multiple from-imports
        already_added = [_n for _n in parent["children"] if _n["name"] == node["name"]]
        if len(already_added) == 1:
            _node = already_added[0]
            _node["value"] += node.get("value", 0)
            node = _node
        else:
            parent["children"].append(node)
        self.stack.append(node)

    def setsize(self, _size):
        self.stack[-1]["value"] = _size

    def pop(self):
        self.stack.pop()

    def finish(self):
        # HACKETY HACK
        self.stack[0] = remove_importlib(self.stack[0])

        # First we sum up all the children to their parents, so each parent containing space
        # consuming children has a value with at least their sum.
        sum_values(self.stack[0])
        adjust_values(self.stack[0])

        # We now remove everything of value <= 0 because we don't have any interest in any of that.
        # FIXME Not sure if this is even still the case.
        # remove_zeros(self.stack[0])

        filename = os.path.abspath("import-tree.%d.json" % (os.getpid(),))
        with open(filename, "w") as jsf:
            jsf.write(self.render_json())
        sys.stderr.write(f"Written to {filename}\n")
        sys.stderr.flush()

        if self.open:
            _spawn("gvim", filename)

    def render_json(self) -> str:
        # We skip to the first entry, so we start with a dict as root
        return json.dumps(self.stack[0], indent=4)


class GraphVizFormatter:
    __slots__ = ("from_", "to", "out", "open")

    def __init__(self, args) -> None:  # type: ignore[no-untyped-def]
        self.open = args.open
        self.from_: list = []
        self.to: list = []
        self.out: list = []

    def push(self, from_, to):
        self.from_.append(from_)
        self.to.append(to)

    def pop(self):
        self.from_.pop()
        self.to.pop()

    def setsize(self, size):
        self.out.append((self.from_[-1], self.to[-1], size))

    def finish(self):
        filename = "import-graph.%d.dot" % (os.getpid(),)
        args: list[str] = []
        with open(filename, "w") as dot:
            dot.write("digraph G {\n")
            dot.write("   overlap = false;\n")

            # Draw edges.
            for from_, to, _size in self.out:
                args[:] = []
                dot.write(f'    "{from_}" -> "{to}"')

                if from_.startswith("cmk"):
                    args.append("color=red")
                elif from_ not in _STDLIB:
                    args.append("color=blue")

                if to.startswith("cmk"):
                    args.append("color=red")
                elif to not in _STDLIB:
                    args.append("color=blue")

                if args:
                    dot.write(" [%s]" % ",".join(args))
                dot.write(";\n")

            # Style nodes.
            for from_, to, _size in self.out:
                if from_.startswith("cmk"):
                    dot.write(f'   "{from_}" [color=red];\n')
                elif from_ not in _STDLIB:
                    dot.write(f'   "{from_}" [color=blue];\n')

                if to.startswith("cmk"):
                    dot.write(f'   "{to}" [color=red];\n')
                elif to not in _STDLIB:
                    dot.write(f'   "{to}" [color=blue];\n')

            dot.write("}\n")
        log("Written to %s", os.path.abspath(filename))

        if self.open:
            self._write_and_open_images(filename)

    def _write_and_open_images(self, filename):
        log("Launching eog")

        #       neato_image = "neato-%s.png" % (filename,)
        #       os.system("neato -Tpng -o%s %s" % (neato_image, filename))
        #       log("Wrote %s", os.path.abspath(neato_image))
        #       _spawn("eog", neato_image)

        dot_image = f"dot-{filename}.png"
        with subprocess.Popen(["dot", "-Tpng", "-o" + dot_image, filename]):
            pass
        log("Wrote %s", os.path.abspath(dot_image))
        _spawn("eog", dot_image)


FORMATTERS = {
    "json": JSONFormatter,
    "dot": GraphVizFormatter,
    "console": ConsoleFormatter,
}


class Formatter:
    def __init__(self, *formatters) -> None:  # type: ignore[no-untyped-def]
        self.formatters = formatters

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            raise
        for formatter in self.formatters:
            formatter.finish()

    def push(self, importer, imported):
        for formatter in self.formatters:
            formatter.push(importer, imported)

    def pop(self):
        for formatter in self.formatters:
            formatter.pop()

    def setsize(self, _size):
        for formatter in self.formatters:
            formatter.setsize(_size)


_FORMATTER: Formatter | None = None


def _get_mem(pid: int) -> Memory:
    private = 0
    pss = 0
    rss = 0

    def _get_smaps(_pid):
        with open("/proc/%d/smaps" % _pid) as fp:
            return fp.read()

    def _parse_amount(_line):
        _name, amount, unit = _line.split()
        val = int(amount)

        # pylint: disable=raising-format-tuple
        if unit == "kB":
            val *= KiB
        elif unit == "MB":
            val *= MiB
        else:
            raise RuntimeError(
                "Unsupported memory unit for Private_* memory info from /proc/%d/smaps", pid
            )
        # pylint: enable=raising-format-tuple
        return val

    for line in _get_smaps(pid).split("\n"):
        if line.startswith("Private_"):
            private += _parse_amount(line)
        elif line.startswith("Pss:"):
            pss += _parse_amount(line)
        elif line.startswith("Rss:"):
            rss += _parse_amount(line)
        else:
            continue

    return Memory(private=private, pss=pss, rss=rss)


def _get_delta(mem1: Memory, mem2: Memory) -> Memory:
    return Memory(
        private=mem2.private - mem1.private,
        pss=mem2.pss - mem1.pss,
        rss=mem2.rss - mem1.rss,
    )


class MeasureMemory:
    __slots__ = ["pid", "mem_before", "mem_after"]

    def __init__(self) -> None:
        self.pid = os.getpid()

    def __enter__(self):
        self.mem_before = _get_mem(self.pid)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.mem_after = _get_mem(self.pid)
        if exc_type:
            raise

    def diff(self):
        return _get_delta(self.mem_before, self.mem_after)


@contextlib.contextmanager
def track_growth(parent, module, fromlist):
    # This is a separate contextmanager so we can actually gracefully recover from any errors.
    if _FORMATTER is None:
        raise Exception("no formatter")

    _FORMATTER.push(parent, module)
    try:
        with MeasureMemory() as mem:
            yield
            size = mem.diff().private
            _FORMATTER.setsize(size)
    finally:
        _FORMATTER.pop()


__imp = __import__

stack = ["__main__"]


def _back(frame: FrameType | None) -> FrameType:
    assert frame is not None
    b = frame.f_back
    assert b is not None
    return b


def _importer(
    name: str,
    _globals: dict | None = None,
    _locals: dict | None = None,
    fromlist: list | None = None,
    level: int = -1,
) -> Any:
    # Wrap __import__ and do some shenanigans to actually track what's happening.
    if fromlist is None:
        fromlist = []

    _stack = inspect.stack()
    without_us = [frame for frame in _stack if frame[1] != __file__]

    current = inspect.currentframe()
    parent = _back(_back(_back(current))).f_globals.get("__name__")

    if _globals is None or _locals is None:
        # Someone did call __import__('foo') without specifying
        # globals and locals. I'm looking at you, six!
        # We get our globals from the parent frame.
        _globals = _back(current).f_globals
        _locals = _back(current).f_locals

    def module_name_from_import_stmt(line):
        # may be 'exec("import foo as bar")'. Thank you reportlab.
        if "from" in line:
            ret = line.split("from")[1].split()[0]
        elif "import" in line:
            ret = line.split("import")[1].split()[0]
        else:
            ret = "? (%s)" % line.strip()
        return ret

    def looks_like_import(lines):
        # WARNING: `line` actually may be None
        return lines and any(part in lines[0] for part in ("import", "from"))

    def hunt_for_module_name(frames):
        # Some crappy^H^H^H^H^H^H external packages do some weird stuff with exec and import.
        # Sometimes it's not even possible to figure out what has been imported as the import
        # stmts are read from a file and executed through layers of wrapper functions leaving us
        # with strange frames on the stack containing no source code whatsoever. We must skip those.
        for _frame in frames:
            # _frame[4] is `code_context`, meaning the source code actually doing the
            # import (in some way...). It's a list of source lines, but I have only
            # seen lengths of 1 until now. Example: ["   from foo import bar\n"]
            if not looks_like_import(_frame[4]):
                continue

            # We take the first one we can find.
            candidate = module_name_from_import_stmt(_frame[4][0])
            if candidate.startswith("."):
                ret = _back(_back(current)).f_globals["__name__"] + module_name
            else:
                ret = candidate
            return ret

    module_name = _globals.get("__name__")

    # c-libraries sometimes don't have a name, or have no _globals and thus somehow fall through
    # to the base globals() so we figure out how they were called when they were imported and
    # reconstruct their real name that way.
    if not module_name or module_name == "__main__":
        module_name = hunt_for_module_name(without_us)

        if not module_name:
            # Well
            module_name = "? (unknown)"

    # Here goes the real magic.
    with track_growth(parent, module_name, fromlist):
        try:
            return __imp(name, _globals, _locals, fromlist, level)
        except ValueError as e:
            # Until Python 3.9 a call to __import__ fails with a ValueError whenever
            # a module tries to do a relative-import extending it's package boundaries.
            # Starting with 3.9 this will be an ImportError
            raise ImportError(e)


def main(args):
    global _FORMATTER
    formatters = [FORMATTERS[key[0]](args) for key in (args.formatter or [])]
    if not args.formatter:
        formatters.append(FORMATTERS["console"](args))

    _FORMATTER = Formatter(*formatters)

    try:
        # TODO: The types don't really match:
        # expr: Callable[[str,     dict   [Any, Any] | None, dict   [Any, Any] | None, list[Any] | None, int], Any]
        # var:  Callable[[unicode, Mapping[str, Any] | None, Mapping[str, Any] | None, Sequence[str]   , int], Any]
        __builtins__.__import__ = _importer  # type: ignore[assignment]
    except AttributeError:
        __builtins__["__import__"] = _importer

    # Not sure how much this helps.
    gc.disable()

    with _FORMATTER:
        for module in args.modules:
            importlib.import_module(module)


@contextlib.contextmanager
def basedir():
    # type () -> Iterator
    root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
    # So cmk will actually be found...
    sys.path.insert(0, root.decode("utf-8").strip())
    try:
        yield
    finally:
        sys.path.pop(0)


def _spawn(*args):
    # Give the process it's own process group so it
    # won't be killed when the script exits.
    # Close stdout and stderr so we won't have an ugly terminal.
    subprocess.Popen(  # pylint: disable=consider-using-with
        args,
        stdout=open("/dev/null", "w"),  # pylint: disable=consider-using-with
        stderr=open("/dev/null", "w"),  # pylint: disable=consider-using-with
        start_new_session=True,
    )


def remove_zeros(param: dict) -> dict:
    if "children" not in param:
        return param

    new_children = []
    for kid in param["children"]:
        remove_zeros(kid)
        if kid.get("value"):
            new_children.append(kid)

    if not new_children:
        del param["children"]
    else:
        param["children"][:] = new_children

    return param


def remove_importlib(val: dict[Any, Any] | list[Any] | str) -> dict[Any, Any] | list[Any] | str:
    if isinstance(val, list):
        ret: list = []
        for entry in val:
            if entry["name"] == "importlib":
                ret.extend(entry.get("children", []))
            else:
                ret.append(entry)
        return ret
    if isinstance(val, dict):
        if val["name"] == "importlib":
            raise Exception("Too late!")
        return {key: remove_importlib(value) for key, value in val.items()}
    return val


def sum_values(node: dict) -> None:
    """
    >>> node = {'value': 0, 'children': [
    ...    {'value': 1}, {'value': -2, 'children': [{'value': 5}]}
    ... ]}
    >>> sum_values(node)

    >>> node['child_value']
    4

    """
    for child in node.get("children", []):
        sum_values(child)

    node["child_value"] = sum(
        child.get("value", 0) + child.get("child_value", 0) for child in node.get("children", [])
    )


def adjust_values(node: dict) -> None:
    """
    >>> node = {
    ...     u'child_value': 4,
    ...     u'children': [
    ...         {u'child_value': 0, u'value': 1},
    ...         {u'child_value': 5,
    ...             u'children': [{u'child_value': 0, u'value': 5}],
    ...             u'value': -2}],
    ...     u'value': 0
    ... }
    >>> adjust_values(node)
    >>> node['own']
    0

    """
    for child in node.get("children", []):
        adjust_values(child)

    value = node.get("value", 0)
    child = node.get("child_value", 0)

    if value < child:
        node["own"] = value
        node["value"] = value + child
    elif value > child:
        node["own"] = value - child
    elif value == child:
        node["own"] = 0

    assert node["own"] + node["child_value"] == node["value"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Track import memory", epilog=__doc__)
    parser.add_argument("modules", metavar="MODULE", type=str, nargs="+")
    parser.add_argument(
        "-o", "--open", action="store_true", help="Direcly show the resulting file (if possible)"
    )
    parser.add_argument(
        "-f",
        "--formatter",
        metavar="OUTPUT-FORMAT",
        type=str,
        help="May be one of: {}".format(", ".join(FORMATTERS.keys())),
        action="append",
        nargs=1,
    )
    _args = parser.parse_args()

    with basedir():
        main(_args)
