aboutsummaryrefslogtreecommitdiff
path: root/middleware/file/file.go
blob: 50ae3fd26315872c17b17de8ca8ee645adb8d686 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package file

import (
	"fmt"
	"io"
	"log"

	"github.com/miekg/coredns/middleware"

	"github.com/miekg/dns"
	"golang.org/x/net/context"
)

type (
	File struct {
		Next  middleware.Handler
		Zones Zones
	}

	Zones struct {
		Z     map[string]*Zone
		Names []string
	}
)

func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	state := middleware.State{W: w, Req: r}

	if state.QClass() != dns.ClassINET {
		return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
	}
	qname := state.Name()
	zone := middleware.Zones(f.Zones.Names).Matches(qname)
	if zone == "" {
		return f.Next.ServeDNS(ctx, w, r)
	}
	z, ok := f.Zones.Z[zone]
	if !ok {
		return f.Next.ServeDNS(ctx, w, r)
	}
	if z == nil {
		return dns.RcodeServerFailure, nil
	}
	if r.Opcode == dns.OpcodeNotify {
		if z.isNotify(state) {
			m := new(dns.Msg)
			m.SetReply(r)
			m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
			state.SizeAndDo(m)
			w.WriteMsg(m)

			log.Printf("[INFO] Notify from %s for %s: checking transfer", state.IP(), zone)
			ok, err := z.shouldTransfer()
			if ok {
				z.TransferIn()
			} else {
				log.Printf("[INFO] Notify from %s for %s: no serial increase seen", state.IP(), zone)
			}
			if err != nil {
				log.Printf("[WARNING] Notify from %s for %s: failed primary check: %s", state.IP(), zone, err)
			}
			return dns.RcodeSuccess, nil
		}
		log.Printf("[INFO] Dropping notify from %s for %s", state.IP(), zone)
		return dns.RcodeSuccess, nil
	}

	if z.Expired != nil && *z.Expired {
		log.Printf("[ERROR] Zone %s is expired", zone)
		return dns.RcodeServerFailure, nil
	}

	if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR {
		xfr := Xfr{z}
		return xfr.ServeDNS(ctx, w, r)
	}

	answer, ns, extra, result := z.Lookup(qname, state.QType(), state.Do())

	m := new(dns.Msg)
	m.SetReply(r)
	m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true

	switch result {
	case Success:
		m.Answer = answer
		m.Ns = ns
		m.Extra = extra
	case NameError:
		m.Ns = ns
		m.Rcode = dns.RcodeNameError
		fallthrough
	case NoData:
		m.Ns = ns
	case ServerFailure:
		return dns.RcodeServerFailure, nil
	}

	state.SizeAndDo(m)
	m, _ = state.Scrub(m)
	w.WriteMsg(m)
	return dns.RcodeSuccess, nil
}

// Parse parses the zone in filename and returns a new Zone or an error.
func Parse(f io.Reader, origin, fileName string) (*Zone, error) {
	tokens := dns.ParseZone(f, dns.Fqdn(origin), fileName)
	z := NewZone(origin)
	for x := range tokens {
		if x.Error != nil {
			log.Printf("[ERROR] Failed to parse `%s': %v", origin, x.Error)
			return nil, x.Error
		}
		switch h := x.RR.Header().Rrtype; h {
		case dns.TypeSOA:
			z.SOA = x.RR.(*dns.SOA)
		case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
			err := fmt.Errorf("NSEC3 zone is not supported, dropping")
			log.Printf("[ERROR] Failed to parse `%s': %v", origin, err)
			return nil, err
		case dns.TypeRRSIG:
			if x, ok := x.RR.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA {
				z.SIG = append(z.SIG, x)
				continue
			}
			fallthrough
		default:
			z.Insert(x.RR)
		}
	}
	return z, nil
}