aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/setup/file.go34
-rw-r--r--middleware/etcd/cname_test.go96
-rw-r--r--middleware/etcd/handler.go5
-rw-r--r--middleware/file/file.go94
-rw-r--r--middleware/file/file.md40
-rw-r--r--middleware/file/lookup.go75
-rw-r--r--middleware/file/lookup_test.go174
-rw-r--r--middleware/file/tree/tree.go585
-rw-r--r--middleware/file/zone.go26
-rw-r--r--middleware/file/zone_test.go30
-rw-r--r--middleware/loadbalance/loadbalance.go23
-rw-r--r--middleware/loadbalance/loadbalance.md4
-rw-r--r--middleware/loadbalance/loadbalance_test.go104
-rw-r--r--middleware/proxy/reverseproxy.go3
-rw-r--r--middleware/state.go48
15 files changed, 1239 insertions, 102 deletions
diff --git a/core/setup/file.go b/core/setup/file.go
index 0b85d84f3..858b784c2 100644
--- a/core/setup/file.go
+++ b/core/setup/file.go
@@ -1,13 +1,10 @@
package setup
import (
- "log"
"os"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file"
-
- "github.com/miekg/dns"
)
// File sets up the file middleware.
@@ -23,8 +20,7 @@ func File(c *Controller) (middleware.Middleware, error) {
}
func fileParse(c *Controller) (file.Zones, error) {
- // Maybe multiple, each for each zone.
- z := make(map[string]file.Zone)
+ z := make(map[string]*file.Zone)
names := []string{}
for c.Next() {
if c.Val() == "file" {
@@ -42,7 +38,12 @@ func fileParse(c *Controller) (file.Zones, error) {
// normalize this origin
origin = middleware.Host(origin).StandardHost()
- zone, err := parseZone(origin, fileName)
+ reader, err := os.Open(fileName)
+ if err != nil {
+ return file.Zones{}, err
+ }
+
+ zone, err := file.Parse(reader, origin, fileName)
if err == nil {
z[origin] = zone
}
@@ -51,24 +52,3 @@ func fileParse(c *Controller) (file.Zones, error) {
}
return file.Zones{Z: z, Names: names}, nil
}
-
-//
-// parsrZone parses the zone in filename and returns a []RR or an error.
-func parseZone(origin, fileName string) (file.Zone, error) {
- f, err := os.Open(fileName)
- if err != nil {
- return nil, err
- }
- tokens := dns.ParseZone(f, origin, fileName)
- zone := make([]dns.RR, 0, defaultZoneSize)
- for x := range tokens {
- if x.Error != nil {
- log.Printf("[ERROR] failed to parse %s: %v", origin, x.Error)
- return nil, x.Error
- }
- zone = append(zone, x.RR)
- }
- return file.Zone(zone), nil
-}
-
-const defaultZoneSize = 20 // A made up number.
diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go
new file mode 100644
index 000000000..7d53bfef6
--- /dev/null
+++ b/middleware/etcd/cname_test.go
@@ -0,0 +1,96 @@
+// +build etcd
+
+package etcd
+
+// etcd needs to be running on http://127.0.0.1:2379
+
+import (
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/etcd/msg"
+
+ "github.com/miekg/dns"
+)
+
+// Check the ordering of returned cname.
+func TestCnameLookup(t *testing.T) {
+ for _, serv := range servicesCname {
+ set(t, etc, serv.Key, 0, serv)
+ defer delete(t, etc, serv.Key)
+ }
+ for _, tc := range dnsTestCasesCname {
+ m := new(dns.Msg)
+ m.SetQuestion(dns.Fqdn(tc.Qname), tc.Qtype)
+
+ rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{})
+ _, err := etc.ServeDNS(ctx, rec, m)
+ if err != nil {
+ t.Errorf("expected no error, got %v\n", err)
+ return
+ }
+ resp := rec.Msg()
+
+ if resp.Rcode != tc.Rcode {
+ t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
+ t.Logf("%v\n", resp)
+ continue
+ }
+
+ if len(resp.Answer) != len(tc.Answer) {
+ t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
+ t.Logf("%v\n", resp)
+ continue
+ }
+ if len(resp.Ns) != len(tc.Ns) {
+ t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
+ t.Logf("%v\n", resp)
+ continue
+ }
+ if len(resp.Extra) != len(tc.Extra) {
+ t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
+ t.Logf("%v\n", resp)
+ continue
+ }
+
+ if !checkSection(t, tc, Answer, resp.Answer) {
+ t.Logf("%v\n", resp)
+ }
+ if !checkSection(t, tc, Ns, resp.Ns) {
+ t.Logf("%v\n", resp)
+
+ }
+ if !checkSection(t, tc, Extra, resp.Extra) {
+ t.Logf("%v\n", resp)
+ }
+ }
+}
+
+var servicesCname = []*msg.Service{
+ {Host: "cname1.region2.skydns.test", Key: "a.server1.dev.region1.skydns.test."},
+ {Host: "cname2.region2.skydns.test", Key: "cname1.region2.skydns.test."},
+ {Host: "cname3.region2.skydns.test", Key: "cname2.region2.skydns.test."},
+ {Host: "cname4.region2.skydns.test", Key: "cname3.region2.skydns.test."},
+ {Host: "cname5.region2.skydns.test", Key: "cname4.region2.skydns.test."},
+ {Host: "cname6.region2.skydns.test", Key: "cname5.region2.skydns.test."},
+ {Host: "endpoint.region2.skydns.test", Key: "cname6.region2.skydns.test."},
+ {Host: "10.240.0.1", Key: "endpoint.region2.skydns.test."},
+}
+
+var dnsTestCasesCname = []dnsTestCase{
+ {
+ Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV,
+ Answer: []dns.RR{
+ newSRV("a.server1.dev.region1.skydns.test. 300 IN SRV 10 100 0 cname1.region2.skydns.test."),
+ },
+ Extra: []dns.RR{
+ newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newCNAME("cname3.region2.skydns.test. 300 IN CNAME cname4.region2.skydns.test."),
+ newCNAME("cname4.region2.skydns.test. 300 IN CNAME cname5.region2.skydns.test."),
+ newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
+ newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ },
+ },
+}
diff --git a/middleware/etcd/handler.go b/middleware/etcd/handler.go
index 552243fa4..bd5df5e13 100644
--- a/middleware/etcd/handler.go
+++ b/middleware/etcd/handler.go
@@ -30,7 +30,8 @@ func (e Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
return e.Next.ServeDNS(ctx, w, r)
}
- m := state.AnswerMessage()
+ m := new(dns.Msg)
+ m.SetReply(r)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
var (
@@ -88,7 +89,7 @@ func (e Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
}
m = dedup(m)
-
+ m, _ = state.Scrub(m)
state.W.WriteMsg(m)
return 0, nil
}
diff --git a/middleware/file/file.go b/middleware/file/file.go
index 3827664bc..b9f53003a 100644
--- a/middleware/file/file.go
+++ b/middleware/file/file.go
@@ -6,12 +6,13 @@ package file
// have some fluff for DNSSEC (and be memory efficient).
import (
- "strings"
-
- "golang.org/x/net/context"
+ "io"
+ "log"
"github.com/miekg/coredns/middleware"
+
"github.com/miekg/dns"
+ "golang.org/x/net/context"
)
type (
@@ -21,9 +22,8 @@ type (
// Maybe a list of all zones as well, as a []string?
}
- Zone []dns.RR
Zones struct {
- Z map[string]Zone // utterly braindead impl. TODO(miek): fix
+ Z map[string]*Zone
Names []string
}
)
@@ -35,57 +35,51 @@ func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
if zone == "" {
return f.Next.ServeDNS(ctx, w, r)
}
+ z, ok := f.Zones.Z[zone]
+ if !ok {
+ return f.Next.ServeDNS(ctx, w, r)
+ }
- names, nodata := f.Zones.Z[zone].lookup(qname, state.QType())
- var answer *dns.Msg
- switch {
- case nodata:
- answer = state.AnswerMessage()
- answer.Ns = names
- case len(names) == 0:
- answer = state.AnswerMessage()
- answer.Ns = names
- answer.Rcode = dns.RcodeNameError
- case len(names) > 0:
- answer = state.AnswerMessage()
- answer.Answer = names
+ rrs, extra, result := z.Lookup(qname, state.QType(), state.Do())
+
+ m := new(dns.Msg)
+ m.SetReply(r)
+ m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
+
+ switch result {
+ case Success:
+ // case?
+ m.Answer = rrs
+ m.Extra = extra
+ // Ns section
+ case NameError:
+ m.Rcode = dns.RcodeNameError
+ fallthrough
+ case NoData:
+ // case?
+ m.Ns = rrs
default:
- answer = state.ErrorMessage(dns.RcodeServerFailure)
+ // TODO
}
- // Check return size, etc. TODO(miek)
- w.WriteMsg(answer)
- return 0, nil
+ // sizing and Do bit RRSIG
+ w.WriteMsg(m)
+ return dns.RcodeSuccess, nil
}
-// Lookup will try to find qname and qtype in z. It returns the
-// records found *or* a boolean saying NODATA. If the answer
-// is NODATA then the RR returned is the SOA record.
-//
-// TODO(miek): EXTREMELY STUPID IMPLEMENTATION.
-// Doesn't do much, no delegation, no cname, nothing really, etc.
-// TODO(miek): even NODATA looks broken
-func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) {
- var (
- nodata bool
- rep []dns.RR
- soa dns.RR
- )
-
- for _, rr := range z {
- if rr.Header().Rrtype == dns.TypeSOA {
- soa = rr
+// Parse parses the zone in filename and returns a new Zone or an error.
+func Parse(f io.Reader, origin, fileName string) (*Zone, error) {
+ tokens := dns.ParseZone(f, dns.Fqdn(origin), fileName)
+ z := NewZone(origin)
+ for x := range tokens {
+ if x.Error != nil {
+ log.Printf("[ERROR] failed to parse %s: %v", origin, x.Error)
+ return nil, x.Error
}
- // Match function in Go DNS?
- if strings.ToLower(rr.Header().Name) == qname {
- if rr.Header().Rrtype == qtype {
- rep = append(rep, rr)
- nodata = false
- }
-
+ if x.RR.Header().Rrtype == dns.TypeSOA {
+ z.SOA = x.RR.(*dns.SOA)
+ continue
}
+ z.Insert(x.RR)
}
- if nodata {
- return []dns.RR{soa}, true
- }
- return rep, false
+ return z, nil
}
diff --git a/middleware/file/file.md b/middleware/file/file.md
new file mode 100644
index 000000000..2e23b0332
--- /dev/null
+++ b/middleware/file/file.md
@@ -0,0 +1,40 @@
+# file
+
+`file` enabled reading zone data from a RFC-1035 styled file.
+
+The etcd middleware makes extensive use of the proxy middleware to forward and query
+other servers in the network.
+
+## Syntax
+
+~~~
+file dbfile [zones...]
+~~~
+
+* `dbfile` the database file to read and parse.
+* `zones` zones it should be authoritative for. If empty the zones from the configuration block
+ are used.
+
+If you want to `round robin` A and AAAA responses look at the `loadbalance` middleware.
+
+~~~
+file {
+ db <dsds>
+ masters [...masters...]
+}
+~~~
+
+
+
+
+
+* `path` /skydns
+* `endpoint` endpoints...
+* `stubzones`
+
+## Examples
+
+dnssec {
+ file blaat, transparant allow already signed responses
+ ksk bliep.dsdsk
+}
diff --git a/middleware/file/lookup.go b/middleware/file/lookup.go
new file mode 100644
index 000000000..2798b0b23
--- /dev/null
+++ b/middleware/file/lookup.go
@@ -0,0 +1,75 @@
+package file
+
+import "github.com/miekg/dns"
+
+// Result is the result of a Lookup
+type Result int
+
+const (
+ Success Result = iota
+ NameError
+ NoData // aint no offical NoData return code.
+)
+
+// Lookup looks up qname and qtype in the zone, when do is true DNSSEC are included as well.
+// Two sets of records are returned, one for the answer and one for the additional section.
+func (z *Zone) Lookup(qname string, qtype uint16, do bool) ([]dns.RR, []dns.RR, Result) {
+ // TODO(miek): implement DNSSEC
+ var rr dns.RR
+ mk, known := dns.TypeToRR[qtype]
+ if !known {
+ return nil, nil, NameError
+ // Uhm...?
+ // rr = new(RFC3597)
+ } else {
+ rr = mk()
+ }
+ if qtype == dns.TypeSOA {
+ return z.lookupSOA(do)
+ }
+
+ rr.Header().Name = qname
+ elem := z.Tree.Get(rr)
+ if elem == nil {
+ return []dns.RR{z.SOA}, nil, NameError
+ }
+ rrs := elem.Types(dns.TypeCNAME)
+ if len(rrs) > 0 { // should only ever be 1 actually; TODO(miek) check for this?
+ // lookup target from the cname
+ rr.Header().Name = rrs[0].(*dns.CNAME).Target
+ elem := z.Tree.Get(rr)
+ if elem == nil {
+ return rrs, nil, Success
+ }
+ return rrs, elem.All(), Success
+ }
+
+ rrs = elem.Types(qtype)
+ if len(rrs) == 0 {
+ return []dns.RR{z.SOA}, nil, NoData
+ }
+ // Need to check sub-type on RRSIG records to only include the correctly
+ // typed ones.
+ return rrs, nil, Success
+}
+
+func (z *Zone) lookupSOA(do bool) ([]dns.RR, []dns.RR, Result) {
+ return []dns.RR{z.SOA}, nil, Success
+}
+
+// signatureForSubType range through the signature and return the correct
+// ones for the subtype.
+func (z *Zone) signatureForSubType(rrs []dns.RR, subtype uint16, do bool) []dns.RR {
+ if !do {
+ return nil
+ }
+ sigs := []dns.RR{}
+ for _, sig := range rrs {
+ if s, ok := sig.(*dns.RRSIG); ok {
+ if s.TypeCovered == subtype {
+ sigs = append(sigs, s)
+ }
+ }
+ }
+ return sigs
+}
diff --git a/middleware/file/lookup_test.go b/middleware/file/lookup_test.go
new file mode 100644
index 000000000..1e40d52a5
--- /dev/null
+++ b/middleware/file/lookup_test.go
@@ -0,0 +1,174 @@
+package file
+
+import (
+ "sort"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+var dnsTestCases = []dnsTestCase{
+ {
+ Qname: "miek.nl.", Qtype: dns.TypeSOA,
+ Answer: []dns.RR{
+ newSOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"),
+ },
+ },
+ {
+ Qname: "miek.nl.", Qtype: dns.TypeAAAA,
+ Answer: []dns.RR{
+ newAAAA("miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"),
+ },
+ },
+ {
+ Qname: "miek.nl.", Qtype: dns.TypeMX,
+ Answer: []dns.RR{
+ newMX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."),
+ newMX("miek.nl. 1800 IN MX 10 aspmx2.googlemail.com."),
+ newMX("miek.nl. 1800 IN MX 10 aspmx3.googlemail.com."),
+ newMX("miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com."),
+ newMX("miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com."),
+ },
+ },
+ {
+ Qname: "www.miek.nl.", Qtype: dns.TypeA,
+ Answer: []dns.RR{
+ newCNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."),
+ },
+
+ Extra: []dns.RR{
+ newA("a.miek.nl. 1800 IN A 139.162.196.78"),
+ newAAAA("a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"),
+ },
+ },
+ {
+ Qname: "a.miek.nl.", Qtype: dns.TypeSRV,
+ Ns: []dns.RR{
+ newSOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"),
+ },
+ },
+ {
+ Qname: "b.miek.nl.", Qtype: dns.TypeA,
+ Rcode: dns.RcodeNameError,
+ Ns: []dns.RR{
+ newSOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"),
+ },
+ },
+}
+
+type rrSet []dns.RR
+
+func (p rrSet) Len() int { return len(p) }
+func (p rrSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+func (p rrSet) Less(i, j int) bool { return p[i].String() < p[j].String() }
+
+const testzone = "miek.nl."
+
+func TestLookup(t *testing.T) {
+ zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin")
+ if err != nil {
+ t.Fatalf("expect no error when reading zone, got %q", err)
+ }
+
+ fm := File{Next: handler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}}
+ ctx := context.TODO()
+
+ for _, tc := range dnsTestCases {
+ m := new(dns.Msg)
+ m.SetQuestion(dns.Fqdn(tc.Qname), tc.Qtype)
+
+ rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{})
+ _, err := fm.ServeDNS(ctx, rec, m)
+ if err != nil {
+ t.Errorf("expected no error, got %v\n", err)
+ return
+ }
+ resp := rec.Msg()
+
+ sort.Sort(rrSet(resp.Answer))
+ sort.Sort(rrSet(resp.Ns))
+ sort.Sort(rrSet(resp.Extra))
+
+ if resp.Rcode != tc.Rcode {
+ t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
+ t.Logf("%v\n", resp)
+ continue
+ }
+
+ if len(resp.Answer) != len(tc.Answer) {
+ t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
+ t.Logf("%v\n", resp)
+ continue
+ }
+ if len(resp.Ns) != len(tc.Ns) {
+ t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
+ t.Logf("%v\n", resp)
+ continue
+ }
+ if len(resp.Extra) != len(tc.Extra) {
+ t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
+ t.Logf("%v\n", resp)
+ continue
+ }
+ }
+}
+
+type dnsTestCase struct {
+ Qname string
+ Qtype uint16
+ Rcode int
+ Answer []dns.RR
+ Ns []dns.RR
+ Extra []dns.RR
+}
+
+func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
+func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
+func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
+func newSRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) }
+func newSOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) }
+func newNS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) }
+func newPTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) }
+func newTXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) }
+func newMX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }
+
+const dbMiekNL = `
+$TTL 30M
+$ORIGIN miek.nl.
+@ IN SOA linode.atoom.net. miek.miek.nl. (
+ 1282630057 ; Serial
+ 4H ; Refresh
+ 1H ; Retry
+ 7D ; Expire
+ 4H ) ; Negative Cache TTL
+ IN NS linode.atoom.net.
+ IN NS ns-ext.nlnetlabs.nl.
+ IN NS omval.tednet.nl.
+ IN NS ext.ns.whyscream.net.
+
+ IN MX 1 aspmx.l.google.com.
+ IN MX 5 alt1.aspmx.l.google.com.
+ IN MX 5 alt2.aspmx.l.google.com.
+ IN MX 10 aspmx2.googlemail.com.
+ IN MX 10 aspmx3.googlemail.com.
+
+ IN A 139.162.196.78
+ IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
+
+a IN A 139.162.196.78
+ IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
+www IN CNAME a
+archive IN CNAME a`
+
+func handler() middleware.Handler {
+ return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ m := new(dns.Msg)
+ m.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(m)
+ return dns.RcodeServerFailure, nil
+ })
+}
diff --git a/middleware/file/tree/tree.go b/middleware/file/tree/tree.go
new file mode 100644
index 000000000..db57c2092
--- /dev/null
+++ b/middleware/file/tree/tree.go
@@ -0,0 +1,585 @@
+// Copyright ©2012 The bíogo Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found at the end of this file.
+
+// Package tree implements Left-Leaning Red Black trees as described by Robert Sedgewick.
+//
+// More details relating to the implementation are available at the following locations:
+//
+// http://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf
+// http://www.cs.princeton.edu/~rs/talks/LLRB/Java/RedBlackBST.java
+// http://www.teachsolaisgames.com/articles/balanced_left_leaning.html
+//
+// Heavily modified by Miek Gieben for use in DNS zones.
+package tree
+
+// TODO(miek): locking? lockfree
+// TODO(miek): fix docs
+
+import (
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+const (
+ TD234 = iota
+ BU23
+)
+
+// Operation mode of the LLRB tree.
+const Mode = BU23
+
+func init() {
+ if Mode != TD234 && Mode != BU23 {
+ panic("tree: unknown mode")
+ }
+}
+
+type Elem struct {
+ m map[uint16][]dns.RR
+}
+
+// newElem returns a new elem
+func newElem(rr dns.RR) *Elem {
+ e := Elem{m: make(map[uint16][]dns.RR)}
+ e.m[rr.Header().Rrtype] = []dns.RR{rr}
+ return &e
+}
+
+// Types returns the types from with type qtype from e.
+func (e *Elem) Types(qtype uint16) []dns.RR {
+ if rrs, ok := e.m[qtype]; ok {
+ // TODO(miek): length should never be zero here.
+ return rrs
+ }
+ return nil
+}
+
+func (e *Elem) All() []dns.RR {
+ list := []dns.RR{}
+ for _, rrs := range e.m {
+ list = append(list, rrs...)
+ }
+ return list
+}
+
+func (e *Elem) Insert(rr dns.RR) {
+ t := rr.Header().Rrtype
+ if e.m == nil {
+ e.m = make(map[uint16][]dns.RR)
+ e.m[t] = []dns.RR{rr}
+ return
+ }
+ rrs, ok := e.m[t]
+ if !ok {
+ e.m[t] = []dns.RR{rr}
+ return
+ }
+ for _, er := range rrs {
+ if equalRdata(er, rr) {
+ return
+ }
+ }
+
+ rrs = append(rrs, rr)
+ e.m[t] = rrs
+}
+
+// Delete removes rr from e. When e is empty after the removal the returned bool is true.
+func (e *Elem) Delete(rr dns.RR) (empty bool) {
+ t := rr.Header().Rrtype
+ if e.m == nil {
+ return
+ }
+ rrs, ok := e.m[t]
+ if !ok {
+ return
+ }
+ for i, er := range rrs {
+ if equalRdata(er, rr) {
+ rrs = removeFromSlice(rrs, i)
+ e.m[t] = rrs
+ empty = len(rrs) == 0
+ if empty {
+ delete(e.m, t)
+ }
+ return
+ }
+ }
+ return
+}
+
+// TODO(miek): need case ignore compare that is more efficient.
+func Less(a *Elem, rr dns.RR) int {
+ aname := ""
+ for _, ar := range a.m {
+ aname = strings.ToLower(ar[0].Header().Name)
+ break
+ }
+ rname := strings.ToLower(rr.Header().Name)
+ if aname == rname {
+ return 0
+ }
+ if aname < rname {
+ return -1
+ }
+ return 1
+}
+
+// Assuming the same type and name this will check if the rdata is equal as well.
+func equalRdata(a, b dns.RR) bool {
+ switch x := a.(type) {
+ case *dns.A:
+ return x.A.Equal(b.(*dns.A).A)
+ case *dns.AAAA:
+ return x.AAAA.Equal(b.(*dns.AAAA).AAAA)
+ case *dns.MX:
+ if x.Mx == b.(*dns.MX).Mx && x.Preference == b.(*dns.MX).Preference {
+ return true
+ }
+ }
+ return false
+}
+
+// removeFromSlice removes index i from the slice.
+func removeFromSlice(rrs []dns.RR, i int) []dns.RR {
+ if i >= len(rrs) {
+ return rrs
+ }
+ rrs = append(rrs[:i], rrs[i+1:]...)
+ return rrs
+}
+
+// A Color represents the color of a Node.
+type Color bool
+
+const (
+ // Red as false give us the defined behaviour that new nodes are red. Although this
+ // is incorrect for the root node, that is resolved on the first insertion.
+ Red Color = false
+ Black Color = true
+)
+
+// A Node represents a node in the LLRB tree.
+type Node struct {
+ Elem *Elem
+ Left, Right *Node
+ Color Color
+}
+
+// A Tree manages the root node of an LLRB tree. Public methods are exposed through this type.
+type Tree struct {
+ Root *Node // Root node of the tree.
+ Count int // Number of elements stored.
+}
+
+// Helper methods
+
+// color returns the effect color of a Node. A nil node returns black.
+func (n *Node) color() Color {
+ if n == nil {
+ return Black
+ }
+ return n.Color
+}
+
+// (a,c)b -rotL-> ((a,)b,)c
+func (n *Node) rotateLeft() (root *Node) {
+ // Assumes: n has two children.
+ root = n.Right
+ n.Right = root.Left
+ root.Left = n
+ root.Color = n.Color
+ n.Color = Red
+ return
+}
+
+// (a,c)b -rotR-> (,(,c)b)a
+func (n *Node) rotateRight() (root *Node) {
+ // Assumes: n has two children.
+ root = n.Left
+ n.Left = root.Right
+ root.Right = n
+ root.Color = n.Color
+ n.Color = Red
+ return
+}
+
+// (aR,cR)bB -flipC-> (aB,cB)bR | (aB,cB)bR -flipC-> (aR,cR)bB
+func (n *Node) flipColors() {
+ // Assumes: n has two children.
+ n.Color = !n.Color
+ n.Left.Color = !n.Left.Color
+ n.Right.Color = !n.Right.Color
+}
+
+// fixUp ensures that black link balance is correct, that red nodes lean left,
+// and that 4 nodes are split in the case of BU23 and properly balanced in TD234.
+func (n *Node) fixUp() *Node {
+ if n.Right.color() == Red {
+ if Mode == TD234 && n.Right.Left.color() == Red {
+ n.Right = n.Right.rotateRight()
+ }
+ n = n.rotateLeft()
+ }
+ if n.Left.color() == Red && n.Left.Left.color() == Red {
+ n = n.rotateRight()
+ }
+ if Mode == BU23 && n.Left.color() == Red && n.Right.color() == Red {
+ n.flipColors()
+ }
+ return n
+}
+
+func (n *Node) moveRedLeft() *Node {
+ n.flipColors()
+ if n.Right.Left.color() == Red {
+ n.Right = n.Right.rotateRight()
+ n = n.rotateLeft()
+ n.flipColors()
+ if Mode == TD234 && n.Right.Right.color() == Red {
+ n.Right = n.Right.rotateLeft()
+ }
+ }
+ return n
+}
+
+func (n *Node) moveRedRight() *Node {
+ n.flipColors()
+ if n.Left.Left.color() == Red {
+ n = n.rotateRight()
+ n.flipColors()
+ }
+ return n
+}
+
+// Len returns the number of elements stored in the Tree.
+func (t *Tree) Len() int {
+ return t.Count
+}
+
+// Get returns the first match of q in the Tree. If insertion without
+// replacement is used, this is probably not what you want.
+func (t *Tree) Get(rr dns.RR) *Elem {
+ if t.Root == nil {
+ return nil
+ }
+ n := t.Root.search(rr)
+ if n == nil {
+ return nil
+ }
+ return n.Elem
+}
+
+func (n *Node) search(rr dns.RR) *Node {
+ for n != nil {
+ switch c := Less(n.Elem, rr); {
+ case c == 0:
+ return n
+ case c < 0:
+ n = n.Left
+ default:
+ n = n.Right
+ }
+ }
+
+ return n
+}
+
+// Insert inserts the Comparable e into the Tree at the first match found
+// with e or when a nil node is reached. Insertion without replacement can
+// specified by ensuring that e.Compare() never returns 0. If insert without
+// replacement is performed, a distinct query Comparable must be used that
+// can return 0 with a Compare() call.
+func (t *Tree) Insert(rr dns.RR) {
+ var d int
+ t.Root, d = t.Root.insert(rr)
+ t.Count += d
+ t.Root.Color = Black
+}
+
+func (n *Node) insert(rr dns.RR) (root *Node, d int) {
+ if n == nil {
+ return &Node{Elem: newElem(rr)}, 1
+ } else if n.Elem == nil {
+ n.Elem = newElem(rr)
+ return n, 1
+ }
+
+ if Mode == TD234 {
+ if n.Left.color() == Red && n.Right.color() == Red {
+ n.flipColors()
+ }
+ }
+
+ switch c := Less(n.Elem, rr); {
+ case c == 0:
+ n.Elem.Insert(rr)
+ case c < 0:
+ n.Left, d = n.Left.insert(rr)
+ default:
+ n.Right, d = n.Right.insert(rr)
+ }
+
+ if n.Right.color() == Red && n.Left.color() == Black {
+ n = n.rotateLeft()
+ }
+ if n.Left.color() == Red && n.Left.Left.color() == Red {
+ n = n.rotateRight()
+ }
+
+ if Mode == BU23 {
+ if n.Left.color() == Red && n.Right.color() == Red {
+ n.flipColors()
+ }
+ }
+
+ root = n
+
+ return
+}
+
+// DeleteMin deletes the node with the minimum value in the tree. If insertion without
+// replacement has been used, the left-most minimum will be deleted.
+func (t *Tree) DeleteMin() {
+ if t.Root == nil {
+ return
+ }
+ var d int
+ t.Root, d = t.Root.deleteMin()
+ t.Count += d
+ if t.Root == nil {
+ return
+ }
+ t.Root.Color = Black
+}
+
+func (n *Node) deleteMin() (root *Node, d int) {
+ if n.Left == nil {
+ return nil, -1
+ }
+ if n.Left.color() == Black && n.Left.Left.color() == Black {
+ n = n.moveRedLeft()
+ }
+ n.Left, d = n.Left.deleteMin()
+
+ root = n.fixUp()
+
+ return
+}
+
+// DeleteMax deletes the node with the maximum value in the tree. If insertion without
+// replacement has been used, the right-most maximum will be deleted.
+func (t *Tree) DeleteMax() {
+ if t.Root == nil {
+ return
+ }
+ var d int
+ t.Root, d = t.Root.deleteMax()
+ t.Count += d
+ if t.Root == nil {
+ return
+ }
+ t.Root.Color = Black
+}
+
+func (n *Node) deleteMax() (root *Node, d int) {
+ if n.Left != nil && n.Left.color() == Red {
+ n = n.rotateRight()
+ }
+ if n.Right == nil {
+ return nil, -1
+ }
+ if n.Right.color() == Black && n.Right.Left.color() == Black {
+ n = n.moveRedRight()
+ }
+ n.Right, d = n.Right.deleteMax()
+
+ root = n.fixUp()
+
+ return
+}
+
+// Delete removes rr from the tree, is the node turns empty, that node is return with DeleteNode.
+func (t *Tree) Delete(rr dns.RR) {
+ if t.Root == nil {
+ return
+ }
+ // If there is an element, remove the rr from it
+ el := t.Get(rr)
+ if el == nil {
+ t.DeleteNode(rr)
+ return
+ }
+ // delete from this element
+ empty := el.Delete(rr)
+ if empty {
+ t.DeleteNode(rr)
+ return
+ }
+}
+
+// DeleteNode deletes the node that matches e according to Compare(). Note that Compare must
+// identify the target node uniquely and in cases where non-unique keys are used,
+// attributes used to break ties must be used to determine tree ordering during insertion.
+func (t *Tree) DeleteNode(rr dns.RR) {
+ if t.Root == nil {
+ return
+ }
+ var d int
+ t.Root, d = t.Root.delete(rr)
+ t.Count += d
+ if t.Root == nil {
+ return
+ }
+ t.Root.Color = Black
+}
+
+func (n *Node) delete(rr dns.RR) (root *Node, d int) {
+ if Less(n.Elem, rr) < 0 {
+ if n.Left != nil {
+ if n.Left.color() == Black && n.Left.Left.color() == Black {
+ n = n.moveRedLeft()
+ }
+ n.Left, d = n.Left.delete(rr)
+ }
+ } else {
+ if n.Left.color() == Red {
+ n = n.rotateRight()
+ }
+ if n.Right == nil && Less(n.Elem, rr) == 0 {
+ return nil, -1
+ }
+ if n.Right != nil {
+ if n.Right.color() == Black && n.Right.Left.color() == Black {
+ n = n.moveRedRight()
+ }
+ if Less(n.Elem, rr) == 0 {
+ n.Elem = n.Right.min().Elem
+ n.Right, d = n.Right.deleteMin()
+ } else {
+ n.Right, d = n.Right.delete(rr)
+ }
+ }
+ }
+
+ root = n.fixUp()
+
+ return
+}
+
+// Return the minimum value stored in the tree. This will be the left-most minimum value if
+// insertion without replacement has been used.
+func (t *Tree) Min() *Elem {
+ if t.Root == nil {
+ return nil
+ }
+ return t.Root.min().Elem
+}
+
+func (n *Node) min() *Node {
+ for ; n.Left != nil; n = n.Left {
+ }
+ return n
+}
+
+// Return the maximum value stored in the tree. This will be the right-most maximum value if
+// insertion without replacement has been used.
+func (t *Tree) Max() *Elem {
+ if t.Root == nil {
+ return nil
+ }
+ return t.Root.max().Elem
+}
+
+func (n *Node) max() *Node {
+ for ; n.Right != nil; n = n.Right {
+ }
+ return n
+}
+
+// Floor returns the greatest value equal to or less than the query q according to q.Compare().
+func (t *Tree) Floor(rr dns.RR) *Elem {
+ if t.Root == nil {
+ return nil
+ }
+ n := t.Root.floor(rr)
+ if n == nil {
+ return nil
+ }
+ return n.Elem
+}
+
+func (n *Node) floor(rr dns.RR) *Node {
+ if n == nil {
+ return nil
+ }
+ switch c := Less(n.Elem, rr); {
+ case c == 0:
+ return n
+ case c < 0:
+ return n.Left.floor(rr)
+ default:
+ if r := n.Right.floor(rr); r != nil {
+ return r
+ }
+ }
+ return n
+}
+
+// Ceil returns the smallest value equal to or greater than the query q according to q.Compare().
+func (t *Tree) Ceil(rr dns.RR) *Elem {
+ if t.Root == nil {
+ return nil
+ }
+ n := t.Root.ceil(rr)
+ if n == nil {
+ return nil
+ }
+ return n.Elem
+}
+
+func (n *Node) ceil(rr dns.RR) *Node {
+ if n == nil {
+ return nil
+ }
+ switch c := Less(n.Elem, rr); {
+ case c == 0:
+ return n
+ case c > 0:
+ return n.Right.ceil(rr)
+ default:
+ if l := n.Left.ceil(rr); l != nil {
+ return l
+ }
+ }
+ return n
+}
+
+/*
+Copyright ©2012 The bíogo Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+* Neither the name of the bíogo project nor the names of its authors and
+ contributors may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
diff --git a/middleware/file/zone.go b/middleware/file/zone.go
new file mode 100644
index 000000000..57eb8d997
--- /dev/null
+++ b/middleware/file/zone.go
@@ -0,0 +1,26 @@
+package file
+
+import (
+ "github.com/miekg/coredns/middleware/file/tree"
+
+ "github.com/miekg/dns"
+)
+
+type Zone struct {
+ SOA *dns.SOA
+ SIG []*dns.RRSIG
+ name string
+ *tree.Tree
+}
+
+func NewZone(name string) *Zone {
+ return &Zone{name: dns.Fqdn(name), Tree: &tree.Tree{}}
+}
+
+func (z *Zone) Insert(r dns.RR) {
+ z.Tree.Insert(r)
+}
+
+func (z *Zone) Delete(r dns.RR) {
+ z.Tree.Delete(r)
+}
diff --git a/middleware/file/zone_test.go b/middleware/file/zone_test.go
new file mode 100644
index 000000000..4e3997c46
--- /dev/null
+++ b/middleware/file/zone_test.go
@@ -0,0 +1,30 @@
+package file
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestZoneInsert(t *testing.T) {
+ z := NewZone("miek.nl")
+ rr, _ := dns.NewRR("miek.nl. IN A 127.0.0.1")
+ z.Insert(rr)
+
+ t.Logf("%+v\n", z)
+
+ elem := z.Get(rr)
+ t.Logf("%+v\n", elem)
+ if elem != nil {
+ t.Logf("%+v\n", elem.Types(dns.TypeA))
+ }
+ z.Delete(rr)
+
+ t.Logf("%+v\n", z)
+
+ elem = z.Get(rr)
+ t.Logf("%+v\n", elem)
+ if elem != nil {
+ t.Logf("%+v\n", elem.Types(dns.TypeA))
+ }
+}
diff --git a/middleware/loadbalance/loadbalance.go b/middleware/loadbalance/loadbalance.go
index c81ad0c8a..e1bee25fd 100644
--- a/middleware/loadbalance/loadbalance.go
+++ b/middleware/loadbalance/loadbalance.go
@@ -14,18 +14,21 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
if res.Rcode != dns.RcodeSuccess {
return r.ResponseWriter.WriteMsg(res)
}
- if len(res.Answer) < 2 { // don't even bother
- return r.ResponseWriter.WriteMsg(res)
- }
- // put CNAMEs first, randomize a/aaaa's and put packet back together.
- // TODO(miek): check family and give v6 more prio?
+ res.Answer = roundRobin(res.Answer)
+ res.Extra = roundRobin(res.Extra)
+
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
rest := []dns.RR{}
- for _, r := range res.Answer {
+ for _, r := range in {
switch r.Header().Rrtype {
case dns.TypeCNAME:
+ // d d d d DNAME and friends here as well?
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
@@ -36,7 +39,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
switch l := len(address); l {
case 0, 1:
- return r.ResponseWriter.WriteMsg(res)
+ break
case 2:
if dns.Id()%2 == 0 {
address[0], address[1] = address[1], address[0]
@@ -51,9 +54,9 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
address[q], address[p] = address[p], address[q]
}
}
- res.Answer = append(cname, rest...)
- res.Answer = append(res.Answer, address...)
- return r.ResponseWriter.WriteMsg(res)
+ out := append(cname, rest...)
+ out = append(out, address...)
+ return out
}
// Should we pack and unpack here to fiddle with the packet... Not likely.
diff --git a/middleware/loadbalance/loadbalance.md b/middleware/loadbalance/loadbalance.md
index 0e931fb53..5c381135d 100644
--- a/middleware/loadbalance/loadbalance.md
+++ b/middleware/loadbalance/loadbalance.md
@@ -4,13 +4,15 @@
message. See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons
on this setup.
+It will take care to sort any CNAMEs before any address records.
+
## Syntax
~~~
loadbalance [policy]
~~~
-* policy is how to balance, the default is "round_robin"
+* `policy` is how to balance, the default is "round_robin"
## Examples
diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go
new file mode 100644
index 000000000..dc027607c
--- /dev/null
+++ b/middleware/loadbalance/loadbalance_test.go
@@ -0,0 +1,104 @@
+package loadbalance
+
+import (
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+func TestLoadBalance(t *testing.T) {
+ rm := RoundRobin{Next: handler()}
+
+ // the first X records must be cnames after this test
+ tests := []struct {
+ answer []dns.RR
+ extra []dns.RR
+ cnameAnswer int
+ cnameExtra int
+ }{
+ {
+ answer: []dns.RR{
+ newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
+ newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ },
+ cnameAnswer: 4,
+ },
+ {
+ answer: []dns.RR{
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
+ },
+ cnameAnswer: 1,
+ },
+ {
+ answer: []dns.RR{
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ },
+ extra: []dns.RR{
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
+ newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
+ newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
+ newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
+ newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
+ },
+ cnameAnswer: 1,
+ cnameExtra: 1,
+ },
+ }
+
+ rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{})
+
+ for i, test := range tests {
+ req := new(dns.Msg)
+ req.SetQuestion("region2.skydns.test.", dns.TypeSRV)
+ req.Answer = test.answer
+ req.Extra = test.extra
+
+ _, err := rm.ServeDNS(context.TODO(), rec, req)
+ if err != nil {
+ t.Errorf("Test %d: Expected no error, but got %s", i, err)
+ continue
+
+ }
+ cname := 0
+ for _, r := range rec.Msg().Answer {
+ if r.Header().Rrtype != dns.TypeCNAME {
+ break
+ }
+ cname++
+ }
+ if cname != test.cnameAnswer {
+ t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname)
+ }
+ cname = 0
+ for _, r := range rec.Msg().Extra {
+ if r.Header().Rrtype != dns.TypeCNAME {
+ break
+ }
+ cname++
+ }
+ if cname != test.cnameExtra {
+ t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname)
+ }
+ }
+}
+
+func handler() middleware.Handler {
+ return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ w.WriteMsg(r)
+ return dns.RcodeSuccess, nil
+ })
+}
+
+func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
+func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
+func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go
index efe6f47d2..83221ba78 100644
--- a/middleware/proxy/reverseproxy.go
+++ b/middleware/proxy/reverseproxy.go
@@ -3,6 +3,7 @@ package proxy
import (
"github.com/miekg/coredns/middleware"
+
"github.com/miekg/dns"
)
@@ -18,7 +19,7 @@ func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR)
)
state := middleware.State{W: w, Req: r}
- // tls+tcp ?
+ // We forward the original request, no need to fiddle with EDNS0 opt sizes.
if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
} else {
diff --git a/middleware/state.go b/middleware/state.go
index f4cd43542..7ae89ef1c 100644
--- a/middleware/state.go
+++ b/middleware/state.go
@@ -102,9 +102,9 @@ func (s State) Size() int {
return dns.MinMsgSize
}
-// SizeAndDo returns a ready made OPT record that the reflects the intent
-// from the state. This can be added to upstream requests that will then
-// hopefully return a message that is understandable by the original client.
+// SizeAndDo returns a ready made OPT record that the reflects the intent from
+// state. This can be added to upstream requests that will then hopefully
+// return a message that is fits the buffer in the client.
func (s State) SizeAndDo() *dns.OPT {
size := s.Size()
Do := s.Do()
@@ -119,6 +119,40 @@ func (s State) SizeAndDo() *dns.OPT {
return o
}
+// Result is the result of Fit.
+type Result int
+
+const (
+ // ScrubIgnored is returned when Scrub did nothing to the message.
+ ScrubIgnored Result = iota
+ // ScrubDone is returned when the reply has been scrubbed.
+ ScrubDone
+)
+
+// Scrub scrubs the reply message so that it will fit the client's buffer. If even after dropping
+// the additional section, it still does not fit the TC bit will be set on the message. Note,
+// the TC bit will be set regardless of protocol, even TCP message will get the bit, the client
+// should then retry with pigeons.
+// TODO(referral).
+func (s State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
+ size := s.Size()
+ l := reply.Len()
+ if size >= l {
+ return reply, ScrubIgnored
+ }
+ // If not delegation, drop additional section.
+ // TODO(miek): check for delegation
+ reply.Extra = nil
+ l = reply.Len()
+ if size >= l {
+ return reply, ScrubDone
+ }
+ // Still?!! does not fit.
+ reply.Truncated = true
+ return reply, ScrubDone
+
+}
+
// Type returns the type of the question as a string.
func (s State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() }
@@ -145,11 +179,3 @@ func (s State) ErrorMessage(rcode int) *dns.Msg {
m.SetRcode(s.Req, rcode)
return m
}
-
-// AnswerMessage returns an error message suitable for sending
-// back to the client.
-func (s State) AnswerMessage() *dns.Msg {
- m := new(dns.Msg)
- m.SetReply(s.Req)
- return m
-}