import logging
import subprocess
import time
from collections.abc import Generator
from enum import IntEnum, StrEnum
from pathlib import Path
from typing import cast

import dns.zone
from dns.rdataclass import RdataClass
from dns.rdatatype import RdataType
from dns.rdtypes.ANY.RRSIG import RRSIG

LOGGER = logging.getLogger(__name__)


TWO_WEEKS_SECONDS = 2 * 7 * 86400


class KeyType(StrEnum):
    KSK = "KSK"
    ZSK = "ZSK"


class Algorithm(IntEnum):
    # https://en.wikipedia.org/wiki/Domain_Name_System_Security_Extensions#Algorithms
    ECDSA_SHA256 = 13
    ED25519 = 15


SUPPORTED_ALGORITHMS = [Algorithm.ECDSA_SHA256, Algorithm.ED25519]


class ZoneSigner:
    def __init__(
        self,
        zone_name: str,
        directory: Path,
    ) -> None:
        self.zone_name = zone_name
        self.directory = directory

        self.signature_requested = False

    def generate_key(
        self,
        key_type: KeyType,
        algorithm: Algorithm,
    ) -> None:
        LOGGER.info(
            "Generating %s %s DNSSEC key for %s",
            key_type.name,
            algorithm.name,
            self.zone_name,
        )

        subprocess.check_output(
            [
                "/usr/bin/dnssec-keygen",
                "-K",
                str(self.directory),
                "-f",
                key_type.value,
                "-a",
                str(algorithm.value),
                self.zone_name,
            ]
        )

    def get_key_files(self, algorithm: Algorithm | None = None) -> Generator[Path]:
        return self.directory.glob(get_key_glob(self.zone_name, algorithm))

    def get_keys_to_sign_with(self) -> Generator[Path]:
        for algorithm in SUPPORTED_ALGORITHMS:
            self.ensure_keys(algorithm=algorithm)
        self.signature_requested = True
        return self.get_key_files()

    def ensure_keys(self, algorithm: Algorithm) -> None:
        if not self.directory.exists():
            self.directory.mkdir()

        if any(self.get_key_files(algorithm=algorithm)):
            LOGGER.debug("Zone %s has existing DNSSEC keys", self.zone_name)
            # TODO: key rotation?
            return

        self.generate_key(key_type=KeyType.ZSK, algorithm=algorithm)
        self.generate_key(key_type=KeyType.KSK, algorithm=algorithm)

    def sign_zone_file(self, zone_unsigned: Path, zone_signed: Path) -> None:
        keys = [str(key) for key in self.get_key_files()]

        subprocess.check_output(
            [
                "/usr/bin/dnssec-signzone",
                "-K",
                str(self.directory),
                "-o",
                self.zone_name,
                "-f",
                str(zone_signed),
                str(zone_unsigned),
                *keys,
            ],
            cwd=str(self.directory),
        )

    def needs_resigning(self, signed_file: Path) -> int:
        expiration = get_expiry(
            signed_zone_file=signed_file, origin=f"{self.zone_name}."
        )
        if not expiration:
            LOGGER.debug("No expiration for zone %s (%s)", signed_file, self.zone_name)
            return False

        remaining = expiration - time.time()
        LOGGER.debug(
            "Zone %s signature expires in %s seconds",
            self.zone_name,
            remaining,
        )
        return remaining < TWO_WEEKS_SECONDS

    def get_extra_serial(self, zone_file: Path, current_serial: int) -> int:
        signature_file = get_signature_file(zone_file)
        if not signature_file.exists():
            LOGGER.debug(
                "Signature file %s does not exist, zone not signed",
                self.zone_name,
            )
            return 0

        bump_file = get_bump_file(zone_file)

        counter = 0
        if bump_file.exists():
            expected_serial, counter_str = bump_file.read_text().strip().split(" ")
            if expected_serial == str(current_serial):
                counter = int(counter_str)
            else:
                LOGGER.debug(
                    "Removing %s, serial changed from %s to %s",
                    bump_file,
                    expected_serial,
                    current_serial,
                )
                bump_file.unlink()

        if self.needs_resigning(signature_file):
            LOGGER.info("Zone %s needs a new DNSSEC signature", self.zone_name)
            # This is safe: we're by definition dealing with very
            # old files here, so any newer updates will have much higher
            # serials than this one.
            counter += 1
            bump_file.write_text(f"{current_serial} {counter}\n")

        return counter


class Dnssec:
    def __init__(self, base_dir: Path) -> None:
        self.zones: dict[str, ZoneSigner] = {}
        self.base_dir = base_dir

    def get_zone(self, name: str) -> ZoneSigner:
        name_nodot = name.removesuffix(".")
        if name_nodot not in self.zones:
            self.zones[name_nodot] = ZoneSigner(
                name_nodot,
                self.base_dir / name_nodot,
            )

        return self.zones[name_nodot]


def get_expiry(signed_zone_file: Path, origin: str) -> int | None:
    with signed_zone_file.open("r") as f:
        zone = dns.zone.from_file(f, origin=origin)
    apex = zone.get_node(origin)
    if not apex:
        LOGGER.warning("Found no apex record '%s'", origin)
        return None

    rrsig = apex.find_rdataset(
        rdclass=RdataClass.IN,
        rdtype=RdataType.RRSIG,
        covers=RdataType.SOA,
        # bypass KeyError if this does not exist
        create=True,
    )

    for record in rrsig:
        return cast(RRSIG, record).expiration  # type: ignore

    return None


def get_signature_file(path: Path) -> Path:
    return path.parent / f"{path.name}.signed"


def get_bump_file(path: Path) -> Path:
    return path.parent / f"{path.name}.signature-counter"


def get_key_glob(zone_name: str, algorithm: Algorithm | None) -> str:
    algo_glob = f"+{algorithm.value:03}+" if algorithm else ""
    return f"K{zone_name.removesuffix('.')}.{algo_glob}*.key"
