diff options
Diffstat (limited to 'plugin/pkg')
36 files changed, 2353 insertions, 0 deletions
diff --git a/plugin/pkg/cache/cache.go b/plugin/pkg/cache/cache.go new file mode 100644 index 000000000..56cae2180 --- /dev/null +++ b/plugin/pkg/cache/cache.go @@ -0,0 +1,129 @@ +// Package cache implements a cache. The cache hold 256 shards, each shard +// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it +// just randomly evicts elements when it gets full. +package cache + +import ( + "hash/fnv" + "sync" +) + +// Hash returns the FNV hash of what. +func Hash(what []byte) uint32 { + h := fnv.New32() + h.Write(what) + return h.Sum32() +} + +// Cache is cache. +type Cache struct { + shards [shardSize]*shard +} + +// shard is a cache with random eviction. +type shard struct { + items map[uint32]interface{} + size int + + sync.RWMutex +} + +// New returns a new cache. +func New(size int) *Cache { + ssize := size / shardSize + if ssize < 512 { + ssize = 512 + } + + c := &Cache{} + + // Initialize all the shards + for i := 0; i < shardSize; i++ { + c.shards[i] = newShard(ssize) + } + return c +} + +// Add adds a new element to the cache. If the element already exists it is overwritten. +func (c *Cache) Add(key uint32, el interface{}) { + shard := key & (shardSize - 1) + c.shards[shard].Add(key, el) +} + +// Get looks up element index under key. +func (c *Cache) Get(key uint32) (interface{}, bool) { + shard := key & (shardSize - 1) + return c.shards[shard].Get(key) +} + +// Remove removes the element indexed with key. +func (c *Cache) Remove(key uint32) { + shard := key & (shardSize - 1) + c.shards[shard].Remove(key) +} + +// Len returns the number of elements in the cache. +func (c *Cache) Len() int { + l := 0 + for _, s := range c.shards { + l += s.Len() + } + return l +} + +// newShard returns a new shard with size. +func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} } + +// Add adds element indexed by key into the cache. Any existing element is overwritten +func (s *shard) Add(key uint32, el interface{}) { + l := s.Len() + if l+1 > s.size { + s.Evict() + } + + s.Lock() + s.items[key] = el + s.Unlock() +} + +// Remove removes the element indexed by key from the cache. +func (s *shard) Remove(key uint32) { + s.Lock() + delete(s.items, key) + s.Unlock() +} + +// Evict removes a random element from the cache. +func (s *shard) Evict() { + s.Lock() + defer s.Unlock() + + key := -1 + for k := range s.items { + key = int(k) + break + } + if key == -1 { + // empty cache + return + } + delete(s.items, uint32(key)) +} + +// Get looks up the element indexed under key. +func (s *shard) Get(key uint32) (interface{}, bool) { + s.RLock() + el, found := s.items[key] + s.RUnlock() + return el, found +} + +// Len returns the current length of the cache. +func (s *shard) Len() int { + s.RLock() + l := len(s.items) + s.RUnlock() + return l +} + +const shardSize = 256 diff --git a/plugin/pkg/cache/cache_test.go b/plugin/pkg/cache/cache_test.go new file mode 100644 index 000000000..2c92bf438 --- /dev/null +++ b/plugin/pkg/cache/cache_test.go @@ -0,0 +1,31 @@ +package cache + +import "testing" + +func TestCacheAddAndGet(t *testing.T) { + c := New(4) + c.Add(1, 1) + + if _, found := c.Get(1); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestCacheLen(t *testing.T) { + c := New(4) + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(2, 2) + if l := c.Len(); l != 2 { + t.Fatalf("Cache size should %d, got %d", 2, l) + } +} diff --git a/plugin/pkg/cache/shard_test.go b/plugin/pkg/cache/shard_test.go new file mode 100644 index 000000000..26675cee1 --- /dev/null +++ b/plugin/pkg/cache/shard_test.go @@ -0,0 +1,60 @@ +package cache + +import "testing" + +func TestShardAddAndGet(t *testing.T) { + s := newShard(4) + s.Add(1, 1) + + if _, found := s.Get(1); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestShardLen(t *testing.T) { + s := newShard(4) + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(2, 2) + if l := s.Len(); l != 2 { + t.Fatalf("Shard size should %d, got %d", 2, l) + } +} + +func TestShardEvict(t *testing.T) { + s := newShard(1) + s.Add(1, 1) + s.Add(2, 2) + // 1 should be gone + + if _, found := s.Get(1); found { + t.Fatal("Found item that should have been evicted") + } +} + +func TestShardLenEvict(t *testing.T) { + s := newShard(4) + s.Add(1, 1) + s.Add(2, 1) + s.Add(3, 1) + s.Add(4, 1) + + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + + // This should evict one element + s.Add(5, 1) + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } +} diff --git a/plugin/pkg/dnsrecorder/recorder.go b/plugin/pkg/dnsrecorder/recorder.go new file mode 100644 index 000000000..3ca5f00d0 --- /dev/null +++ b/plugin/pkg/dnsrecorder/recorder.go @@ -0,0 +1,58 @@ +// Package dnsrecorder allows you to record a DNS response when it is send to the client. +package dnsrecorder + +import ( + "time" + + "github.com/miekg/dns" +) + +// Recorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type Recorder struct { + dns.ResponseWriter + Rcode int + Len int + Msg *dns.Msg + Start time.Time +} + +// New makes and returns a new Recorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func New(w dns.ResponseWriter) *Recorder { + return &Recorder{ + ResponseWriter: w, + Rcode: 0, + Msg: nil, + Start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *Recorder) WriteMsg(res *dns.Msg) error { + r.Rcode = res.Rcode + // We may get called multiple times (axfr for instance). + // Save the last message, but add the sizes. + r.Len += res.Len() + r.Msg = res + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the length of the message that gets written. +func (r *Recorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Len += n + } + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return } diff --git a/plugin/pkg/dnsrecorder/recorder_test.go b/plugin/pkg/dnsrecorder/recorder_test.go new file mode 100644 index 000000000..c9c2f6ce4 --- /dev/null +++ b/plugin/pkg/dnsrecorder/recorder_test.go @@ -0,0 +1,28 @@ +package dnsrecorder + +/* +func TestNewResponseRecorder(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + if !(recordRequest.ResponseWriter == w) { + t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n") + } + if recordRequest.status != http.StatusOK { + t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status) + } +} + +func TestWrite(t *testing.T) { + w := httptest.NewRecorder() + responseTestString := "test" + recordRequest := NewResponseRecorder(w) + buf := []byte(responseTestString) + recordRequest.Write(buf) + if recordRequest.size != len(buf) { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size) + } + if w.Body.String() != responseTestString { + t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) + } +} +*/ diff --git a/plugin/pkg/dnsutil/cname.go b/plugin/pkg/dnsutil/cname.go new file mode 100644 index 000000000..281e03218 --- /dev/null +++ b/plugin/pkg/dnsutil/cname.go @@ -0,0 +1,15 @@ +package dnsutil + +import "github.com/miekg/dns" + +// DuplicateCNAME returns true if r already exists in records. +func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { + for _, rec := range records { + if v, ok := rec.(*dns.CNAME); ok { + if v.Target == r.Target { + return true + } + } + } + return false +} diff --git a/plugin/pkg/dnsutil/cname_test.go b/plugin/pkg/dnsutil/cname_test.go new file mode 100644 index 000000000..5fb8d3029 --- /dev/null +++ b/plugin/pkg/dnsutil/cname_test.go @@ -0,0 +1,55 @@ +package dnsutil + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestDuplicateCNAME(t *testing.T) { + tests := []struct { + cname string + records []string + expected bool + }{ + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{ + "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534", + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + }, + true, + }, + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{ + "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534", + }, + false, + }, + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{}, + false, + }, + } + for i, test := range tests { + cnameRR, err := dns.NewRR(test.cname) + if err != nil { + t.Fatalf("Test %d, cname ('%s') error (%s)!", i, test.cname, err) + } + cname := cnameRR.(*dns.CNAME) + records := []dns.RR{} + for j, r := range test.records { + rr, err := dns.NewRR(r) + if err != nil { + t.Fatalf("Test %d, record %d ('%s') error (%s)!", i, j, r, err) + } + records = append(records, rr) + } + got := DuplicateCNAME(cname, records) + if got != test.expected { + t.Errorf("Test %d, expected '%v', got '%v' for CNAME ('%s') and RECORDS (%v)", i, test.expected, got, test.cname, test.records) + } + } +} diff --git a/plugin/pkg/dnsutil/dedup.go b/plugin/pkg/dnsutil/dedup.go new file mode 100644 index 000000000..dae656a01 --- /dev/null +++ b/plugin/pkg/dnsutil/dedup.go @@ -0,0 +1,12 @@ +package dnsutil + +import "github.com/miekg/dns" + +// Dedup de-duplicates a message. +func Dedup(m *dns.Msg) *dns.Msg { + // TODO(miek): expensive! + m.Answer = dns.Dedup(m.Answer, nil) + m.Ns = dns.Dedup(m.Ns, nil) + m.Extra = dns.Dedup(m.Extra, nil) + return m +} diff --git a/plugin/pkg/dnsutil/doc.go b/plugin/pkg/dnsutil/doc.go new file mode 100644 index 000000000..75d1e8c7a --- /dev/null +++ b/plugin/pkg/dnsutil/doc.go @@ -0,0 +1,2 @@ +// Package dnsutil contains DNS related helper functions. +package dnsutil diff --git a/plugin/pkg/dnsutil/host.go b/plugin/pkg/dnsutil/host.go new file mode 100644 index 000000000..aaab586e8 --- /dev/null +++ b/plugin/pkg/dnsutil/host.go @@ -0,0 +1,82 @@ +package dnsutil + +import ( + "fmt" + "net" + "os" + + "github.com/miekg/dns" +) + +// ParseHostPortOrFile parses the strings in s, each string can either be a address, +// address:port or a filename. The address part is checked and the filename case a +// resolv.conf like file is parsed and the nameserver found are returned. +func ParseHostPortOrFile(s ...string) ([]string, error) { + var servers []string + for _, host := range s { + addr, _, err := net.SplitHostPort(host) + if err != nil { + // Parse didn't work, it is not a addr:port combo + if net.ParseIP(host) == nil { + // Not an IP address. + ss, err := tryFile(host) + if err == nil { + servers = append(servers, ss...) + continue + } + return servers, fmt.Errorf("not an IP address or file: %q", host) + } + ss := net.JoinHostPort(host, "53") + servers = append(servers, ss) + continue + } + + if net.ParseIP(addr) == nil { + // No an IP address. + ss, err := tryFile(host) + if err == nil { + servers = append(servers, ss...) + continue + } + return servers, fmt.Errorf("not an IP address or file: %q", host) + } + servers = append(servers, host) + } + return servers, nil +} + +// Try to open this is a file first. +func tryFile(s string) ([]string, error) { + c, err := dns.ClientConfigFromFile(s) + if err == os.ErrNotExist { + return nil, fmt.Errorf("failed to open file %q: %q", s, err) + } else if err != nil { + return nil, err + } + + servers := []string{} + for _, s := range c.Servers { + servers = append(servers, net.JoinHostPort(s, c.Port)) + } + return servers, nil +} + +// ParseHostPort will check if the host part is a valid IP address, if the +// IP address is valid, but no port is found, defaultPort is added. +func ParseHostPort(s, defaultPort string) (string, error) { + addr, port, err := net.SplitHostPort(s) + if port == "" { + port = defaultPort + } + if err != nil { + if net.ParseIP(s) == nil { + return "", fmt.Errorf("must specify an IP address: `%s'", s) + } + return net.JoinHostPort(s, port), nil + } + + if net.ParseIP(addr) == nil { + return "", fmt.Errorf("must specify an IP address: `%s'", addr) + } + return net.JoinHostPort(addr, port), nil +} diff --git a/plugin/pkg/dnsutil/host_test.go b/plugin/pkg/dnsutil/host_test.go new file mode 100644 index 000000000..cc55f4570 --- /dev/null +++ b/plugin/pkg/dnsutil/host_test.go @@ -0,0 +1,85 @@ +package dnsutil + +import ( + "io/ioutil" + "os" + "testing" +) + +func TestParseHostPortOrFile(t *testing.T) { + tests := []struct { + in string + expected string + shouldErr bool + }{ + { + "8.8.8.8", + "8.8.8.8:53", + false, + }, + { + "8.8.8.8:153", + "8.8.8.8:153", + false, + }, + { + "/etc/resolv.conf:53", + "", + true, + }, + { + "resolv.conf", + "127.0.0.1:53", + false, + }, + } + + err := ioutil.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600) + if err != nil { + t.Fatalf("Failed to write test resolv.conf") + } + defer os.Remove("resolv.conf") + + for i, tc := range tests { + got, err := ParseHostPortOrFile(tc.in) + if err == nil && tc.shouldErr { + t.Errorf("Test %d, expected error, got nil", i) + continue + } + if err != nil && tc.shouldErr { + continue + } + if got[0] != tc.expected { + t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got[0]) + } + } +} + +func TestParseHostPort(t *testing.T) { + tests := []struct { + in string + expected string + shouldErr bool + }{ + {"8.8.8.8:53", "8.8.8.8:53", false}, + {"a.a.a.a:153", "", true}, + {"8.8.8.8", "8.8.8.8:53", false}, + {"8.8.8.8:", "8.8.8.8:53", false}, + {"8.8.8.8::53", "", true}, + {"resolv.conf", "", true}, + } + + for i, tc := range tests { + got, err := ParseHostPort(tc.in, "53") + if err == nil && tc.shouldErr { + t.Errorf("Test %d, expected error, got nil", i) + continue + } + if err != nil && !tc.shouldErr { + t.Errorf("Test %d, expected no error, got %q", i, err) + } + if got != tc.expected { + t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got) + } + } +} diff --git a/plugin/pkg/dnsutil/join.go b/plugin/pkg/dnsutil/join.go new file mode 100644 index 000000000..515bf3dad --- /dev/null +++ b/plugin/pkg/dnsutil/join.go @@ -0,0 +1,19 @@ +package dnsutil + +import ( + "strings" + + "github.com/miekg/dns" +) + +// Join joins labels to form a fully qualified domain name. If the last label is +// the root label it is ignored. Not other syntax checks are performed. +func Join(labels []string) string { + ll := len(labels) + if labels[ll-1] == "." { + s := strings.Join(labels[:ll-1], ".") + return dns.Fqdn(s) + } + s := strings.Join(labels, ".") + return dns.Fqdn(s) +} diff --git a/plugin/pkg/dnsutil/join_test.go b/plugin/pkg/dnsutil/join_test.go new file mode 100644 index 000000000..26eeb5897 --- /dev/null +++ b/plugin/pkg/dnsutil/join_test.go @@ -0,0 +1,20 @@ +package dnsutil + +import "testing" + +func TestJoin(t *testing.T) { + tests := []struct { + in []string + out string + }{ + {[]string{"bla", "bliep", "example", "org"}, "bla.bliep.example.org."}, + {[]string{"example", "."}, "example."}, + {[]string{"."}, "."}, + } + + for i, tc := range tests { + if x := Join(tc.in); x != tc.out { + t.Errorf("Test %d, expected %s, got %s", i, tc.out, x) + } + } +} diff --git a/plugin/pkg/dnsutil/reverse.go b/plugin/pkg/dnsutil/reverse.go new file mode 100644 index 000000000..daf9cc600 --- /dev/null +++ b/plugin/pkg/dnsutil/reverse.go @@ -0,0 +1,68 @@ +package dnsutil + +import ( + "net" + "strings" +) + +// ExtractAddressFromReverse turns a standard PTR reverse record name +// into an IP address. This works for ipv4 or ipv6. +// +// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion +// failes the empty string is returned. +func ExtractAddressFromReverse(reverseName string) string { + search := "" + + f := reverse + + switch { + case strings.HasSuffix(reverseName, v4arpaSuffix): + search = strings.TrimSuffix(reverseName, v4arpaSuffix) + case strings.HasSuffix(reverseName, v6arpaSuffix): + search = strings.TrimSuffix(reverseName, v6arpaSuffix) + f = reverse6 + default: + return "" + } + + // Reverse the segments and then combine them. + return f(strings.Split(search, ".")) +} + +func reverse(slice []string) string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + ip := net.ParseIP(strings.Join(slice, ".")).To4() + if ip == nil { + return "" + } + return ip.String() +} + +// reverse6 reverse the segments and combine them according to RFC3596: +// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2 +// is reversed to 2001:db8::567:89ab +func reverse6(slice []string) string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + slice6 := []string{} + for i := 0; i < len(slice)/4; i++ { + slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], "")) + } + ip := net.ParseIP(strings.Join(slice6, ":")).To16() + if ip == nil { + return "" + } + return ip.String() +} + +const ( + // v4arpaSuffix is the reverse tree suffix for v4 IP addresses. + v4arpaSuffix = ".in-addr.arpa." + // v6arpaSuffix is the reverse tree suffix for v6 IP addresses. + v6arpaSuffix = ".ip6.arpa." +) diff --git a/plugin/pkg/dnsutil/reverse_test.go b/plugin/pkg/dnsutil/reverse_test.go new file mode 100644 index 000000000..25bd897ac --- /dev/null +++ b/plugin/pkg/dnsutil/reverse_test.go @@ -0,0 +1,51 @@ +package dnsutil + +import ( + "testing" +) + +func TestExtractAddressFromReverse(t *testing.T) { + tests := []struct { + reverseName string + expectedAddress string + }{ + { + "54.119.58.176.in-addr.arpa.", + "176.58.119.54", + }, + { + ".58.176.in-addr.arpa.", + "", + }, + { + "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.in-addr.arpa.", + "", + }, + { + "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "2001:db8::567:89ab", + }, + { + "d.0.1.0.0.2.ip6.arpa.", + "", + }, + { + "54.119.58.176.ip6.arpa.", + "", + }, + { + "NONAME", + "", + }, + { + "", + "", + }, + } + for i, test := range tests { + got := ExtractAddressFromReverse(test.reverseName) + if got != test.expectedAddress { + t.Errorf("Test %d, expected '%s', got '%s'", i, test.expectedAddress, got) + } + } +} diff --git a/plugin/pkg/dnsutil/zone.go b/plugin/pkg/dnsutil/zone.go new file mode 100644 index 000000000..579fef1ba --- /dev/null +++ b/plugin/pkg/dnsutil/zone.go @@ -0,0 +1,20 @@ +package dnsutil + +import ( + "errors" + + "github.com/miekg/dns" +) + +// TrimZone removes the zone component from q. It returns the trimmed +// name or an error is zone is longer then qname. The trimmed name will be returned +// without a trailing dot. +func TrimZone(q string, z string) (string, error) { + zl := dns.CountLabel(z) + i, ok := dns.PrevLabel(q, zl) + if ok || i-1 < 0 { + return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z) + } + // This includes the '.', remove on return + return q[:i-1], nil +} diff --git a/plugin/pkg/dnsutil/zone_test.go b/plugin/pkg/dnsutil/zone_test.go new file mode 100644 index 000000000..81cd1adad --- /dev/null +++ b/plugin/pkg/dnsutil/zone_test.go @@ -0,0 +1,39 @@ +package dnsutil + +import ( + "errors" + "testing" + + "github.com/miekg/dns" +) + +func TestTrimZone(t *testing.T) { + tests := []struct { + qname string + zone string + expected string + err error + }{ + {"a.example.org", "example.org", "a", nil}, + {"a.b.example.org", "example.org", "a.b", nil}, + {"b.", ".", "b", nil}, + {"example.org", "example.org", "", errors.New("should err")}, + {"org", "example.org", "", errors.New("should err")}, + } + + for i, tc := range tests { + got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone)) + if tc.err != nil && err == nil { + t.Errorf("Test %d, expected error got nil", i) + continue + } + if tc.err == nil && err != nil { + t.Errorf("Test %d, expected no error got %v", i, err) + continue + } + if got != tc.expected { + t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got) + continue + } + } +} diff --git a/plugin/pkg/edns/edns.go b/plugin/pkg/edns/edns.go new file mode 100644 index 000000000..3f0ea5e16 --- /dev/null +++ b/plugin/pkg/edns/edns.go @@ -0,0 +1,46 @@ +// Package edns provides function useful for adding/inspecting OPT records to/in messages. +package edns + +import ( + "errors" + + "github.com/miekg/dns" +) + +// Version checks the EDNS version in the request. If error +// is nil everything is OK and we can invoke the plugin. If non-nil, the +// returned Msg is valid to be returned to the client (and should). For some +// reason this response should not contain a question RR in the question section. +func Version(req *dns.Msg) (*dns.Msg, error) { + opt := req.IsEdns0() + if opt == nil { + return nil, nil + } + if opt.Version() == 0 { + return nil, nil + } + m := new(dns.Msg) + m.SetReply(req) + // zero out question section, wtf. + m.Question = nil + + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetVersion(0) + o.SetExtendedRcode(dns.RcodeBadVers) + m.Extra = []dns.RR{o} + + return m, errors.New("EDNS0 BADVERS") +} + +// Size returns a normalized size based on proto. +func Size(proto string, size int) int { + if proto == "tcp" { + return dns.MaxMsgSize + } + if size < dns.MinMsgSize { + return dns.MinMsgSize + } + return size +} diff --git a/plugin/pkg/edns/edns_test.go b/plugin/pkg/edns/edns_test.go new file mode 100644 index 000000000..89ac6d2ec --- /dev/null +++ b/plugin/pkg/edns/edns_test.go @@ -0,0 +1,37 @@ +package edns + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestVersion(t *testing.T) { + m := ednsMsg() + m.Extra[0].(*dns.OPT).SetVersion(2) + + _, err := Version(m) + if err == nil { + t.Errorf("expected wrong version, but got OK") + } +} + +func TestVersionNoEdns(t *testing.T) { + m := ednsMsg() + m.Extra = nil + + _, err := Version(m) + if err != nil { + t.Errorf("expected no error, but got one: %s", err) + } +} + +func ednsMsg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + m.Extra = append(m.Extra, o) + return m +} diff --git a/plugin/pkg/healthcheck/healthcheck.go b/plugin/pkg/healthcheck/healthcheck.go new file mode 100644 index 000000000..18f09087c --- /dev/null +++ b/plugin/pkg/healthcheck/healthcheck.go @@ -0,0 +1,243 @@ +package healthcheck + +import ( + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "sync" + "sync/atomic" + "time" +) + +// UpstreamHostDownFunc can be used to customize how Down behaves. +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Conns int64 // must be first field to be 64-bit aligned on 32-bit systems + Name string // IP address (and port) of this upstream host + Network string // Network (tcp, unix, etc) of the host, default "" is "tcp" + Fails int32 + FailTimeout time.Duration + OkUntil time.Time + CheckDown UpstreamHostDownFunc + CheckURL string + WithoutPathPrefix string + Checking bool + CheckMu sync.Mutex +} + +// Down checks whether the upstream host is down or not. +// Down will try to use uh.CheckDown first, and will fall +// back to some default criteria if necessary. +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + fails := atomic.LoadInt32(&uh.Fails) + after := false + + uh.CheckMu.Lock() + until := uh.OkUntil + uh.CheckMu.Unlock() + + if !until.IsZero() && time.Now().After(until) { + after = true + } + + return after || fails > 0 + } + return uh.CheckDown(uh) +} + +// HostPool is a collection of UpstreamHosts. +type HostPool []*UpstreamHost + +// HealthCheck is used for performing healthcheck +// on a collection of upstream hosts and select +// one based on the policy. +type HealthCheck struct { + wg sync.WaitGroup // Used to wait for running goroutines to stop. + stop chan struct{} // Signals running goroutines to stop. + Hosts HostPool + Policy Policy + Spray Policy + FailTimeout time.Duration + MaxFails int32 + Future time.Duration + Path string + Port string + Interval time.Duration +} + +// Start starts the healthcheck +func (u *HealthCheck) Start() { + u.stop = make(chan struct{}) + if u.Path != "" { + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.healthCheckWorker(u.stop) + }() + } +} + +// Stop sends a signal to all goroutines started by this staticUpstream to exit +// and waits for them to finish before returning. +func (u *HealthCheck) Stop() error { + close(u.stop) + u.wg.Wait() + return nil +} + +// This was moved into a thread so that each host could throw a health +// check at the same time. The reason for this is that if we are checking +// 3 hosts, and the first one is gone, and we spend minutes timing out to +// fail it, we would not have been doing any other health checks in that +// time. So we now have a per-host lock and a threaded health check. +// +// We use the Checking bool to avoid concurrent checks against the same +// host; if one is taking a long time, the next one will find a check in +// progress and simply return before trying. +// +// We are carefully avoiding having the mutex locked while we check, +// otherwise checks will back up, potentially a lot of them if a host is +// absent for a long time. This arrangement makes checks quickly see if +// they are the only one running and abort otherwise. +func healthCheckURL(nextTs time.Time, host *UpstreamHost) { + + // lock for our bool check. We don't just defer the unlock because + // we don't want the lock held while http.Get runs + host.CheckMu.Lock() + + // are we mid check? Don't run another one + if host.Checking { + host.CheckMu.Unlock() + return + } + + host.Checking = true + host.CheckMu.Unlock() + + //log.Printf("[DEBUG] Healthchecking %s, nextTs is %s\n", url, nextTs.Local()) + + // fetch that url. This has been moved into a go func because + // when the remote host is not merely not serving, but actually + // absent, then tcp syn timeouts can be very long, and so one + // fetch could last several check intervals + if r, err := http.Get(host.CheckURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + + if r.StatusCode < 200 || r.StatusCode >= 400 { + log.Printf("[WARNING] Host %s health check returned HTTP code %d\n", + host.Name, r.StatusCode) + nextTs = time.Unix(0, 0) + } + } else { + log.Printf("[WARNING] Host %s health check probe failed: %v\n", host.Name, err) + nextTs = time.Unix(0, 0) + } + + host.CheckMu.Lock() + host.Checking = false + host.OkUntil = nextTs + host.CheckMu.Unlock() +} + +func (u *HealthCheck) healthCheck() { + for _, host := range u.Hosts { + + if host.CheckURL == "" { + var hostName, checkPort string + + // The DNS server might be an HTTP server. If so, extract its name. + ret, err := url.Parse(host.Name) + if err == nil && len(ret.Host) > 0 { + hostName = ret.Host + } else { + hostName = host.Name + } + + // Extract the port number from the parsed server name. + checkHostName, checkPort, err := net.SplitHostPort(hostName) + if err != nil { + checkHostName = hostName + } + + if u.Port != "" { + checkPort = u.Port + } + + host.CheckURL = "http://" + net.JoinHostPort(checkHostName, checkPort) + u.Path + } + + // calculate this before the get + nextTs := time.Now().Add(u.Future) + + // locks/bools should prevent requests backing up + go healthCheckURL(nextTs, host) + } +} + +func (u *HealthCheck) healthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + ticker.Stop() + return + } + } +} + +// Select selects an upstream host based on the policy +// and the healthcheck result. +func (u *HealthCheck) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() && u.Spray == nil { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + if u.Spray == nil { + return nil + } + return u.Spray.Select(pool) + } + + if u.Policy == nil { + h := (&Random{}).Select(pool) + if h != nil { + return h + } + if h == nil && u.Spray == nil { + return nil + } + return u.Spray.Select(pool) + } + + h := u.Policy.Select(pool) + if h != nil { + return h + } + + if u.Spray == nil { + return nil + } + return u.Spray.Select(pool) +} diff --git a/plugin/pkg/healthcheck/policy.go b/plugin/pkg/healthcheck/policy.go new file mode 100644 index 000000000..6a828fc4d --- /dev/null +++ b/plugin/pkg/healthcheck/policy.go @@ -0,0 +1,120 @@ +package healthcheck + +import ( + "log" + "math/rand" + "sync/atomic" +) + +var ( + // SupportedPolicies is the collection of policies registered + SupportedPolicies = make(map[string]func() Policy) +) + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + SupportedPolicies[name] = policy +} + +// Policy decides how a host will be selected from a pool. When all hosts are unhealthy, it is assumed the +// healthchecking failed. In this case each policy will *randomly* return a host from the pool to prevent +// no traffic to go through at all. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + +// Random is a policy that selects up hosts from a pool at random. +type Random struct{} + +// Select selects an up host at random from the specified pool. +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a down host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if host.Down() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// Spray is a policy that selects a host from a pool at random. This should be used as a last ditch +// attempt to get a host when all hosts are reporting unhealthy. +type Spray struct{} + +// Select selects an up host at random from the specified pool. +func (r *Spray) Select(pool HostPool) *UpstreamHost { + rnd := rand.Int() % len(pool) + randHost := pool[rnd] + log.Printf("[WARNING] All hosts reported as down, spraying to target: %s", randHost.Name) + return randHost +} + +// LeastConn is a policy that selects the host with the least connections. +type LeastConn struct{} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if host.Down() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// RoundRobin is a policy that selects hosts based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +// Select selects an up host from the pool using a round robin ordering scheme. +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is down, just ffwd to up host + for i := uint32(1); host.Down() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + return host +} diff --git a/plugin/pkg/healthcheck/policy_test.go b/plugin/pkg/healthcheck/policy_test.go new file mode 100644 index 000000000..4c667952c --- /dev/null +++ b/plugin/pkg/healthcheck/policy_test.go @@ -0,0 +1,143 @@ +package healthcheck + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +var workableServer *httptest.Server + +func TestMain(m *testing.M) { + workableServer = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // do nothing + })) + r := m.Run() + workableServer.Close() + os.Exit(r) +} + +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func testPool() HostPool { + pool := []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://shouldnot.resolve", // this shouldn't + }, + { + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := SupportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +} + +// TODO(miek): Disabled for now, we should get out of the habit of using +// realtime in these tests . +func testHealthCheck(t *testing.T) { + log.SetOutput(ioutil.Discard) + + u := &HealthCheck{ + Hosts: testPool(), + FailTimeout: 10 * time.Second, + Future: 60 * time.Second, + MaxFails: 1, + } + + u.healthCheck() + // sleep a bit, it's async now + time.Sleep(time.Duration(2 * time.Second)) + + if u.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !u.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + u := &HealthCheck{ + Hosts: testPool()[:3], + FailTimeout: 10 * time.Second, + Future: 60 * time.Second, + MaxFails: 1, + } + u.Hosts[0].OkUntil = time.Unix(0, 0) + u.Hosts[1].OkUntil = time.Unix(0, 0) + u.Hosts[2].OkUntil = time.Unix(0, 0) + if h := u.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + u.Hosts[2].OkUntil = time.Time{} + if h := u.Select(); h == nil { + t.Error("Expected select to not return nil") + } +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + // mark host as down + pool[0].OkUntil = time.Unix(0, 0) + h = rrPolicy.Select(pool) + if h != pool[1] { + t.Error("Expected third round robin host to be first host in the pool.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/plugin/pkg/nonwriter/nonwriter.go b/plugin/pkg/nonwriter/nonwriter.go new file mode 100644 index 000000000..7819a320f --- /dev/null +++ b/plugin/pkg/nonwriter/nonwriter.go @@ -0,0 +1,23 @@ +// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written. +package nonwriter + +import ( + "github.com/miekg/dns" +) + +// Writer is a type of ResponseWriter that captures the message, but never writes to the client. +type Writer struct { + dns.ResponseWriter + Msg *dns.Msg +} + +// New makes and returns a new NonWriter. +func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} } + +// WriteMsg records the message, but doesn't write it itself. +func (w *Writer) WriteMsg(res *dns.Msg) error { + w.Msg = res + return nil +} + +func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil } diff --git a/plugin/pkg/nonwriter/nonwriter_test.go b/plugin/pkg/nonwriter/nonwriter_test.go new file mode 100644 index 000000000..d8433af55 --- /dev/null +++ b/plugin/pkg/nonwriter/nonwriter_test.go @@ -0,0 +1,19 @@ +package nonwriter + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestNonWriter(t *testing.T) { + nw := New(nil) + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + if err := nw.WriteMsg(m); err != nil { + t.Errorf("Got error when writing to nonwriter: %s", err) + } + if x := nw.Msg.Question[0].Name; x != "example.org." { + t.Errorf("Expacted 'example.org.' got %q:", x) + } +} diff --git a/plugin/pkg/rcode/rcode.go b/plugin/pkg/rcode/rcode.go new file mode 100644 index 000000000..32863f0b2 --- /dev/null +++ b/plugin/pkg/rcode/rcode.go @@ -0,0 +1,16 @@ +package rcode + +import ( + "strconv" + + "github.com/miekg/dns" +) + +// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE +// value is unknown. +func ToString(rcode int) string { + if str, ok := dns.RcodeToString[rcode]; ok { + return str + } + return "RCODE" + strconv.Itoa(rcode) +} diff --git a/plugin/pkg/rcode/rcode_test.go b/plugin/pkg/rcode/rcode_test.go new file mode 100644 index 000000000..bfca32f1d --- /dev/null +++ b/plugin/pkg/rcode/rcode_test.go @@ -0,0 +1,29 @@ +package rcode + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestToString(t *testing.T) { + tests := []struct { + in int + expected string + }{ + { + dns.RcodeSuccess, + "NOERROR", + }, + { + 28, + "RCODE28", + }, + } + for i, test := range tests { + got := ToString(test.in) + if got != test.expected { + t.Errorf("Test %d, expected %s, got %s", i, test.expected, got) + } + } +} diff --git a/plugin/pkg/replacer/replacer.go b/plugin/pkg/replacer/replacer.go new file mode 100644 index 000000000..fc98e5d29 --- /dev/null +++ b/plugin/pkg/replacer/replacer.go @@ -0,0 +1,161 @@ +package replacer + +import ( + "strconv" + "strings" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Replacer is a type which can replace placeholder +// substrings in a string with actual values from a +// dns.Msg and responseRecorder. Always use +// NewReplacer to get one of these. +type Replacer interface { + Replace(string) string + Set(key, value string) +} + +type replacer struct { + replacements map[string]string + emptyValue string +} + +// New makes a new replacer based on r and rr. +// Do not create a new replacer until r and rr have all +// the needed values, because this function copies those +// values into the replacer. rr may be nil if it is not +// available. emptyValue should be the string that is used +// in place of empty string (can still be empty string). +func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer { + req := request.Request{W: rr, Req: r} + rep := replacer{ + replacements: map[string]string{ + "{type}": req.Type(), + "{name}": req.Name(), + "{class}": req.Class(), + "{proto}": req.Proto(), + "{when}": func() string { + return time.Now().Format(timeFormat) + }(), + "{size}": strconv.Itoa(req.Len()), + "{remote}": req.IP(), + "{port}": req.Port(), + }, + emptyValue: emptyValue, + } + if rr != nil { + rcode := dns.RcodeToString[rr.Rcode] + if rcode == "" { + rcode = strconv.Itoa(rr.Rcode) + } + rep.replacements["{rcode}"] = rcode + rep.replacements["{rsize}"] = strconv.Itoa(rr.Len) + rep.replacements["{duration}"] = time.Since(rr.Start).String() + if rr.Msg != nil { + rep.replacements[headerReplacer+"rflags}"] = flagsToString(rr.Msg.MsgHdr) + } + } + + // Header placeholders (case-insensitive) + rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id)) + rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(r.Opcode) + rep.replacements[headerReplacer+"do}"] = boolToString(req.Do()) + rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size()) + + return rep +} + +// Replace performs a replacement of values on s and returns +// the string with the replaced values. +func (r replacer) Replace(s string) string { + // Header replacements - these are case-insensitive, so we can't just use strings.Replace() + for strings.Contains(s, headerReplacer) { + idxStart := strings.Index(s, headerReplacer) + endOffset := idxStart + len(headerReplacer) + idxEnd := strings.Index(s[endOffset:], "}") + if idxEnd > -1 { + placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) + replacement := r.replacements[placeholder] + if replacement == "" { + replacement = r.emptyValue + } + s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] + } else { + break + } + } + + // Regular replacements - these are easier because they're case-sensitive + for placeholder, replacement := range r.replacements { + if replacement == "" { + replacement = r.emptyValue + } + s = strings.Replace(s, placeholder, replacement, -1) + } + + return s +} + +// Set sets key to value in the replacements map. +func (r replacer) Set(key, value string) { + r.replacements["{"+key+"}"] = value +} + +func boolToString(b bool) string { + if b { + return "true" + } + return "false" +} + +// flagsToString checks all header flags and returns those +// that are set as a string separated with commas +func flagsToString(h dns.MsgHdr) string { + flags := make([]string, 7) + i := 0 + + if h.Response { + flags[i] = "qr" + i++ + } + + if h.Authoritative { + flags[i] = "aa" + i++ + } + if h.Truncated { + flags[i] = "tc" + i++ + } + if h.RecursionDesired { + flags[i] = "rd" + i++ + } + if h.RecursionAvailable { + flags[i] = "ra" + i++ + } + if h.Zero { + flags[i] = "z" + i++ + } + if h.AuthenticatedData { + flags[i] = "ad" + i++ + } + if h.CheckingDisabled { + flags[i] = "cd" + i++ + } + return strings.Join(flags[:i], ",") +} + +const ( + timeFormat = "02/Jan/2006:15:04:05 -0700" + headerReplacer = "{>" +) diff --git a/plugin/pkg/replacer/replacer_test.go b/plugin/pkg/replacer/replacer_test.go new file mode 100644 index 000000000..95c3bbd52 --- /dev/null +++ b/plugin/pkg/replacer/replacer_test.go @@ -0,0 +1,61 @@ +package replacer + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestNewReplacer(t *testing.T) { + w := dnsrecorder.New(&test.ResponseWriter{}) + + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + + replaceValues := New(r, w, "") + + switch v := replaceValues.(type) { + case replacer: + + if v.replacements["{type}"] != "HINFO" { + t.Errorf("Expected type to be HINFO, got %q", v.replacements["{type}"]) + } + if v.replacements["{name}"] != "example.org." { + t.Errorf("Expected request name to be example.org., got %q", v.replacements["{name}"]) + } + if v.replacements["{size}"] != "29" { // size of request + t.Errorf("Expected size to be 29, got %q", v.replacements["{size}"]) + } + + default: + t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n") + } +} + +func TestSet(t *testing.T) { + w := dnsrecorder.New(&test.ResponseWriter{}) + + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + + repl := New(r, w, "") + + repl.Set("name", "coredns.io.") + repl.Set("type", "A") + repl.Set("size", "20") + + if repl.Replace("This name is {name}") != "This name is coredns.io." { + t.Error("Expected name replacement failed") + } + if repl.Replace("This type is {type}") != "This type is A" { + t.Error("Expected type replacement failed") + } + if repl.Replace("The request size is {size}") != "The request size is 20" { + t.Error("Expected size replacement failed") + } +} diff --git a/plugin/pkg/response/classify.go b/plugin/pkg/response/classify.go new file mode 100644 index 000000000..2e705cb0b --- /dev/null +++ b/plugin/pkg/response/classify.go @@ -0,0 +1,61 @@ +package response + +import "fmt" + +// Class holds sets of Types +type Class int + +const ( + // All is a meta class encompassing all the classes. + All Class = iota + // Success is a class for a successful response. + Success + // Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist) + Denial + // Error is a class for errors, right now defined as not Success and not Denial + Error +) + +func (c Class) String() string { + switch c { + case All: + return "all" + case Success: + return "success" + case Denial: + return "denial" + case Error: + return "error" + } + return "" +} + +// ClassFromString returns the class from the string s. If not class matches +// the All class and an error are returned +func ClassFromString(s string) (Class, error) { + switch s { + case "all": + return All, nil + case "success": + return Success, nil + case "denial": + return Denial, nil + case "error": + return Error, nil + } + return All, fmt.Errorf("invalid Class: %s", s) +} + +// Classify classifies the Type t, it returns its Class. +func Classify(t Type) Class { + switch t { + case NoError, Delegation: + return Success + case NameError, NoData: + return Denial + case OtherError: + fallthrough + default: + return Error + } +} diff --git a/plugin/pkg/response/typify.go b/plugin/pkg/response/typify.go new file mode 100644 index 000000000..7cfaab497 --- /dev/null +++ b/plugin/pkg/response/typify.go @@ -0,0 +1,146 @@ +package response + +import ( + "fmt" + "time" + + "github.com/miekg/dns" +) + +// Type is the type of the message. +type Type int + +const ( + // NoError indicates a positive reply + NoError Type = iota + // NameError is a NXDOMAIN in header, SOA in auth. + NameError + // NoData indicates name found, but not the type: NOERROR in header, SOA in auth. + NoData + // Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked). + Delegation + // Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR. + Meta + // Update is an dynamic update message. + Update + // OtherError indicates any other error: don't cache these. + OtherError +) + +var toString = map[Type]string{ + NoError: "NOERROR", + NameError: "NXDOMAIN", + NoData: "NODATA", + Delegation: "DELEGATION", + Meta: "META", + Update: "UPDATE", + OtherError: "OTHERERROR", +} + +func (t Type) String() string { return toString[t] } + +// TypeFromString returns the type from the string s. If not type matches +// the OtherError type and an error are returned. +func TypeFromString(s string) (Type, error) { + for t, str := range toString { + if s == str { + return t, nil + } + } + return NoError, fmt.Errorf("invalid Type: %s", s) +} + +// Typify classifies a message, it returns the Type. +func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) { + if m == nil { + return OtherError, nil + } + opt := m.IsEdns0() + do := false + if opt != nil { + do = opt.Do() + } + + if m.Opcode == dns.OpcodeUpdate { + return Update, opt + } + + // Check transfer and update first + if m.Opcode == dns.OpcodeNotify { + return Meta, opt + } + + if len(m.Question) > 0 { + if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR { + return Meta, opt + } + } + + // If our message contains any expired sigs and we care about that, we should return expired + if do { + if expired := typifyExpired(m, t); expired { + return OtherError, opt + } + } + + if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { + return NoError, opt + } + + soa := false + ns := 0 + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeSOA { + soa = true + continue + } + if r.Header().Rrtype == dns.TypeNS { + ns++ + } + } + + // Check length of different sections, and drop stuff that is just to large? TODO(miek). + + if soa && m.Rcode == dns.RcodeSuccess { + return NoData, opt + } + if soa && m.Rcode == dns.RcodeNameError { + return NameError, opt + } + + if ns > 0 && m.Rcode == dns.RcodeSuccess { + return Delegation, opt + } + + if m.Rcode == dns.RcodeSuccess { + return NoError, opt + } + + return OtherError, opt +} + +func typifyExpired(m *dns.Msg, t time.Time) bool { + if expired := typifyExpiredRRSIG(m.Answer, t); expired { + return true + } + if expired := typifyExpiredRRSIG(m.Ns, t); expired { + return true + } + if expired := typifyExpiredRRSIG(m.Extra, t); expired { + return true + } + return false +} + +func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool { + for _, r := range rrs { + if r.Header().Rrtype != dns.TypeRRSIG { + continue + } + ok := r.(*dns.RRSIG).ValidityPeriod(t) + if !ok { + return true + } + } + return false +} diff --git a/plugin/pkg/response/typify_test.go b/plugin/pkg/response/typify_test.go new file mode 100644 index 000000000..faeaf3579 --- /dev/null +++ b/plugin/pkg/response/typify_test.go @@ -0,0 +1,84 @@ +package response + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestTypifyNilMsg(t *testing.T) { + var m *dns.Msg + + ty, _ := Typify(m, time.Now().UTC()) + if ty != OtherError { + t.Errorf("message wrongly typified, expected OtherError, got %s", ty) + } +} + +func TestTypifyDelegation(t *testing.T) { + m := delegationMsg() + mt, _ := Typify(m, time.Now().UTC()) + if mt != Delegation { + t.Errorf("message is wrongly typified, expected Delegation, got %s", mt) + } +} + +func TestTypifyRRSIG(t *testing.T) { + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + m := delegationMsgRRSIGOK() + if mt, _ := Typify(m, utc); mt != Delegation { + t.Errorf("message is wrongly typified, expected Delegation, got %s", mt) + } + + // Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs. + m = delegationMsgRRSIGFail() + if mt, _ := Typify(m, utc); mt != Delegation { + t.Errorf("message is wrongly typified, expected Delegation, got %s", mt) + } + + m = delegationMsgRRSIGFail() + m = addOpt(m) + if mt, _ := Typify(m, utc); mt != OtherError { + t.Errorf("message is wrongly typified, expected OtherError, got %s", mt) + } +} + +func delegationMsg() *dns.Msg { + return &dns.Msg{ + Ns: []dns.RR{ + test.NS("miek.nl. 3600 IN NS linode.atoom.net."), + test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."), + test.NS("miek.nl. 3600 IN NS omval.tednet.nl."), + }, + Extra: []dns.RR{ + test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"), + test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"), + }, + } +} + +func delegationMsgRRSIGOK() *dns.Msg { + del := delegationMsg() + del.Ns = append(del.Ns, + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), + ) + return del +} + +func delegationMsgRRSIGFail() *dns.Msg { + del := delegationMsg() + del.Ns = append(del.Ns, + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), + ) + return del +} + +func addOpt(m *dns.Msg) *dns.Msg { + m.Extra = append(m.Extra, test.OPT(4096, true)) + return m +} diff --git a/plugin/pkg/singleflight/singleflight.go b/plugin/pkg/singleflight/singleflight.go new file mode 100644 index 000000000..365e3ef58 --- /dev/null +++ b/plugin/pkg/singleflight/singleflight.go @@ -0,0 +1,64 @@ +/* +Copyright 2012 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight + +import "sync" + +// call is an in-flight or completed Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// Group represents a class of work and forms a namespace in which +// units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[uint32]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[uint32]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} diff --git a/plugin/pkg/singleflight/singleflight_test.go b/plugin/pkg/singleflight/singleflight_test.go new file mode 100644 index 000000000..d1d406e0b --- /dev/null +++ b/plugin/pkg/singleflight/singleflight_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2012 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package singleflight + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group + v, err := g.Do(1, func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group + someErr := errors.New("Some error") + v, err := g.Do(1, func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr", err) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do(1, fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("number of calls = %d; want 1", got) + } +} diff --git a/plugin/pkg/tls/tls.go b/plugin/pkg/tls/tls.go new file mode 100644 index 000000000..6fc10dd8e --- /dev/null +++ b/plugin/pkg/tls/tls.go @@ -0,0 +1,128 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "time" +) + +// NewTLSConfigFromArgs returns a TLS config based upon the passed +// in list of arguments. Typically these come straight from the +// Corefile. +// no args +// - creates a Config with no cert and using system CAs +// - use for a client that talks to a server with a public signed cert (CA installed in system) +// - the client will not be authenticated by the server since there is no cert +// one arg: the path to CA PEM file +// - creates a Config with no cert using a specific CA +// - use for a client that talks to a server with a private signed cert (CA not installed in system) +// - the client will not be authenticated by the server since there is no cert +// two args: path to cert PEM file, the path to private key PEM file +// - creates a Config with a cert, using system CAs to validate the other end +// - use for: +// - a server; or, +// - a client that talks to a server with a public cert and needs certificate-based authentication +// - the other end will authenticate this end via the provided cert +// - the cert of the other end will be verified via system CAs +// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file +// - creates a Config with the cert, using specified CA to validate the other end +// - use for: +// - a server; or, +// - a client that talks to a server with a privately signed cert and needs certificate-based +// authentication +// - the other end will authenticate this end via the provided cert +// - this end will verify the other end's cert using the specified CA +func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) { + var err error + var c *tls.Config + switch len(args) { + case 0: + // No client cert, use system CA + c, err = NewTLSClientConfig("") + case 1: + // No client cert, use specified CA + c, err = NewTLSClientConfig(args[0]) + case 2: + // Client cert, use system CA + c, err = NewTLSConfig(args[0], args[1], "") + case 3: + // Client cert, use specified CA + c, err = NewTLSConfig(args[0], args[1], args[2]) + default: + err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args)) + } + if err != nil { + return nil, err + } + return c, nil +} + +// NewTLSConfig returns a TLS config that includes a certificate +// Use for server TLS config or when using a client certificate +// If caPath is empty, system CAs will be used +func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("could not load TLS cert: %s", err) + } + + roots, err := loadRoots(caPath) + if err != nil { + return nil, err + } + + return &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}, nil +} + +// NewTLSClientConfig returns a TLS config for a client connection +// If caPath is empty, system CAs will be used +func NewTLSClientConfig(caPath string) (*tls.Config, error) { + roots, err := loadRoots(caPath) + if err != nil { + return nil, err + } + + return &tls.Config{RootCAs: roots}, nil +} + +func loadRoots(caPath string) (*x509.CertPool, error) { + if caPath == "" { + return nil, nil + } + + roots := x509.NewCertPool() + pem, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("error reading %s: %s", caPath, err) + } + ok := roots.AppendCertsFromPEM(pem) + if !ok { + return nil, fmt.Errorf("could not read root certs: %s", err) + } + return roots, nil +} + +// NewHTTPSTransport returns an HTTP transport configured using tls.Config +func NewHTTPSTransport(cc *tls.Config) *http.Transport { + // this seems like a bad idea but was here in the previous version + if cc != nil { + cc.InsecureSkipVerify = true + } + + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: cc, + MaxIdleConnsPerHost: 25, + } + + return tr +} diff --git a/plugin/pkg/tls/tls_test.go b/plugin/pkg/tls/tls_test.go new file mode 100644 index 000000000..8c88bfcc4 --- /dev/null +++ b/plugin/pkg/tls/tls_test.go @@ -0,0 +1,101 @@ +package tls + +import ( + "path/filepath" + "testing" + + "github.com/coredns/coredns/plugin/test" +) + +func getPEMFiles(t *testing.T) (rmFunc func(), cert, key, ca string) { + tempDir, rmFunc, err := test.WritePEMFiles("") + if err != nil { + t.Fatalf("Could not write PEM files: %s", err) + } + + cert = filepath.Join(tempDir, "cert.pem") + key = filepath.Join(tempDir, "key.pem") + ca = filepath.Join(tempDir, "ca.pem") + + return +} + +func TestNewTLSConfig(t *testing.T) { + rmFunc, cert, key, ca := getPEMFiles(t) + defer rmFunc() + + _, err := NewTLSConfig(cert, key, ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } +} + +func TestNewTLSClientConfig(t *testing.T) { + rmFunc, _, _, ca := getPEMFiles(t) + defer rmFunc() + + _, err := NewTLSClientConfig(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } +} + +func TestNewTLSConfigFromArgs(t *testing.T) { + rmFunc, cert, key, ca := getPEMFiles(t) + defer rmFunc() + + _, err := NewTLSConfigFromArgs() + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + + c, err := NewTLSConfigFromArgs(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs == nil { + t.Error("RootCAs should not be nil when one arg passed") + } + + c, err = NewTLSConfigFromArgs(cert, key) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs != nil { + t.Error("RootCAs should be nil when two args passed") + } + if len(c.Certificates) != 1 { + t.Error("Certificates should have a single entry when two args passed") + } + args := []string{cert, key, ca} + c, err = NewTLSConfigFromArgs(args...) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs == nil { + t.Error("RootCAs should not be nil when three args passed") + } + if len(c.Certificates) != 1 { + t.Error("Certificateis should have a single entry when three args passed") + } +} + +func TestNewHTTPSTransport(t *testing.T) { + rmFunc, _, _, ca := getPEMFiles(t) + defer rmFunc() + + cc, err := NewTLSClientConfig(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + + tr := NewHTTPSTransport(cc) + if tr == nil { + t.Errorf("Failed to create https transport with cc") + } + + tr = NewHTTPSTransport(nil) + if tr == nil { + t.Errorf("Failed to create https transport without cc") + } +} diff --git a/plugin/pkg/trace/trace.go b/plugin/pkg/trace/trace.go new file mode 100644 index 000000000..35a8ddabd --- /dev/null +++ b/plugin/pkg/trace/trace.go @@ -0,0 +1,12 @@ +package trace + +import ( + "github.com/coredns/coredns/plugin" + ot "github.com/opentracing/opentracing-go" +) + +// Trace holds the tracer and endpoint info +type Trace interface { + plugin.Handler + Tracer() ot.Tracer +} |