#!/usr/bin/python3
import argparse
import hashlib
import logging
import subprocess
import sys
import time
from pathlib import Path
from typing import Any

import git
import jinja2
import yaml
from jinja2 import meta

from dnsdeploy import data_dir
from dnsdeploy.dnssec import Dnssec, ZoneSigner
from dnsdeploy.util import ascii_transliterate, get_validated_symlink
from dnsdeploy.zone import get_referenced_files, process_includes

LOGGER = logging.getLogger("dnsdeploy")

OUTPUT_DIR = "/var/lib/authdns/zones"

BIND_CONFIG_FILE = "/var/lib/authdns/zones.conf"

BIND_WORKING_DIRECTORY = "/var/cache/bind"

CATALOG_ZONE_TEMPLATE_FILE = "/var/lib/authdns/catalog.zone.j2"

# Header placed on all output zonefiles
HEADER = """; WARNING!
; This file was automatically generated from a template
; Do NOT edit this file directly!

"""


def get_template_files(
    repo_path: Path,
    template_env: jinja2.Environment,
    template: jinja2.Template,
    tmpl_name: str,
) -> set[Path]:
    """Get the full list of files parsed to process this template"""
    loader: jinja2.FileSystemLoader = template_env.loader  # type: ignore
    raw_src = loader.get_source(template_env, tmpl_name)[0]
    ast = template_env.parse(raw_src)

    files = {repo_path / "zones" / tmpl_name}
    files.update(
        [
            (repo_path / "zones" / zone).resolve()
            for zone in meta.find_referenced_templates(ast)
            if zone is not None
        ]
    )
    files.update(
        get_referenced_files(
            repo_path,
            template.render(
                {
                    "serial_num": "DUMMY",
                    "serial_comment": "DUMMY",
                    "dnssec_keys": lambda: [],
                }
            ),
        )
    )

    return files


def make_serial(repo: git.Repo, flist: list[Path] | set[Path]) -> tuple[str, str]:
    paths = flist if isinstance(flist, list) else list(flist)
    """Make serial from git metadata"""
    # It might be nice to track uncommitted index/workdir changes here too
    try:
        commit = next(repo.iter_commits(max_count=1, paths=paths))
        date = time.strftime(
            "%Y%m%d",
            time.gmtime(commit.committed_date),
        )
        commits = len(
            [
                c
                for c in repo.iter_commits(max_count=100, paths=paths)
                if time.strftime("%Y%m%d", time.gmtime(c.committed_date)) == date
            ]
        )

        sernum = f"{date}{commits:02}"
        return (
            sernum,
            "%s %s %s"
            % (
                sernum,
                str(commit)[:8],
                ascii_transliterate(str(commit.message).splitlines()[0]),
            ),
        )
    except StopIteration:
        return "1", "Unknown/uncommitted"


def make_serial_for_zone(
    repo: git.Repo,
    repo_path: Path,
    template_env: jinja2.Environment,
    template: jinja2.Template,
    tmpl_name: str,
) -> tuple[str, str]:
    flist = get_template_files(repo_path, template_env, template, tmpl_name)
    return make_serial(repo, flist)


def process_tmpl(
    template_path: Path,
    repo: git.Repo,
    template_env: jinja2.Environment,
    output_path: Path,
    signer: ZoneSigner,
) -> bool:
    """Process a template file from the templates directory"""
    repo_path = template_path.parent.parent
    template = template_env.get_template(template_path.name)
    serial_num, serial_comment = make_serial_for_zone(
        repo, repo_path, template_env, template, template_path.name
    )
    output = template.render(
        {
            "serial_num": serial_num,
            "serial_comment": serial_comment,
            "dnssec_keys": signer.generate_and_get_keys,
        },
    )

    output = f"{HEADER}{process_includes(repo_path, output)}\n"
    if output_path.exists() and output_path.read_text() == output:
        LOGGER.debug("Skipping writing %s, already identical content", output_path)
        return False

    output_path.write_text(output)
    return True


def generate_zones_config(
    template_path: Path, views: dict[str, Any], tsig_suffix: str
) -> str:
    template_env = jinja2.Environment(
        loader=jinja2.FileSystemLoader(str(template_path.parent)),
        undefined=jinja2.StrictUndefined,
    )

    template = template_env.get_template(template_path.name)
    return template.render({"views": views, "tsig_suffix": tsig_suffix})


def write_catalog_zone(
    template_path: Path,
    repo: git.Repo,
    repo_path: Path,
    view_name: str,
    zones: dict[str, Any],
    output_path: Path,
) -> bool:
    zones = {
        hashlib.sha1(zname.encode("utf-8")).hexdigest(): {
            "name": zname,
            "view": view_name if "file" in zone else zone["in_view"],
        }
        for zname, zone in zones.items()
    }

    template_env = jinja2.Environment(
        loader=jinja2.FileSystemLoader(template_path.parent),
        undefined=jinja2.StrictUndefined,
    )

    serial_num, serial_comment = make_serial(repo, [repo_path / "zones.yaml"])
    template = template_env.get_template(template_path.name)
    rendered = template.render(
        {
            "serial_num": serial_num,
            "serial_comment": serial_comment,
            "zones": zones,
            "view": view_name,
        }
    )
    output = f"{HEADER}{rendered}\n"

    if output_path.exists() and output_path.read_text() == output:
        LOGGER.debug("Skipping writing %s, already identical content", output_path)
        return False

    output_path.write_text(output)
    return True


def main() -> None:
    """main"""
    parser = argparse.ArgumentParser(
        prog="dnsdeploy",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--debug", action="store_true", help="enable verbose logging")
    parser.add_argument("--git-dir", type=Path, required=True)
    parser.add_argument("--views-config", type=Path, required=True)
    parser.add_argument(
        "--key-dir",
        type=Path,
        default=Path("/var/lib/authdns/keys"),
        help="path to directory holding DNSSEC keys",
    )
    parser.add_argument("--view-tsig-suffix", default="view.tsig.dns.majava.org.")
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

    zone_dir = Path(OUTPUT_DIR)

    # Counter is just for the final summary output text
    zcount = 0

    # Repo object for git history access
    repo = git.Repo(args.git_dir)
    # Set up template engine
    templates_dir_path = args.git_dir / "zones"
    template_env = jinja2.Environment(
        loader=jinja2.FileSystemLoader(templates_dir_path),
        undefined=jinja2.StrictUndefined,
    )

    # For cleanup
    written_paths: set[Path] = set()

    with Path(args.git_dir / "zones.yaml").open("r") as f:
        metadata = yaml.safe_load(f)

    with args.views_config.open("r") as f:
        views = yaml.safe_load(f)
    for view_name, view_data in views.items():
        view_data["zones"] = {}

    dnssec = Dnssec(args.key_dir)

    for zname, zdata in metadata["zones"].items():
        signer = dnssec.get_zone(zname)

        if not isinstance(zdata, dict):
            zdata = {
                "default": zdata,
            }

        for view_name, template_name in zdata.items():
            template_path = templates_dir_path / template_name

            zcount += 1
            if template_path.is_symlink():
                template_path = get_validated_symlink(template_path)

            output_path = zone_dir / template_name

            changed = process_tmpl(
                template_path,
                repo,
                template_env,
                output_path,
                signer,
            )

            written_paths.add(output_path)

            views[view_name]["zones"][zname] = {
                "file": output_path,
                "changed": changed,
            }

        for view_name, view_data in views.items():
            if zname not in view_data["zones"]:
                view_data["zones"][zname] = {"in_view": "default"}

    LOGGER.info("Processed %s zones into directory %s", zcount, zone_dir)

    for view_name, view_data in views.items():
        catalog_zone = view_data["catalog"]
        catalog_zone_file = zone_dir / f"{catalog_zone}.zone"

        changed = write_catalog_zone(
            Path(CATALOG_ZONE_TEMPLATE_FILE),
            repo,
            args.git_dir,
            view_name,
            view_data["zones"],
            catalog_zone_file,
        )
        written_paths.add(catalog_zone_file)

        view_data["zones"][f"{catalog_zone}."] = {
            "file": catalog_zone_file,
            "changed": changed,
        }

    LOGGER.info(
        "Generated %s catalog zones into directory %s",
        len(views),
        zone_dir,
    )

    validate_success = True
    for view_name, view_data in views.items():
        for zone_name, zone_data in view_data["zones"].items():
            if "file" not in zone_data:
                continue

            signer = dnssec.get_zone(zone_name)
            if signer.has_keys():
                signed_file = (
                    zone_data["file"].parent / f"{zone_data['file'].name}.signed"
                )

                # TODO: check if the zone needs re-signing?
                if zone_data.get("changed", False) or not signed_file.exists():
                    signer.sign_zone_file(zone_data["file"])

                written_paths.add(signed_file)
                zone_data["file"] = signed_file

            result = subprocess.run(
                [
                    "/usr/bin/named-checkzone",
                    zone_name,
                    str(zone_data["file"]),
                ],
                cwd=BIND_WORKING_DIRECTORY,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
            )
            if result.returncode != 0:
                print(result.stdout, file=sys.stderr, flush=True)
                validate_success = False
    if not validate_success:
        LOGGER.error("Found syntax errors, aborting")
        sys.exit(1)

    LOGGER.info("Validated syntax for all zone files")

    for existing_file in zone_dir.glob("*"):
        if existing_file in written_paths:
            continue

        existing_file.unlink()
        LOGGER.warning("Removed unused zone file %s", existing_file)

    bind_config = generate_zones_config(
        data_dir / "bind9.conf.j2", views, args.view_tsig_suffix
    )

    with Path(BIND_CONFIG_FILE).open("w") as f:
        f.write(bind_config)
        LOGGER.info("Wrote config file to %s", BIND_CONFIG_FILE)

    subprocess.check_call(["/usr/bin/named-checkconf"])
    subprocess.check_call(["/usr/sbin/rndc", "reload"])
    LOGGER.info("Reloaded named")
