1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
package file
import (
"fmt"
"io"
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
type (
File struct {
Next middleware.Handler
Zones Zones
}
Zones struct {
Z map[string]*Zone
Names []string
}
)
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
}
qname := state.Name()
zone := middleware.Zones(f.Zones.Names).Matches(qname)
if zone == "" {
return f.Next.ServeDNS(ctx, w, r)
}
z, ok := f.Zones.Z[zone]
if !ok {
return f.Next.ServeDNS(ctx, w, r)
}
if z == nil {
return dns.RcodeServerFailure, nil
}
if r.Opcode == dns.OpcodeNotify {
if z.isNotify(state) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
state.SizeAndDo(m)
w.WriteMsg(m)
log.Printf("[INFO] Notify from %s for %s: checking transfer", state.IP(), zone)
ok, err := z.shouldTransfer()
if ok {
z.TransferIn()
} else {
log.Printf("[INFO] Notify from %s for %s: no serial increase seen", state.IP(), zone)
}
if err != nil {
log.Printf("[WARNING] Notify from %s for %s: failed primary check: %s", state.IP(), zone, err)
}
return dns.RcodeSuccess, nil
}
log.Printf("[INFO] Dropping notify from %s for %s", state.IP(), zone)
return dns.RcodeSuccess, nil
}
if z.Expired != nil && *z.Expired {
log.Printf("[ERROR] Zone %s is expired", zone)
return dns.RcodeServerFailure, nil
}
if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR {
xfr := Xfr{z}
return xfr.ServeDNS(ctx, w, r)
}
answer, ns, extra, result := z.Lookup(qname, state.QType(), state.Do())
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
switch result {
case Success:
m.Answer = answer
m.Ns = ns
m.Extra = extra
case NameError:
m.Ns = ns
m.Rcode = dns.RcodeNameError
fallthrough
case NoData:
m.Ns = ns
case ServerFailure:
return dns.RcodeServerFailure, nil
}
state.SizeAndDo(m)
m, _ = state.Scrub(m)
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Parse parses the zone in filename and returns a new Zone or an error.
func Parse(f io.Reader, origin, fileName string) (*Zone, error) {
tokens := dns.ParseZone(f, dns.Fqdn(origin), fileName)
z := NewZone(origin)
for x := range tokens {
if x.Error != nil {
log.Printf("[ERROR] Failed to parse `%s': %v", origin, x.Error)
return nil, x.Error
}
switch h := x.RR.Header().Rrtype; h {
case dns.TypeSOA:
z.SOA = x.RR.(*dns.SOA)
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
err := fmt.Errorf("NSEC3 zone is not supported, dropping")
log.Printf("[ERROR] Failed to parse `%s': %v", origin, err)
return nil, err
case dns.TypeRRSIG:
if x, ok := x.RR.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA {
z.SIG = append(z.SIG, x)
continue
}
fallthrough
default:
z.Insert(x.RR)
}
}
return z, nil
}
|