package ns

import (
	"log/slog"
	"net"
	"time"

	"github.com/miekg/dns"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promauto"

	"git.majava.org/software/edif/config"
)

const timeout = 5 * time.Second

// Querier will do DNS queries against nameservers in a zone.
type Querier struct {
	c        *dns.Client
	recursor string

	serial           *prometheus.GaugeVec
	nameserverNs     *prometheus.GaugeVec
	nameserverParent *prometheus.GaugeVec
	dnssecExpiry     *prometheus.GaugeVec
	dnssecSignedKey  *prometheus.GaugeVec
	dnssecParentKey  *prometheus.GaugeVec
}

func NewQuerier(reg prometheus.Registerer, recursor string) Querier {
	return Querier{
		c: &dns.Client{
			Timeout: timeout,
		},
		recursor: recursor,

		serial: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_serial",
				Help: "Current serial served by that specific name server",
			},
			[]string{"zone", "nameserver"},
		),

		nameserverNs: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_nameserver_ns",
				Help: "Name server listed in an NS record",
			},
			[]string{"zone", "nameserver"},
		),

		nameserverParent: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_nameserver_parent",
				Help: "Name server listed in a parent record",
			},
			[]string{"zone", "nameserver"},
		),

		dnssecExpiry: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_dnssec_expiry",
				Help: "Earliest timestamp of DNSSEC signature record expiry",
			},
			[]string{"zone"},
		),

		dnssecSignedKey: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_dnssec_signed_key",
				Help: "DNSSEC signing key present in the zone file",
			},
			[]string{"zone", "key_digest"},
		),

		dnssecParentKey: promauto.With(reg).NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "zone_dnssec_parent_key",
				Help: "DNSSEC signing key present in the parent zone",
			},
			[]string{"zone", "key_digest"},
		),
	}
}

func (q *Querier) checkConfiguredNameservers(zone config.Zone) string {
	workingNs := ""

	for i, nameserver := range zone.Nameservers {
		slog.With("zone", zone.Name).Debug("querying name server",
			"index", i, "ns", nameserver.Hostname)

		labels := prometheus.Labels{"zone": zone.Name, "nameserver": nameserver.Hostname}

		msg := &dns.Msg{}
		msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeSOA)
		msg.RecursionDesired = false

		address := net.JoinHostPort(nameserver.Hostname, "53")
		r, _, err := q.c.Exchange(msg, address)
		if err != nil {
			slog.With("zone", zone.Name).Warn("failed to query nameserver",
				"ns", address, "error", err)
			q.serial.With(labels).Set(0)
			continue
		}
		if r.Rcode != dns.RcodeSuccess {
			slog.With("zone", zone.Name).Warn("failed to query nameserver",
				"ns", address, "rcode", dns.RcodeToString[r.Rcode])
			q.serial.With(labels).Set(0)
			continue
		}
		if len(r.Answer) != 1 {
			slog.With("zone", zone.Name).Warn("got unexpected number of answers",
				"ns", address, "answers", r.Answer)
			q.serial.With(labels).Set(0)
			continue
		}

		answer, ok := r.Answer[0].(*dns.SOA)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", address, "record", r.Answer)
			q.serial.With(labels).Set(0)
			continue
		}

		slog.With("zone", zone.Name).Debug("got serial",
			"ns", nameserver.Hostname, "serial", answer.Serial)
		q.serial.With(labels).Set(float64(answer.Serial))

		if workingNs == "" {
			workingNs = address
		}
	}

	return workingNs
}

func (q *Querier) checkPublishedNameservers(zone config.Zone, ns string) {
	slog.With("zone", zone.Name).Debug("checking published name servers",
		"ns", ns)

	msg := &dns.Msg{}
	msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeNS)
	msg.RecursionDesired = false

	r, _, err := q.c.Exchange(msg, ns)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", ns, "error", err)
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", ns, "rcode", dns.RcodeToString[r.Rcode])
		return
	}

	q.nameserverNs.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})

	for _, rr := range r.Answer {
		record, ok := rr.(*dns.NS)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", ns, "record", rr)
			continue
		}

		q.nameserverNs.With(prometheus.Labels{"zone": zone.Name, "nameserver": record.Ns}).Set(1)
	}
}

func (q *Querier) checkDelegation(zone config.Zone) {
	parentNs, err := q.getParentNameserver(zone)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"error", err)
		return
	}

	slog.With("zone", zone.Name).Debug("checking delegation",
		"ns", parentNs)

	msg := &dns.Msg{}
	msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeNS)
	msg.RecursionDesired = false

	address := net.JoinHostPort(parentNs, "53")
	r, _, err := q.c.Exchange(msg, address)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", parentNs, "error", err)
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", parentNs, "rcode", dns.RcodeToString[r.Rcode])
		return
	}

	q.nameserverParent.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})
	for _, rr := range r.Ns {
		record, ok := rr.(*dns.NS)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", parentNs, "record", rr)
			continue
		}

		q.nameserverParent.With(prometheus.Labels{"zone": zone.Name, "nameserver": record.Ns}).Set(1)
	}
	for _, rr := range r.Answer {
		record, ok := rr.(*dns.NS)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", parentNs, "record", rr)
			continue
		}

		q.nameserverParent.With(prometheus.Labels{"zone": zone.Name, "nameserver": record.Ns}).Set(1)
	}

	msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeDS)
	r, _, err = q.c.Exchange(msg, address)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", parentNs, "error", err)
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		slog.With("zone", zone.Name).Warn("failed to query parent server",
			"ns", parentNs, "rcode", dns.RcodeToString[r.Rcode])
		return
	}

	q.dnssecParentKey.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})
	for _, rr := range r.Answer {
		record, ok := rr.(*dns.DS)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", parentNs, "record", rr)
			continue
		}

		if record.DigestType != dns.SHA256 {
			// this is probably fine
			continue
		}

		q.dnssecParentKey.With(prometheus.Labels{"zone": zone.Name, "key_digest": record.Digest}).Set(1)
	}
}

func (q *Querier) checkDnssec(zone config.Zone, ns string) {
	msg := &dns.Msg{}
	msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeRRSIG)
	msg.RecursionDesired = false

	r, _, err := q.c.Exchange(msg, ns)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query server",
			"ns", ns, "error", err)
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		slog.With("zone", zone.Name).Warn("failed to query server",
			"ns", ns, "rcode", dns.RcodeToString[r.Rcode])
		return
	}

	expiration := uint32(0)

	for _, rr := range r.Answer {
		record, ok := rr.(*dns.RRSIG)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", ns, "record", rr)
			continue
		}

		if expiration == 0 || expiration > record.Expiration {
			expiration = record.Expiration
		}
	}

	if expiration == 0 {
		q.dnssecExpiry.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})
		q.dnssecSignedKey.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})
		return
	}

	q.dnssecExpiry.With(prometheus.Labels{"zone": zone.Name}).Set(float64(expiration))

	msg.SetQuestion(dns.Fqdn(zone.Name), dns.TypeDNSKEY)
	r, _, err = q.c.Exchange(msg, ns)
	if err != nil {
		slog.With("zone", zone.Name).Warn("failed to query server",
			"ns", ns, "error", err)
		return
	}
	if r.Rcode != dns.RcodeSuccess {
		slog.With("zone", zone.Name).Warn("failed to query server",
			"ns", ns, "rcode", dns.RcodeToString[r.Rcode])
		return
	}

	q.dnssecSignedKey.DeletePartialMatch(prometheus.Labels{"zone": zone.Name})
	for _, rr := range r.Answer {
		record, ok := rr.(*dns.DNSKEY)
		if !ok {
			slog.With("zone", zone.Name).Warn("failed to parse record",
				"ns", ns, "record", rr)
			continue
		}

		ds := record.ToDS(dns.SHA256)
		q.dnssecSignedKey.With(prometheus.Labels{"zone": zone.Name, "key_digest": ds.Digest}).Set(1)
	}
}

func (q *Querier) FetchData(zone config.Zone) {
	ns := q.checkConfiguredNameservers(zone)
	if ns == "" {
		slog.With("zone", zone.Name).Error("did not find a working name server")
		return
	}

	if !zone.NoDelegation {
		q.checkPublishedNameservers(zone, ns)
		q.checkDelegation(zone)
	}

	q.checkDnssec(zone, ns)
}
