aboutsummaryrefslogtreecommitdiff
path: root/plugin/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/pkg')
-rw-r--r--plugin/pkg/cache/cache.go129
-rw-r--r--plugin/pkg/cache/cache_test.go31
-rw-r--r--plugin/pkg/cache/shard_test.go60
-rw-r--r--plugin/pkg/dnsrecorder/recorder.go58
-rw-r--r--plugin/pkg/dnsrecorder/recorder_test.go28
-rw-r--r--plugin/pkg/dnsutil/cname.go15
-rw-r--r--plugin/pkg/dnsutil/cname_test.go55
-rw-r--r--plugin/pkg/dnsutil/dedup.go12
-rw-r--r--plugin/pkg/dnsutil/doc.go2
-rw-r--r--plugin/pkg/dnsutil/host.go82
-rw-r--r--plugin/pkg/dnsutil/host_test.go85
-rw-r--r--plugin/pkg/dnsutil/join.go19
-rw-r--r--plugin/pkg/dnsutil/join_test.go20
-rw-r--r--plugin/pkg/dnsutil/reverse.go68
-rw-r--r--plugin/pkg/dnsutil/reverse_test.go51
-rw-r--r--plugin/pkg/dnsutil/zone.go20
-rw-r--r--plugin/pkg/dnsutil/zone_test.go39
-rw-r--r--plugin/pkg/edns/edns.go46
-rw-r--r--plugin/pkg/edns/edns_test.go37
-rw-r--r--plugin/pkg/healthcheck/healthcheck.go243
-rw-r--r--plugin/pkg/healthcheck/policy.go120
-rw-r--r--plugin/pkg/healthcheck/policy_test.go143
-rw-r--r--plugin/pkg/nonwriter/nonwriter.go23
-rw-r--r--plugin/pkg/nonwriter/nonwriter_test.go19
-rw-r--r--plugin/pkg/rcode/rcode.go16
-rw-r--r--plugin/pkg/rcode/rcode_test.go29
-rw-r--r--plugin/pkg/replacer/replacer.go161
-rw-r--r--plugin/pkg/replacer/replacer_test.go61
-rw-r--r--plugin/pkg/response/classify.go61
-rw-r--r--plugin/pkg/response/typify.go146
-rw-r--r--plugin/pkg/response/typify_test.go84
-rw-r--r--plugin/pkg/singleflight/singleflight.go64
-rw-r--r--plugin/pkg/singleflight/singleflight_test.go85
-rw-r--r--plugin/pkg/tls/tls.go128
-rw-r--r--plugin/pkg/tls/tls_test.go101
-rw-r--r--plugin/pkg/trace/trace.go12
36 files changed, 2353 insertions, 0 deletions
diff --git a/plugin/pkg/cache/cache.go b/plugin/pkg/cache/cache.go
new file mode 100644
index 000000000..56cae2180
--- /dev/null
+++ b/plugin/pkg/cache/cache.go
@@ -0,0 +1,129 @@
+// Package cache implements a cache. The cache hold 256 shards, each shard
+// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
+// just randomly evicts elements when it gets full.
+package cache
+
+import (
+ "hash/fnv"
+ "sync"
+)
+
+// Hash returns the FNV hash of what.
+func Hash(what []byte) uint32 {
+ h := fnv.New32()
+ h.Write(what)
+ return h.Sum32()
+}
+
+// Cache is cache.
+type Cache struct {
+ shards [shardSize]*shard
+}
+
+// shard is a cache with random eviction.
+type shard struct {
+ items map[uint32]interface{}
+ size int
+
+ sync.RWMutex
+}
+
+// New returns a new cache.
+func New(size int) *Cache {
+ ssize := size / shardSize
+ if ssize < 512 {
+ ssize = 512
+ }
+
+ c := &Cache{}
+
+ // Initialize all the shards
+ for i := 0; i < shardSize; i++ {
+ c.shards[i] = newShard(ssize)
+ }
+ return c
+}
+
+// Add adds a new element to the cache. If the element already exists it is overwritten.
+func (c *Cache) Add(key uint32, el interface{}) {
+ shard := key & (shardSize - 1)
+ c.shards[shard].Add(key, el)
+}
+
+// Get looks up element index under key.
+func (c *Cache) Get(key uint32) (interface{}, bool) {
+ shard := key & (shardSize - 1)
+ return c.shards[shard].Get(key)
+}
+
+// Remove removes the element indexed with key.
+func (c *Cache) Remove(key uint32) {
+ shard := key & (shardSize - 1)
+ c.shards[shard].Remove(key)
+}
+
+// Len returns the number of elements in the cache.
+func (c *Cache) Len() int {
+ l := 0
+ for _, s := range c.shards {
+ l += s.Len()
+ }
+ return l
+}
+
+// newShard returns a new shard with size.
+func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} }
+
+// Add adds element indexed by key into the cache. Any existing element is overwritten
+func (s *shard) Add(key uint32, el interface{}) {
+ l := s.Len()
+ if l+1 > s.size {
+ s.Evict()
+ }
+
+ s.Lock()
+ s.items[key] = el
+ s.Unlock()
+}
+
+// Remove removes the element indexed by key from the cache.
+func (s *shard) Remove(key uint32) {
+ s.Lock()
+ delete(s.items, key)
+ s.Unlock()
+}
+
+// Evict removes a random element from the cache.
+func (s *shard) Evict() {
+ s.Lock()
+ defer s.Unlock()
+
+ key := -1
+ for k := range s.items {
+ key = int(k)
+ break
+ }
+ if key == -1 {
+ // empty cache
+ return
+ }
+ delete(s.items, uint32(key))
+}
+
+// Get looks up the element indexed under key.
+func (s *shard) Get(key uint32) (interface{}, bool) {
+ s.RLock()
+ el, found := s.items[key]
+ s.RUnlock()
+ return el, found
+}
+
+// Len returns the current length of the cache.
+func (s *shard) Len() int {
+ s.RLock()
+ l := len(s.items)
+ s.RUnlock()
+ return l
+}
+
+const shardSize = 256
diff --git a/plugin/pkg/cache/cache_test.go b/plugin/pkg/cache/cache_test.go
new file mode 100644
index 000000000..2c92bf438
--- /dev/null
+++ b/plugin/pkg/cache/cache_test.go
@@ -0,0 +1,31 @@
+package cache
+
+import "testing"
+
+func TestCacheAddAndGet(t *testing.T) {
+ c := New(4)
+ c.Add(1, 1)
+
+ if _, found := c.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestCacheLen(t *testing.T) {
+ c := New(4)
+
+ c.Add(1, 1)
+ if l := c.Len(); l != 1 {
+ t.Fatalf("Cache size should %d, got %d", 1, l)
+ }
+
+ c.Add(1, 1)
+ if l := c.Len(); l != 1 {
+ t.Fatalf("Cache size should %d, got %d", 1, l)
+ }
+
+ c.Add(2, 2)
+ if l := c.Len(); l != 2 {
+ t.Fatalf("Cache size should %d, got %d", 2, l)
+ }
+}
diff --git a/plugin/pkg/cache/shard_test.go b/plugin/pkg/cache/shard_test.go
new file mode 100644
index 000000000..26675cee1
--- /dev/null
+++ b/plugin/pkg/cache/shard_test.go
@@ -0,0 +1,60 @@
+package cache
+
+import "testing"
+
+func TestShardAddAndGet(t *testing.T) {
+ s := newShard(4)
+ s.Add(1, 1)
+
+ if _, found := s.Get(1); !found {
+ t.Fatal("Failed to find inserted record")
+ }
+}
+
+func TestShardLen(t *testing.T) {
+ s := newShard(4)
+
+ s.Add(1, 1)
+ if l := s.Len(); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Add(1, 1)
+ if l := s.Len(); l != 1 {
+ t.Fatalf("Shard size should %d, got %d", 1, l)
+ }
+
+ s.Add(2, 2)
+ if l := s.Len(); l != 2 {
+ t.Fatalf("Shard size should %d, got %d", 2, l)
+ }
+}
+
+func TestShardEvict(t *testing.T) {
+ s := newShard(1)
+ s.Add(1, 1)
+ s.Add(2, 2)
+ // 1 should be gone
+
+ if _, found := s.Get(1); found {
+ t.Fatal("Found item that should have been evicted")
+ }
+}
+
+func TestShardLenEvict(t *testing.T) {
+ s := newShard(4)
+ s.Add(1, 1)
+ s.Add(2, 1)
+ s.Add(3, 1)
+ s.Add(4, 1)
+
+ if l := s.Len(); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+
+ // This should evict one element
+ s.Add(5, 1)
+ if l := s.Len(); l != 4 {
+ t.Fatalf("Shard size should %d, got %d", 4, l)
+ }
+}
diff --git a/plugin/pkg/dnsrecorder/recorder.go b/plugin/pkg/dnsrecorder/recorder.go
new file mode 100644
index 000000000..3ca5f00d0
--- /dev/null
+++ b/plugin/pkg/dnsrecorder/recorder.go
@@ -0,0 +1,58 @@
+// Package dnsrecorder allows you to record a DNS response when it is send to the client.
+package dnsrecorder
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Recorder is a type of ResponseWriter that captures
+// the rcode code written to it and also the size of the message
+// written in the response. A rcode code does not have
+// to be written, however, in which case 0 must be assumed.
+// It is best to have the constructor initialize this type
+// with that default status code.
+type Recorder struct {
+ dns.ResponseWriter
+ Rcode int
+ Len int
+ Msg *dns.Msg
+ Start time.Time
+}
+
+// New makes and returns a new Recorder,
+// which captures the DNS rcode from the ResponseWriter
+// and also the length of the response message written through it.
+func New(w dns.ResponseWriter) *Recorder {
+ return &Recorder{
+ ResponseWriter: w,
+ Rcode: 0,
+ Msg: nil,
+ Start: time.Now(),
+ }
+}
+
+// WriteMsg records the status code and calls the
+// underlying ResponseWriter's WriteMsg method.
+func (r *Recorder) WriteMsg(res *dns.Msg) error {
+ r.Rcode = res.Rcode
+ // We may get called multiple times (axfr for instance).
+ // Save the last message, but add the sizes.
+ r.Len += res.Len()
+ r.Msg = res
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+// Write is a wrapper that records the length of the message that gets written.
+func (r *Recorder) Write(buf []byte) (int, error) {
+ n, err := r.ResponseWriter.Write(buf)
+ if err == nil {
+ r.Len += n
+ }
+ return n, err
+}
+
+// Hijack implements dns.Hijacker. It simply wraps the underlying
+// ResponseWriter's Hijack method if there is one, or returns an error.
+func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }
diff --git a/plugin/pkg/dnsrecorder/recorder_test.go b/plugin/pkg/dnsrecorder/recorder_test.go
new file mode 100644
index 000000000..c9c2f6ce4
--- /dev/null
+++ b/plugin/pkg/dnsrecorder/recorder_test.go
@@ -0,0 +1,28 @@
+package dnsrecorder
+
+/*
+func TestNewResponseRecorder(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ if !(recordRequest.ResponseWriter == w) {
+ t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
+ }
+ if recordRequest.status != http.StatusOK {
+ t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
+ }
+}
+
+func TestWrite(t *testing.T) {
+ w := httptest.NewRecorder()
+ responseTestString := "test"
+ recordRequest := NewResponseRecorder(w)
+ buf := []byte(responseTestString)
+ recordRequest.Write(buf)
+ if recordRequest.size != len(buf) {
+ t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
+ }
+ if w.Body.String() != responseTestString {
+ t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
+ }
+}
+*/
diff --git a/plugin/pkg/dnsutil/cname.go b/plugin/pkg/dnsutil/cname.go
new file mode 100644
index 000000000..281e03218
--- /dev/null
+++ b/plugin/pkg/dnsutil/cname.go
@@ -0,0 +1,15 @@
+package dnsutil
+
+import "github.com/miekg/dns"
+
+// DuplicateCNAME returns true if r already exists in records.
+func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
+ for _, rec := range records {
+ if v, ok := rec.(*dns.CNAME); ok {
+ if v.Target == r.Target {
+ return true
+ }
+ }
+ }
+ return false
+}
diff --git a/plugin/pkg/dnsutil/cname_test.go b/plugin/pkg/dnsutil/cname_test.go
new file mode 100644
index 000000000..5fb8d3029
--- /dev/null
+++ b/plugin/pkg/dnsutil/cname_test.go
@@ -0,0 +1,55 @@
+package dnsutil
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestDuplicateCNAME(t *testing.T) {
+ tests := []struct {
+ cname string
+ records []string
+ expected bool
+ }{
+ {
+ "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
+ []string{
+ "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
+ "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
+ },
+ true,
+ },
+ {
+ "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
+ []string{
+ "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
+ },
+ false,
+ },
+ {
+ "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
+ []string{},
+ false,
+ },
+ }
+ for i, test := range tests {
+ cnameRR, err := dns.NewRR(test.cname)
+ if err != nil {
+ t.Fatalf("Test %d, cname ('%s') error (%s)!", i, test.cname, err)
+ }
+ cname := cnameRR.(*dns.CNAME)
+ records := []dns.RR{}
+ for j, r := range test.records {
+ rr, err := dns.NewRR(r)
+ if err != nil {
+ t.Fatalf("Test %d, record %d ('%s') error (%s)!", i, j, r, err)
+ }
+ records = append(records, rr)
+ }
+ got := DuplicateCNAME(cname, records)
+ if got != test.expected {
+ t.Errorf("Test %d, expected '%v', got '%v' for CNAME ('%s') and RECORDS (%v)", i, test.expected, got, test.cname, test.records)
+ }
+ }
+}
diff --git a/plugin/pkg/dnsutil/dedup.go b/plugin/pkg/dnsutil/dedup.go
new file mode 100644
index 000000000..dae656a01
--- /dev/null
+++ b/plugin/pkg/dnsutil/dedup.go
@@ -0,0 +1,12 @@
+package dnsutil
+
+import "github.com/miekg/dns"
+
+// Dedup de-duplicates a message.
+func Dedup(m *dns.Msg) *dns.Msg {
+ // TODO(miek): expensive!
+ m.Answer = dns.Dedup(m.Answer, nil)
+ m.Ns = dns.Dedup(m.Ns, nil)
+ m.Extra = dns.Dedup(m.Extra, nil)
+ return m
+}
diff --git a/plugin/pkg/dnsutil/doc.go b/plugin/pkg/dnsutil/doc.go
new file mode 100644
index 000000000..75d1e8c7a
--- /dev/null
+++ b/plugin/pkg/dnsutil/doc.go
@@ -0,0 +1,2 @@
+// Package dnsutil contains DNS related helper functions.
+package dnsutil
diff --git a/plugin/pkg/dnsutil/host.go b/plugin/pkg/dnsutil/host.go
new file mode 100644
index 000000000..aaab586e8
--- /dev/null
+++ b/plugin/pkg/dnsutil/host.go
@@ -0,0 +1,82 @@
+package dnsutil
+
+import (
+ "fmt"
+ "net"
+ "os"
+
+ "github.com/miekg/dns"
+)
+
+// ParseHostPortOrFile parses the strings in s, each string can either be a address,
+// address:port or a filename. The address part is checked and the filename case a
+// resolv.conf like file is parsed and the nameserver found are returned.
+func ParseHostPortOrFile(s ...string) ([]string, error) {
+ var servers []string
+ for _, host := range s {
+ addr, _, err := net.SplitHostPort(host)
+ if err != nil {
+ // Parse didn't work, it is not a addr:port combo
+ if net.ParseIP(host) == nil {
+ // Not an IP address.
+ ss, err := tryFile(host)
+ if err == nil {
+ servers = append(servers, ss...)
+ continue
+ }
+ return servers, fmt.Errorf("not an IP address or file: %q", host)
+ }
+ ss := net.JoinHostPort(host, "53")
+ servers = append(servers, ss)
+ continue
+ }
+
+ if net.ParseIP(addr) == nil {
+ // No an IP address.
+ ss, err := tryFile(host)
+ if err == nil {
+ servers = append(servers, ss...)
+ continue
+ }
+ return servers, fmt.Errorf("not an IP address or file: %q", host)
+ }
+ servers = append(servers, host)
+ }
+ return servers, nil
+}
+
+// Try to open this is a file first.
+func tryFile(s string) ([]string, error) {
+ c, err := dns.ClientConfigFromFile(s)
+ if err == os.ErrNotExist {
+ return nil, fmt.Errorf("failed to open file %q: %q", s, err)
+ } else if err != nil {
+ return nil, err
+ }
+
+ servers := []string{}
+ for _, s := range c.Servers {
+ servers = append(servers, net.JoinHostPort(s, c.Port))
+ }
+ return servers, nil
+}
+
+// ParseHostPort will check if the host part is a valid IP address, if the
+// IP address is valid, but no port is found, defaultPort is added.
+func ParseHostPort(s, defaultPort string) (string, error) {
+ addr, port, err := net.SplitHostPort(s)
+ if port == "" {
+ port = defaultPort
+ }
+ if err != nil {
+ if net.ParseIP(s) == nil {
+ return "", fmt.Errorf("must specify an IP address: `%s'", s)
+ }
+ return net.JoinHostPort(s, port), nil
+ }
+
+ if net.ParseIP(addr) == nil {
+ return "", fmt.Errorf("must specify an IP address: `%s'", addr)
+ }
+ return net.JoinHostPort(addr, port), nil
+}
diff --git a/plugin/pkg/dnsutil/host_test.go b/plugin/pkg/dnsutil/host_test.go
new file mode 100644
index 000000000..cc55f4570
--- /dev/null
+++ b/plugin/pkg/dnsutil/host_test.go
@@ -0,0 +1,85 @@
+package dnsutil
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+func TestParseHostPortOrFile(t *testing.T) {
+ tests := []struct {
+ in string
+ expected string
+ shouldErr bool
+ }{
+ {
+ "8.8.8.8",
+ "8.8.8.8:53",
+ false,
+ },
+ {
+ "8.8.8.8:153",
+ "8.8.8.8:153",
+ false,
+ },
+ {
+ "/etc/resolv.conf:53",
+ "",
+ true,
+ },
+ {
+ "resolv.conf",
+ "127.0.0.1:53",
+ false,
+ },
+ }
+
+ err := ioutil.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600)
+ if err != nil {
+ t.Fatalf("Failed to write test resolv.conf")
+ }
+ defer os.Remove("resolv.conf")
+
+ for i, tc := range tests {
+ got, err := ParseHostPortOrFile(tc.in)
+ if err == nil && tc.shouldErr {
+ t.Errorf("Test %d, expected error, got nil", i)
+ continue
+ }
+ if err != nil && tc.shouldErr {
+ continue
+ }
+ if got[0] != tc.expected {
+ t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got[0])
+ }
+ }
+}
+
+func TestParseHostPort(t *testing.T) {
+ tests := []struct {
+ in string
+ expected string
+ shouldErr bool
+ }{
+ {"8.8.8.8:53", "8.8.8.8:53", false},
+ {"a.a.a.a:153", "", true},
+ {"8.8.8.8", "8.8.8.8:53", false},
+ {"8.8.8.8:", "8.8.8.8:53", false},
+ {"8.8.8.8::53", "", true},
+ {"resolv.conf", "", true},
+ }
+
+ for i, tc := range tests {
+ got, err := ParseHostPort(tc.in, "53")
+ if err == nil && tc.shouldErr {
+ t.Errorf("Test %d, expected error, got nil", i)
+ continue
+ }
+ if err != nil && !tc.shouldErr {
+ t.Errorf("Test %d, expected no error, got %q", i, err)
+ }
+ if got != tc.expected {
+ t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got)
+ }
+ }
+}
diff --git a/plugin/pkg/dnsutil/join.go b/plugin/pkg/dnsutil/join.go
new file mode 100644
index 000000000..515bf3dad
--- /dev/null
+++ b/plugin/pkg/dnsutil/join.go
@@ -0,0 +1,19 @@
+package dnsutil
+
+import (
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// Join joins labels to form a fully qualified domain name. If the last label is
+// the root label it is ignored. Not other syntax checks are performed.
+func Join(labels []string) string {
+ ll := len(labels)
+ if labels[ll-1] == "." {
+ s := strings.Join(labels[:ll-1], ".")
+ return dns.Fqdn(s)
+ }
+ s := strings.Join(labels, ".")
+ return dns.Fqdn(s)
+}
diff --git a/plugin/pkg/dnsutil/join_test.go b/plugin/pkg/dnsutil/join_test.go
new file mode 100644
index 000000000..26eeb5897
--- /dev/null
+++ b/plugin/pkg/dnsutil/join_test.go
@@ -0,0 +1,20 @@
+package dnsutil
+
+import "testing"
+
+func TestJoin(t *testing.T) {
+ tests := []struct {
+ in []string
+ out string
+ }{
+ {[]string{"bla", "bliep", "example", "org"}, "bla.bliep.example.org."},
+ {[]string{"example", "."}, "example."},
+ {[]string{"."}, "."},
+ }
+
+ for i, tc := range tests {
+ if x := Join(tc.in); x != tc.out {
+ t.Errorf("Test %d, expected %s, got %s", i, tc.out, x)
+ }
+ }
+}
diff --git a/plugin/pkg/dnsutil/reverse.go b/plugin/pkg/dnsutil/reverse.go
new file mode 100644
index 000000000..daf9cc600
--- /dev/null
+++ b/plugin/pkg/dnsutil/reverse.go
@@ -0,0 +1,68 @@
+package dnsutil
+
+import (
+ "net"
+ "strings"
+)
+
+// ExtractAddressFromReverse turns a standard PTR reverse record name
+// into an IP address. This works for ipv4 or ipv6.
+//
+// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
+// failes the empty string is returned.
+func ExtractAddressFromReverse(reverseName string) string {
+ search := ""
+
+ f := reverse
+
+ switch {
+ case strings.HasSuffix(reverseName, v4arpaSuffix):
+ search = strings.TrimSuffix(reverseName, v4arpaSuffix)
+ case strings.HasSuffix(reverseName, v6arpaSuffix):
+ search = strings.TrimSuffix(reverseName, v6arpaSuffix)
+ f = reverse6
+ default:
+ return ""
+ }
+
+ // Reverse the segments and then combine them.
+ return f(strings.Split(search, "."))
+}
+
+func reverse(slice []string) string {
+ for i := 0; i < len(slice)/2; i++ {
+ j := len(slice) - i - 1
+ slice[i], slice[j] = slice[j], slice[i]
+ }
+ ip := net.ParseIP(strings.Join(slice, ".")).To4()
+ if ip == nil {
+ return ""
+ }
+ return ip.String()
+}
+
+// reverse6 reverse the segments and combine them according to RFC3596:
+// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2
+// is reversed to 2001:db8::567:89ab
+func reverse6(slice []string) string {
+ for i := 0; i < len(slice)/2; i++ {
+ j := len(slice) - i - 1
+ slice[i], slice[j] = slice[j], slice[i]
+ }
+ slice6 := []string{}
+ for i := 0; i < len(slice)/4; i++ {
+ slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], ""))
+ }
+ ip := net.ParseIP(strings.Join(slice6, ":")).To16()
+ if ip == nil {
+ return ""
+ }
+ return ip.String()
+}
+
+const (
+ // v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
+ v4arpaSuffix = ".in-addr.arpa."
+ // v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
+ v6arpaSuffix = ".ip6.arpa."
+)
diff --git a/plugin/pkg/dnsutil/reverse_test.go b/plugin/pkg/dnsutil/reverse_test.go
new file mode 100644
index 000000000..25bd897ac
--- /dev/null
+++ b/plugin/pkg/dnsutil/reverse_test.go
@@ -0,0 +1,51 @@
+package dnsutil
+
+import (
+ "testing"
+)
+
+func TestExtractAddressFromReverse(t *testing.T) {
+ tests := []struct {
+ reverseName string
+ expectedAddress string
+ }{
+ {
+ "54.119.58.176.in-addr.arpa.",
+ "176.58.119.54",
+ },
+ {
+ ".58.176.in-addr.arpa.",
+ "",
+ },
+ {
+ "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.in-addr.arpa.",
+ "",
+ },
+ {
+ "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ "2001:db8::567:89ab",
+ },
+ {
+ "d.0.1.0.0.2.ip6.arpa.",
+ "",
+ },
+ {
+ "54.119.58.176.ip6.arpa.",
+ "",
+ },
+ {
+ "NONAME",
+ "",
+ },
+ {
+ "",
+ "",
+ },
+ }
+ for i, test := range tests {
+ got := ExtractAddressFromReverse(test.reverseName)
+ if got != test.expectedAddress {
+ t.Errorf("Test %d, expected '%s', got '%s'", i, test.expectedAddress, got)
+ }
+ }
+}
diff --git a/plugin/pkg/dnsutil/zone.go b/plugin/pkg/dnsutil/zone.go
new file mode 100644
index 000000000..579fef1ba
--- /dev/null
+++ b/plugin/pkg/dnsutil/zone.go
@@ -0,0 +1,20 @@
+package dnsutil
+
+import (
+ "errors"
+
+ "github.com/miekg/dns"
+)
+
+// TrimZone removes the zone component from q. It returns the trimmed
+// name or an error is zone is longer then qname. The trimmed name will be returned
+// without a trailing dot.
+func TrimZone(q string, z string) (string, error) {
+ zl := dns.CountLabel(z)
+ i, ok := dns.PrevLabel(q, zl)
+ if ok || i-1 < 0 {
+ return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
+ }
+ // This includes the '.', remove on return
+ return q[:i-1], nil
+}
diff --git a/plugin/pkg/dnsutil/zone_test.go b/plugin/pkg/dnsutil/zone_test.go
new file mode 100644
index 000000000..81cd1adad
--- /dev/null
+++ b/plugin/pkg/dnsutil/zone_test.go
@@ -0,0 +1,39 @@
+package dnsutil
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestTrimZone(t *testing.T) {
+ tests := []struct {
+ qname string
+ zone string
+ expected string
+ err error
+ }{
+ {"a.example.org", "example.org", "a", nil},
+ {"a.b.example.org", "example.org", "a.b", nil},
+ {"b.", ".", "b", nil},
+ {"example.org", "example.org", "", errors.New("should err")},
+ {"org", "example.org", "", errors.New("should err")},
+ }
+
+ for i, tc := range tests {
+ got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone))
+ if tc.err != nil && err == nil {
+ t.Errorf("Test %d, expected error got nil", i)
+ continue
+ }
+ if tc.err == nil && err != nil {
+ t.Errorf("Test %d, expected no error got %v", i, err)
+ continue
+ }
+ if got != tc.expected {
+ t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got)
+ continue
+ }
+ }
+}
diff --git a/plugin/pkg/edns/edns.go b/plugin/pkg/edns/edns.go
new file mode 100644
index 000000000..3f0ea5e16
--- /dev/null
+++ b/plugin/pkg/edns/edns.go
@@ -0,0 +1,46 @@
+// Package edns provides function useful for adding/inspecting OPT records to/in messages.
+package edns
+
+import (
+ "errors"
+
+ "github.com/miekg/dns"
+)
+
+// Version checks the EDNS version in the request. If error
+// is nil everything is OK and we can invoke the plugin. If non-nil, the
+// returned Msg is valid to be returned to the client (and should). For some
+// reason this response should not contain a question RR in the question section.
+func Version(req *dns.Msg) (*dns.Msg, error) {
+ opt := req.IsEdns0()
+ if opt == nil {
+ return nil, nil
+ }
+ if opt.Version() == 0 {
+ return nil, nil
+ }
+ m := new(dns.Msg)
+ m.SetReply(req)
+ // zero out question section, wtf.
+ m.Question = nil
+
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ o.SetVersion(0)
+ o.SetExtendedRcode(dns.RcodeBadVers)
+ m.Extra = []dns.RR{o}
+
+ return m, errors.New("EDNS0 BADVERS")
+}
+
+// Size returns a normalized size based on proto.
+func Size(proto string, size int) int {
+ if proto == "tcp" {
+ return dns.MaxMsgSize
+ }
+ if size < dns.MinMsgSize {
+ return dns.MinMsgSize
+ }
+ return size
+}
diff --git a/plugin/pkg/edns/edns_test.go b/plugin/pkg/edns/edns_test.go
new file mode 100644
index 000000000..89ac6d2ec
--- /dev/null
+++ b/plugin/pkg/edns/edns_test.go
@@ -0,0 +1,37 @@
+package edns
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestVersion(t *testing.T) {
+ m := ednsMsg()
+ m.Extra[0].(*dns.OPT).SetVersion(2)
+
+ _, err := Version(m)
+ if err == nil {
+ t.Errorf("expected wrong version, but got OK")
+ }
+}
+
+func TestVersionNoEdns(t *testing.T) {
+ m := ednsMsg()
+ m.Extra = nil
+
+ _, err := Version(m)
+ if err != nil {
+ t.Errorf("expected no error, but got one: %s", err)
+ }
+}
+
+func ednsMsg() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("example.com.", dns.TypeA)
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ m.Extra = append(m.Extra, o)
+ return m
+}
diff --git a/plugin/pkg/healthcheck/healthcheck.go b/plugin/pkg/healthcheck/healthcheck.go
new file mode 100644
index 000000000..18f09087c
--- /dev/null
+++ b/plugin/pkg/healthcheck/healthcheck.go
@@ -0,0 +1,243 @@
+package healthcheck
+
+import (
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/url"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// UpstreamHostDownFunc can be used to customize how Down behaves.
+type UpstreamHostDownFunc func(*UpstreamHost) bool
+
+// UpstreamHost represents a single proxy upstream
+type UpstreamHost struct {
+ Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
+ Name string // IP address (and port) of this upstream host
+ Network string // Network (tcp, unix, etc) of the host, default "" is "tcp"
+ Fails int32
+ FailTimeout time.Duration
+ OkUntil time.Time
+ CheckDown UpstreamHostDownFunc
+ CheckURL string
+ WithoutPathPrefix string
+ Checking bool
+ CheckMu sync.Mutex
+}
+
+// Down checks whether the upstream host is down or not.
+// Down will try to use uh.CheckDown first, and will fall
+// back to some default criteria if necessary.
+func (uh *UpstreamHost) Down() bool {
+ if uh.CheckDown == nil {
+ // Default settings
+ fails := atomic.LoadInt32(&uh.Fails)
+ after := false
+
+ uh.CheckMu.Lock()
+ until := uh.OkUntil
+ uh.CheckMu.Unlock()
+
+ if !until.IsZero() && time.Now().After(until) {
+ after = true
+ }
+
+ return after || fails > 0
+ }
+ return uh.CheckDown(uh)
+}
+
+// HostPool is a collection of UpstreamHosts.
+type HostPool []*UpstreamHost
+
+// HealthCheck is used for performing healthcheck
+// on a collection of upstream hosts and select
+// one based on the policy.
+type HealthCheck struct {
+ wg sync.WaitGroup // Used to wait for running goroutines to stop.
+ stop chan struct{} // Signals running goroutines to stop.
+ Hosts HostPool
+ Policy Policy
+ Spray Policy
+ FailTimeout time.Duration
+ MaxFails int32
+ Future time.Duration
+ Path string
+ Port string
+ Interval time.Duration
+}
+
+// Start starts the healthcheck
+func (u *HealthCheck) Start() {
+ u.stop = make(chan struct{})
+ if u.Path != "" {
+ u.wg.Add(1)
+ go func() {
+ defer u.wg.Done()
+ u.healthCheckWorker(u.stop)
+ }()
+ }
+}
+
+// Stop sends a signal to all goroutines started by this staticUpstream to exit
+// and waits for them to finish before returning.
+func (u *HealthCheck) Stop() error {
+ close(u.stop)
+ u.wg.Wait()
+ return nil
+}
+
+// This was moved into a thread so that each host could throw a health
+// check at the same time. The reason for this is that if we are checking
+// 3 hosts, and the first one is gone, and we spend minutes timing out to
+// fail it, we would not have been doing any other health checks in that
+// time. So we now have a per-host lock and a threaded health check.
+//
+// We use the Checking bool to avoid concurrent checks against the same
+// host; if one is taking a long time, the next one will find a check in
+// progress and simply return before trying.
+//
+// We are carefully avoiding having the mutex locked while we check,
+// otherwise checks will back up, potentially a lot of them if a host is
+// absent for a long time. This arrangement makes checks quickly see if
+// they are the only one running and abort otherwise.
+func healthCheckURL(nextTs time.Time, host *UpstreamHost) {
+
+ // lock for our bool check. We don't just defer the unlock because
+ // we don't want the lock held while http.Get runs
+ host.CheckMu.Lock()
+
+ // are we mid check? Don't run another one
+ if host.Checking {
+ host.CheckMu.Unlock()
+ return
+ }
+
+ host.Checking = true
+ host.CheckMu.Unlock()
+
+ //log.Printf("[DEBUG] Healthchecking %s, nextTs is %s\n", url, nextTs.Local())
+
+ // fetch that url. This has been moved into a go func because
+ // when the remote host is not merely not serving, but actually
+ // absent, then tcp syn timeouts can be very long, and so one
+ // fetch could last several check intervals
+ if r, err := http.Get(host.CheckURL); err == nil {
+ io.Copy(ioutil.Discard, r.Body)
+ r.Body.Close()
+
+ if r.StatusCode < 200 || r.StatusCode >= 400 {
+ log.Printf("[WARNING] Host %s health check returned HTTP code %d\n",
+ host.Name, r.StatusCode)
+ nextTs = time.Unix(0, 0)
+ }
+ } else {
+ log.Printf("[WARNING] Host %s health check probe failed: %v\n", host.Name, err)
+ nextTs = time.Unix(0, 0)
+ }
+
+ host.CheckMu.Lock()
+ host.Checking = false
+ host.OkUntil = nextTs
+ host.CheckMu.Unlock()
+}
+
+func (u *HealthCheck) healthCheck() {
+ for _, host := range u.Hosts {
+
+ if host.CheckURL == "" {
+ var hostName, checkPort string
+
+ // The DNS server might be an HTTP server. If so, extract its name.
+ ret, err := url.Parse(host.Name)
+ if err == nil && len(ret.Host) > 0 {
+ hostName = ret.Host
+ } else {
+ hostName = host.Name
+ }
+
+ // Extract the port number from the parsed server name.
+ checkHostName, checkPort, err := net.SplitHostPort(hostName)
+ if err != nil {
+ checkHostName = hostName
+ }
+
+ if u.Port != "" {
+ checkPort = u.Port
+ }
+
+ host.CheckURL = "http://" + net.JoinHostPort(checkHostName, checkPort) + u.Path
+ }
+
+ // calculate this before the get
+ nextTs := time.Now().Add(u.Future)
+
+ // locks/bools should prevent requests backing up
+ go healthCheckURL(nextTs, host)
+ }
+}
+
+func (u *HealthCheck) healthCheckWorker(stop chan struct{}) {
+ ticker := time.NewTicker(u.Interval)
+ u.healthCheck()
+ for {
+ select {
+ case <-ticker.C:
+ u.healthCheck()
+ case <-stop:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+// Select selects an upstream host based on the policy
+// and the healthcheck result.
+func (u *HealthCheck) Select() *UpstreamHost {
+ pool := u.Hosts
+ if len(pool) == 1 {
+ if pool[0].Down() && u.Spray == nil {
+ return nil
+ }
+ return pool[0]
+ }
+ allDown := true
+ for _, host := range pool {
+ if !host.Down() {
+ allDown = false
+ break
+ }
+ }
+ if allDown {
+ if u.Spray == nil {
+ return nil
+ }
+ return u.Spray.Select(pool)
+ }
+
+ if u.Policy == nil {
+ h := (&Random{}).Select(pool)
+ if h != nil {
+ return h
+ }
+ if h == nil && u.Spray == nil {
+ return nil
+ }
+ return u.Spray.Select(pool)
+ }
+
+ h := u.Policy.Select(pool)
+ if h != nil {
+ return h
+ }
+
+ if u.Spray == nil {
+ return nil
+ }
+ return u.Spray.Select(pool)
+}
diff --git a/plugin/pkg/healthcheck/policy.go b/plugin/pkg/healthcheck/policy.go
new file mode 100644
index 000000000..6a828fc4d
--- /dev/null
+++ b/plugin/pkg/healthcheck/policy.go
@@ -0,0 +1,120 @@
+package healthcheck
+
+import (
+ "log"
+ "math/rand"
+ "sync/atomic"
+)
+
+var (
+ // SupportedPolicies is the collection of policies registered
+ SupportedPolicies = make(map[string]func() Policy)
+)
+
+// RegisterPolicy adds a custom policy to the proxy.
+func RegisterPolicy(name string, policy func() Policy) {
+ SupportedPolicies[name] = policy
+}
+
+// Policy decides how a host will be selected from a pool. When all hosts are unhealthy, it is assumed the
+// healthchecking failed. In this case each policy will *randomly* return a host from the pool to prevent
+// no traffic to go through at all.
+type Policy interface {
+ Select(pool HostPool) *UpstreamHost
+}
+
+func init() {
+ RegisterPolicy("random", func() Policy { return &Random{} })
+ RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
+ RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
+}
+
+// Random is a policy that selects up hosts from a pool at random.
+type Random struct{}
+
+// Select selects an up host at random from the specified pool.
+func (r *Random) Select(pool HostPool) *UpstreamHost {
+ // instead of just generating a random index
+ // this is done to prevent selecting a down host
+ var randHost *UpstreamHost
+ count := 0
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ count++
+ if count == 1 {
+ randHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ randHost = host
+ }
+ }
+ }
+ return randHost
+}
+
+// Spray is a policy that selects a host from a pool at random. This should be used as a last ditch
+// attempt to get a host when all hosts are reporting unhealthy.
+type Spray struct{}
+
+// Select selects an up host at random from the specified pool.
+func (r *Spray) Select(pool HostPool) *UpstreamHost {
+ rnd := rand.Int() % len(pool)
+ randHost := pool[rnd]
+ log.Printf("[WARNING] All hosts reported as down, spraying to target: %s", randHost.Name)
+ return randHost
+}
+
+// LeastConn is a policy that selects the host with the least connections.
+type LeastConn struct{}
+
+// Select selects the up host with the least number of connections in the
+// pool. If more than one host has the same least number of connections,
+// one of the hosts is chosen at random.
+func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
+ var bestHost *UpstreamHost
+ count := 0
+ leastConn := int64(1<<63 - 1)
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ hostConns := host.Conns
+ if hostConns < leastConn {
+ bestHost = host
+ leastConn = hostConns
+ count = 1
+ } else if hostConns == leastConn {
+ // randomly select host among hosts with least connections
+ count++
+ if count == 1 {
+ bestHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ bestHost = host
+ }
+ }
+ }
+ }
+ return bestHost
+}
+
+// RoundRobin is a policy that selects hosts based on round robin ordering.
+type RoundRobin struct {
+ Robin uint32
+}
+
+// Select selects an up host from the pool using a round robin ordering scheme.
+func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
+ poolLen := uint32(len(pool))
+ selection := atomic.AddUint32(&r.Robin, 1) % poolLen
+ host := pool[selection]
+ // if the currently selected host is down, just ffwd to up host
+ for i := uint32(1); host.Down() && i < poolLen; i++ {
+ host = pool[(selection+i)%poolLen]
+ }
+ return host
+}
diff --git a/plugin/pkg/healthcheck/policy_test.go b/plugin/pkg/healthcheck/policy_test.go
new file mode 100644
index 000000000..4c667952c
--- /dev/null
+++ b/plugin/pkg/healthcheck/policy_test.go
@@ -0,0 +1,143 @@
+package healthcheck
+
+import (
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+ "time"
+)
+
+var workableServer *httptest.Server
+
+func TestMain(m *testing.M) {
+ workableServer = httptest.NewServer(http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ // do nothing
+ }))
+ r := m.Run()
+ workableServer.Close()
+ os.Exit(r)
+}
+
+type customPolicy struct{}
+
+func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
+ return pool[0]
+}
+
+func testPool() HostPool {
+ pool := []*UpstreamHost{
+ {
+ Name: workableServer.URL, // this should resolve (healthcheck test)
+ },
+ {
+ Name: "http://shouldnot.resolve", // this shouldn't
+ },
+ {
+ Name: "http://C",
+ },
+ }
+ return HostPool(pool)
+}
+
+func TestRegisterPolicy(t *testing.T) {
+ name := "custom"
+ customPolicy := &customPolicy{}
+ RegisterPolicy(name, func() Policy { return customPolicy })
+ if _, ok := SupportedPolicies[name]; !ok {
+ t.Error("Expected supportedPolicies to have a custom policy.")
+ }
+
+}
+
+// TODO(miek): Disabled for now, we should get out of the habit of using
+// realtime in these tests .
+func testHealthCheck(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+
+ u := &HealthCheck{
+ Hosts: testPool(),
+ FailTimeout: 10 * time.Second,
+ Future: 60 * time.Second,
+ MaxFails: 1,
+ }
+
+ u.healthCheck()
+ // sleep a bit, it's async now
+ time.Sleep(time.Duration(2 * time.Second))
+
+ if u.Hosts[0].Down() {
+ t.Error("Expected first host in testpool to not fail healthcheck.")
+ }
+ if !u.Hosts[1].Down() {
+ t.Error("Expected second host in testpool to fail healthcheck.")
+ }
+}
+
+func TestSelect(t *testing.T) {
+ u := &HealthCheck{
+ Hosts: testPool()[:3],
+ FailTimeout: 10 * time.Second,
+ Future: 60 * time.Second,
+ MaxFails: 1,
+ }
+ u.Hosts[0].OkUntil = time.Unix(0, 0)
+ u.Hosts[1].OkUntil = time.Unix(0, 0)
+ u.Hosts[2].OkUntil = time.Unix(0, 0)
+ if h := u.Select(); h != nil {
+ t.Error("Expected select to return nil as all host are down")
+ }
+ u.Hosts[2].OkUntil = time.Time{}
+ if h := u.Select(); h == nil {
+ t.Error("Expected select to not return nil")
+ }
+}
+
+func TestRoundRobinPolicy(t *testing.T) {
+ pool := testPool()
+ rrPolicy := &RoundRobin{}
+ h := rrPolicy.Select(pool)
+ // First selected host is 1, because counter starts at 0
+ // and increments before host is selected
+ if h != pool[1] {
+ t.Error("Expected first round robin host to be second host in the pool.")
+ }
+ h = rrPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected second round robin host to be third host in the pool.")
+ }
+ // mark host as down
+ pool[0].OkUntil = time.Unix(0, 0)
+ h = rrPolicy.Select(pool)
+ if h != pool[1] {
+ t.Error("Expected third round robin host to be first host in the pool.")
+ }
+}
+
+func TestLeastConnPolicy(t *testing.T) {
+ pool := testPool()
+ lcPolicy := &LeastConn{}
+ pool[0].Conns = 10
+ pool[1].Conns = 10
+ h := lcPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected least connection host to be third host.")
+ }
+ pool[2].Conns = 100
+ h = lcPolicy.Select(pool)
+ if h != pool[0] && h != pool[1] {
+ t.Error("Expected least connection host to be first or second host.")
+ }
+}
+
+func TestCustomPolicy(t *testing.T) {
+ pool := testPool()
+ customPolicy := &customPolicy{}
+ h := customPolicy.Select(pool)
+ if h != pool[0] {
+ t.Error("Expected custom policy host to be the first host.")
+ }
+}
diff --git a/plugin/pkg/nonwriter/nonwriter.go b/plugin/pkg/nonwriter/nonwriter.go
new file mode 100644
index 000000000..7819a320f
--- /dev/null
+++ b/plugin/pkg/nonwriter/nonwriter.go
@@ -0,0 +1,23 @@
+// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written.
+package nonwriter
+
+import (
+ "github.com/miekg/dns"
+)
+
+// Writer is a type of ResponseWriter that captures the message, but never writes to the client.
+type Writer struct {
+ dns.ResponseWriter
+ Msg *dns.Msg
+}
+
+// New makes and returns a new NonWriter.
+func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} }
+
+// WriteMsg records the message, but doesn't write it itself.
+func (w *Writer) WriteMsg(res *dns.Msg) error {
+ w.Msg = res
+ return nil
+}
+
+func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil }
diff --git a/plugin/pkg/nonwriter/nonwriter_test.go b/plugin/pkg/nonwriter/nonwriter_test.go
new file mode 100644
index 000000000..d8433af55
--- /dev/null
+++ b/plugin/pkg/nonwriter/nonwriter_test.go
@@ -0,0 +1,19 @@
+package nonwriter
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestNonWriter(t *testing.T) {
+ nw := New(nil)
+ m := new(dns.Msg)
+ m.SetQuestion("example.org.", dns.TypeA)
+ if err := nw.WriteMsg(m); err != nil {
+ t.Errorf("Got error when writing to nonwriter: %s", err)
+ }
+ if x := nw.Msg.Question[0].Name; x != "example.org." {
+ t.Errorf("Expacted 'example.org.' got %q:", x)
+ }
+}
diff --git a/plugin/pkg/rcode/rcode.go b/plugin/pkg/rcode/rcode.go
new file mode 100644
index 000000000..32863f0b2
--- /dev/null
+++ b/plugin/pkg/rcode/rcode.go
@@ -0,0 +1,16 @@
+package rcode
+
+import (
+ "strconv"
+
+ "github.com/miekg/dns"
+)
+
+// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE
+// value is unknown.
+func ToString(rcode int) string {
+ if str, ok := dns.RcodeToString[rcode]; ok {
+ return str
+ }
+ return "RCODE" + strconv.Itoa(rcode)
+}
diff --git a/plugin/pkg/rcode/rcode_test.go b/plugin/pkg/rcode/rcode_test.go
new file mode 100644
index 000000000..bfca32f1d
--- /dev/null
+++ b/plugin/pkg/rcode/rcode_test.go
@@ -0,0 +1,29 @@
+package rcode
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestToString(t *testing.T) {
+ tests := []struct {
+ in int
+ expected string
+ }{
+ {
+ dns.RcodeSuccess,
+ "NOERROR",
+ },
+ {
+ 28,
+ "RCODE28",
+ },
+ }
+ for i, test := range tests {
+ got := ToString(test.in)
+ if got != test.expected {
+ t.Errorf("Test %d, expected %s, got %s", i, test.expected, got)
+ }
+ }
+}
diff --git a/plugin/pkg/replacer/replacer.go b/plugin/pkg/replacer/replacer.go
new file mode 100644
index 000000000..fc98e5d29
--- /dev/null
+++ b/plugin/pkg/replacer/replacer.go
@@ -0,0 +1,161 @@
+package replacer
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnsrecorder"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+// Replacer is a type which can replace placeholder
+// substrings in a string with actual values from a
+// dns.Msg and responseRecorder. Always use
+// NewReplacer to get one of these.
+type Replacer interface {
+ Replace(string) string
+ Set(key, value string)
+}
+
+type replacer struct {
+ replacements map[string]string
+ emptyValue string
+}
+
+// New makes a new replacer based on r and rr.
+// Do not create a new replacer until r and rr have all
+// the needed values, because this function copies those
+// values into the replacer. rr may be nil if it is not
+// available. emptyValue should be the string that is used
+// in place of empty string (can still be empty string).
+func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
+ req := request.Request{W: rr, Req: r}
+ rep := replacer{
+ replacements: map[string]string{
+ "{type}": req.Type(),
+ "{name}": req.Name(),
+ "{class}": req.Class(),
+ "{proto}": req.Proto(),
+ "{when}": func() string {
+ return time.Now().Format(timeFormat)
+ }(),
+ "{size}": strconv.Itoa(req.Len()),
+ "{remote}": req.IP(),
+ "{port}": req.Port(),
+ },
+ emptyValue: emptyValue,
+ }
+ if rr != nil {
+ rcode := dns.RcodeToString[rr.Rcode]
+ if rcode == "" {
+ rcode = strconv.Itoa(rr.Rcode)
+ }
+ rep.replacements["{rcode}"] = rcode
+ rep.replacements["{rsize}"] = strconv.Itoa(rr.Len)
+ rep.replacements["{duration}"] = time.Since(rr.Start).String()
+ if rr.Msg != nil {
+ rep.replacements[headerReplacer+"rflags}"] = flagsToString(rr.Msg.MsgHdr)
+ }
+ }
+
+ // Header placeholders (case-insensitive)
+ rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
+ rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(r.Opcode)
+ rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
+ rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
+
+ return rep
+}
+
+// Replace performs a replacement of values on s and returns
+// the string with the replaced values.
+func (r replacer) Replace(s string) string {
+ // Header replacements - these are case-insensitive, so we can't just use strings.Replace()
+ for strings.Contains(s, headerReplacer) {
+ idxStart := strings.Index(s, headerReplacer)
+ endOffset := idxStart + len(headerReplacer)
+ idxEnd := strings.Index(s[endOffset:], "}")
+ if idxEnd > -1 {
+ placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
+ replacement := r.replacements[placeholder]
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
+ } else {
+ break
+ }
+ }
+
+ // Regular replacements - these are easier because they're case-sensitive
+ for placeholder, replacement := range r.replacements {
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = strings.Replace(s, placeholder, replacement, -1)
+ }
+
+ return s
+}
+
+// Set sets key to value in the replacements map.
+func (r replacer) Set(key, value string) {
+ r.replacements["{"+key+"}"] = value
+}
+
+func boolToString(b bool) string {
+ if b {
+ return "true"
+ }
+ return "false"
+}
+
+// flagsToString checks all header flags and returns those
+// that are set as a string separated with commas
+func flagsToString(h dns.MsgHdr) string {
+ flags := make([]string, 7)
+ i := 0
+
+ if h.Response {
+ flags[i] = "qr"
+ i++
+ }
+
+ if h.Authoritative {
+ flags[i] = "aa"
+ i++
+ }
+ if h.Truncated {
+ flags[i] = "tc"
+ i++
+ }
+ if h.RecursionDesired {
+ flags[i] = "rd"
+ i++
+ }
+ if h.RecursionAvailable {
+ flags[i] = "ra"
+ i++
+ }
+ if h.Zero {
+ flags[i] = "z"
+ i++
+ }
+ if h.AuthenticatedData {
+ flags[i] = "ad"
+ i++
+ }
+ if h.CheckingDisabled {
+ flags[i] = "cd"
+ i++
+ }
+ return strings.Join(flags[:i], ",")
+}
+
+const (
+ timeFormat = "02/Jan/2006:15:04:05 -0700"
+ headerReplacer = "{>"
+)
diff --git a/plugin/pkg/replacer/replacer_test.go b/plugin/pkg/replacer/replacer_test.go
new file mode 100644
index 000000000..95c3bbd52
--- /dev/null
+++ b/plugin/pkg/replacer/replacer_test.go
@@ -0,0 +1,61 @@
+package replacer
+
+import (
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/dnsrecorder"
+ "github.com/coredns/coredns/plugin/test"
+
+ "github.com/miekg/dns"
+)
+
+func TestNewReplacer(t *testing.T) {
+ w := dnsrecorder.New(&test.ResponseWriter{})
+
+ r := new(dns.Msg)
+ r.SetQuestion("example.org.", dns.TypeHINFO)
+ r.MsgHdr.AuthenticatedData = true
+
+ replaceValues := New(r, w, "")
+
+ switch v := replaceValues.(type) {
+ case replacer:
+
+ if v.replacements["{type}"] != "HINFO" {
+ t.Errorf("Expected type to be HINFO, got %q", v.replacements["{type}"])
+ }
+ if v.replacements["{name}"] != "example.org." {
+ t.Errorf("Expected request name to be example.org., got %q", v.replacements["{name}"])
+ }
+ if v.replacements["{size}"] != "29" { // size of request
+ t.Errorf("Expected size to be 29, got %q", v.replacements["{size}"])
+ }
+
+ default:
+ t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
+ }
+}
+
+func TestSet(t *testing.T) {
+ w := dnsrecorder.New(&test.ResponseWriter{})
+
+ r := new(dns.Msg)
+ r.SetQuestion("example.org.", dns.TypeHINFO)
+ r.MsgHdr.AuthenticatedData = true
+
+ repl := New(r, w, "")
+
+ repl.Set("name", "coredns.io.")
+ repl.Set("type", "A")
+ repl.Set("size", "20")
+
+ if repl.Replace("This name is {name}") != "This name is coredns.io." {
+ t.Error("Expected name replacement failed")
+ }
+ if repl.Replace("This type is {type}") != "This type is A" {
+ t.Error("Expected type replacement failed")
+ }
+ if repl.Replace("The request size is {size}") != "The request size is 20" {
+ t.Error("Expected size replacement failed")
+ }
+}
diff --git a/plugin/pkg/response/classify.go b/plugin/pkg/response/classify.go
new file mode 100644
index 000000000..2e705cb0b
--- /dev/null
+++ b/plugin/pkg/response/classify.go
@@ -0,0 +1,61 @@
+package response
+
+import "fmt"
+
+// Class holds sets of Types
+type Class int
+
+const (
+ // All is a meta class encompassing all the classes.
+ All Class = iota
+ // Success is a class for a successful response.
+ Success
+ // Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist)
+ Denial
+ // Error is a class for errors, right now defined as not Success and not Denial
+ Error
+)
+
+func (c Class) String() string {
+ switch c {
+ case All:
+ return "all"
+ case Success:
+ return "success"
+ case Denial:
+ return "denial"
+ case Error:
+ return "error"
+ }
+ return ""
+}
+
+// ClassFromString returns the class from the string s. If not class matches
+// the All class and an error are returned
+func ClassFromString(s string) (Class, error) {
+ switch s {
+ case "all":
+ return All, nil
+ case "success":
+ return Success, nil
+ case "denial":
+ return Denial, nil
+ case "error":
+ return Error, nil
+ }
+ return All, fmt.Errorf("invalid Class: %s", s)
+}
+
+// Classify classifies the Type t, it returns its Class.
+func Classify(t Type) Class {
+ switch t {
+ case NoError, Delegation:
+ return Success
+ case NameError, NoData:
+ return Denial
+ case OtherError:
+ fallthrough
+ default:
+ return Error
+ }
+}
diff --git a/plugin/pkg/response/typify.go b/plugin/pkg/response/typify.go
new file mode 100644
index 000000000..7cfaab497
--- /dev/null
+++ b/plugin/pkg/response/typify.go
@@ -0,0 +1,146 @@
+package response
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Type is the type of the message.
+type Type int
+
+const (
+ // NoError indicates a positive reply
+ NoError Type = iota
+ // NameError is a NXDOMAIN in header, SOA in auth.
+ NameError
+ // NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
+ NoData
+ // Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
+ Delegation
+ // Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
+ Meta
+ // Update is an dynamic update message.
+ Update
+ // OtherError indicates any other error: don't cache these.
+ OtherError
+)
+
+var toString = map[Type]string{
+ NoError: "NOERROR",
+ NameError: "NXDOMAIN",
+ NoData: "NODATA",
+ Delegation: "DELEGATION",
+ Meta: "META",
+ Update: "UPDATE",
+ OtherError: "OTHERERROR",
+}
+
+func (t Type) String() string { return toString[t] }
+
+// TypeFromString returns the type from the string s. If not type matches
+// the OtherError type and an error are returned.
+func TypeFromString(s string) (Type, error) {
+ for t, str := range toString {
+ if s == str {
+ return t, nil
+ }
+ }
+ return NoError, fmt.Errorf("invalid Type: %s", s)
+}
+
+// Typify classifies a message, it returns the Type.
+func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
+ if m == nil {
+ return OtherError, nil
+ }
+ opt := m.IsEdns0()
+ do := false
+ if opt != nil {
+ do = opt.Do()
+ }
+
+ if m.Opcode == dns.OpcodeUpdate {
+ return Update, opt
+ }
+
+ // Check transfer and update first
+ if m.Opcode == dns.OpcodeNotify {
+ return Meta, opt
+ }
+
+ if len(m.Question) > 0 {
+ if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
+ return Meta, opt
+ }
+ }
+
+ // If our message contains any expired sigs and we care about that, we should return expired
+ if do {
+ if expired := typifyExpired(m, t); expired {
+ return OtherError, opt
+ }
+ }
+
+ if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
+ return NoError, opt
+ }
+
+ soa := false
+ ns := 0
+ for _, r := range m.Ns {
+ if r.Header().Rrtype == dns.TypeSOA {
+ soa = true
+ continue
+ }
+ if r.Header().Rrtype == dns.TypeNS {
+ ns++
+ }
+ }
+
+ // Check length of different sections, and drop stuff that is just to large? TODO(miek).
+
+ if soa && m.Rcode == dns.RcodeSuccess {
+ return NoData, opt
+ }
+ if soa && m.Rcode == dns.RcodeNameError {
+ return NameError, opt
+ }
+
+ if ns > 0 && m.Rcode == dns.RcodeSuccess {
+ return Delegation, opt
+ }
+
+ if m.Rcode == dns.RcodeSuccess {
+ return NoError, opt
+ }
+
+ return OtherError, opt
+}
+
+func typifyExpired(m *dns.Msg, t time.Time) bool {
+ if expired := typifyExpiredRRSIG(m.Answer, t); expired {
+ return true
+ }
+ if expired := typifyExpiredRRSIG(m.Ns, t); expired {
+ return true
+ }
+ if expired := typifyExpiredRRSIG(m.Extra, t); expired {
+ return true
+ }
+ return false
+}
+
+func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
+ for _, r := range rrs {
+ if r.Header().Rrtype != dns.TypeRRSIG {
+ continue
+ }
+ ok := r.(*dns.RRSIG).ValidityPeriod(t)
+ if !ok {
+ return true
+ }
+ }
+ return false
+}
diff --git a/plugin/pkg/response/typify_test.go b/plugin/pkg/response/typify_test.go
new file mode 100644
index 000000000..faeaf3579
--- /dev/null
+++ b/plugin/pkg/response/typify_test.go
@@ -0,0 +1,84 @@
+package response
+
+import (
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/test"
+
+ "github.com/miekg/dns"
+)
+
+func TestTypifyNilMsg(t *testing.T) {
+ var m *dns.Msg
+
+ ty, _ := Typify(m, time.Now().UTC())
+ if ty != OtherError {
+ t.Errorf("message wrongly typified, expected OtherError, got %s", ty)
+ }
+}
+
+func TestTypifyDelegation(t *testing.T) {
+ m := delegationMsg()
+ mt, _ := Typify(m, time.Now().UTC())
+ if mt != Delegation {
+ t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
+ }
+}
+
+func TestTypifyRRSIG(t *testing.T) {
+ now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017")
+ utc := now.UTC()
+
+ m := delegationMsgRRSIGOK()
+ if mt, _ := Typify(m, utc); mt != Delegation {
+ t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
+ }
+
+ // Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs.
+ m = delegationMsgRRSIGFail()
+ if mt, _ := Typify(m, utc); mt != Delegation {
+ t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
+ }
+
+ m = delegationMsgRRSIGFail()
+ m = addOpt(m)
+ if mt, _ := Typify(m, utc); mt != OtherError {
+ t.Errorf("message is wrongly typified, expected OtherError, got %s", mt)
+ }
+}
+
+func delegationMsg() *dns.Msg {
+ return &dns.Msg{
+ Ns: []dns.RR{
+ test.NS("miek.nl. 3600 IN NS linode.atoom.net."),
+ test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."),
+ test.NS("miek.nl. 3600 IN NS omval.tednet.nl."),
+ },
+ Extra: []dns.RR{
+ test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"),
+ test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"),
+ },
+ }
+}
+
+func delegationMsgRRSIGOK() *dns.Msg {
+ del := delegationMsg()
+ del.Ns = append(del.Ns,
+ test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
+ )
+ return del
+}
+
+func delegationMsgRRSIGFail() *dns.Msg {
+ del := delegationMsg()
+ del.Ns = append(del.Ns,
+ test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
+ )
+ return del
+}
+
+func addOpt(m *dns.Msg) *dns.Msg {
+ m.Extra = append(m.Extra, test.OPT(4096, true))
+ return m
+}
diff --git a/plugin/pkg/singleflight/singleflight.go b/plugin/pkg/singleflight/singleflight.go
new file mode 100644
index 000000000..365e3ef58
--- /dev/null
+++ b/plugin/pkg/singleflight/singleflight.go
@@ -0,0 +1,64 @@
+/*
+Copyright 2012 Google Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+// Package singleflight provides a duplicate function call suppression
+// mechanism.
+package singleflight
+
+import "sync"
+
+// call is an in-flight or completed Do call
+type call struct {
+ wg sync.WaitGroup
+ val interface{}
+ err error
+}
+
+// Group represents a class of work and forms a namespace in which
+// units of work can be executed with duplicate suppression.
+type Group struct {
+ mu sync.Mutex // protects m
+ m map[uint32]*call // lazily initialized
+}
+
+// Do executes and returns the results of the given function, making
+// sure that only one execution is in-flight for a given key at a
+// time. If a duplicate comes in, the duplicate caller waits for the
+// original to complete and receives the same results.
+func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) {
+ g.mu.Lock()
+ if g.m == nil {
+ g.m = make(map[uint32]*call)
+ }
+ if c, ok := g.m[key]; ok {
+ g.mu.Unlock()
+ c.wg.Wait()
+ return c.val, c.err
+ }
+ c := new(call)
+ c.wg.Add(1)
+ g.m[key] = c
+ g.mu.Unlock()
+
+ c.val, c.err = fn()
+ c.wg.Done()
+
+ g.mu.Lock()
+ delete(g.m, key)
+ g.mu.Unlock()
+
+ return c.val, c.err
+}
diff --git a/plugin/pkg/singleflight/singleflight_test.go b/plugin/pkg/singleflight/singleflight_test.go
new file mode 100644
index 000000000..d1d406e0b
--- /dev/null
+++ b/plugin/pkg/singleflight/singleflight_test.go
@@ -0,0 +1,85 @@
+/*
+Copyright 2012 Google Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package singleflight
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestDo(t *testing.T) {
+ var g Group
+ v, err := g.Do(1, func() (interface{}, error) {
+ return "bar", nil
+ })
+ if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
+ t.Errorf("Do = %v; want %v", got, want)
+ }
+ if err != nil {
+ t.Errorf("Do error = %v", err)
+ }
+}
+
+func TestDoErr(t *testing.T) {
+ var g Group
+ someErr := errors.New("Some error")
+ v, err := g.Do(1, func() (interface{}, error) {
+ return nil, someErr
+ })
+ if err != someErr {
+ t.Errorf("Do error = %v; want someErr", err)
+ }
+ if v != nil {
+ t.Errorf("unexpected non-nil value %#v", v)
+ }
+}
+
+func TestDoDupSuppress(t *testing.T) {
+ var g Group
+ c := make(chan string)
+ var calls int32
+ fn := func() (interface{}, error) {
+ atomic.AddInt32(&calls, 1)
+ return <-c, nil
+ }
+
+ const n = 10
+ var wg sync.WaitGroup
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func() {
+ v, err := g.Do(1, fn)
+ if err != nil {
+ t.Errorf("Do error: %v", err)
+ }
+ if v.(string) != "bar" {
+ t.Errorf("got %q; want %q", v, "bar")
+ }
+ wg.Done()
+ }()
+ }
+ time.Sleep(100 * time.Millisecond) // let goroutines above block
+ c <- "bar"
+ wg.Wait()
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Errorf("number of calls = %d; want 1", got)
+ }
+}
diff --git a/plugin/pkg/tls/tls.go b/plugin/pkg/tls/tls.go
new file mode 100644
index 000000000..6fc10dd8e
--- /dev/null
+++ b/plugin/pkg/tls/tls.go
@@ -0,0 +1,128 @@
+package tls
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "time"
+)
+
+// NewTLSConfigFromArgs returns a TLS config based upon the passed
+// in list of arguments. Typically these come straight from the
+// Corefile.
+// no args
+// - creates a Config with no cert and using system CAs
+// - use for a client that talks to a server with a public signed cert (CA installed in system)
+// - the client will not be authenticated by the server since there is no cert
+// one arg: the path to CA PEM file
+// - creates a Config with no cert using a specific CA
+// - use for a client that talks to a server with a private signed cert (CA not installed in system)
+// - the client will not be authenticated by the server since there is no cert
+// two args: path to cert PEM file, the path to private key PEM file
+// - creates a Config with a cert, using system CAs to validate the other end
+// - use for:
+// - a server; or,
+// - a client that talks to a server with a public cert and needs certificate-based authentication
+// - the other end will authenticate this end via the provided cert
+// - the cert of the other end will be verified via system CAs
+// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file
+// - creates a Config with the cert, using specified CA to validate the other end
+// - use for:
+// - a server; or,
+// - a client that talks to a server with a privately signed cert and needs certificate-based
+// authentication
+// - the other end will authenticate this end via the provided cert
+// - this end will verify the other end's cert using the specified CA
+func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) {
+ var err error
+ var c *tls.Config
+ switch len(args) {
+ case 0:
+ // No client cert, use system CA
+ c, err = NewTLSClientConfig("")
+ case 1:
+ // No client cert, use specified CA
+ c, err = NewTLSClientConfig(args[0])
+ case 2:
+ // Client cert, use system CA
+ c, err = NewTLSConfig(args[0], args[1], "")
+ case 3:
+ // Client cert, use specified CA
+ c, err = NewTLSConfig(args[0], args[1], args[2])
+ default:
+ err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args))
+ }
+ if err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+// NewTLSConfig returns a TLS config that includes a certificate
+// Use for server TLS config or when using a client certificate
+// If caPath is empty, system CAs will be used
+func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ if err != nil {
+ return nil, fmt.Errorf("could not load TLS cert: %s", err)
+ }
+
+ roots, err := loadRoots(caPath)
+ if err != nil {
+ return nil, err
+ }
+
+ return &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}, nil
+}
+
+// NewTLSClientConfig returns a TLS config for a client connection
+// If caPath is empty, system CAs will be used
+func NewTLSClientConfig(caPath string) (*tls.Config, error) {
+ roots, err := loadRoots(caPath)
+ if err != nil {
+ return nil, err
+ }
+
+ return &tls.Config{RootCAs: roots}, nil
+}
+
+func loadRoots(caPath string) (*x509.CertPool, error) {
+ if caPath == "" {
+ return nil, nil
+ }
+
+ roots := x509.NewCertPool()
+ pem, err := ioutil.ReadFile(caPath)
+ if err != nil {
+ return nil, fmt.Errorf("error reading %s: %s", caPath, err)
+ }
+ ok := roots.AppendCertsFromPEM(pem)
+ if !ok {
+ return nil, fmt.Errorf("could not read root certs: %s", err)
+ }
+ return roots, nil
+}
+
+// NewHTTPSTransport returns an HTTP transport configured using tls.Config
+func NewHTTPSTransport(cc *tls.Config) *http.Transport {
+ // this seems like a bad idea but was here in the previous version
+ if cc != nil {
+ cc.InsecureSkipVerify = true
+ }
+
+ tr := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ Dial: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }).Dial,
+ TLSHandshakeTimeout: 10 * time.Second,
+ TLSClientConfig: cc,
+ MaxIdleConnsPerHost: 25,
+ }
+
+ return tr
+}
diff --git a/plugin/pkg/tls/tls_test.go b/plugin/pkg/tls/tls_test.go
new file mode 100644
index 000000000..8c88bfcc4
--- /dev/null
+++ b/plugin/pkg/tls/tls_test.go
@@ -0,0 +1,101 @@
+package tls
+
+import (
+ "path/filepath"
+ "testing"
+
+ "github.com/coredns/coredns/plugin/test"
+)
+
+func getPEMFiles(t *testing.T) (rmFunc func(), cert, key, ca string) {
+ tempDir, rmFunc, err := test.WritePEMFiles("")
+ if err != nil {
+ t.Fatalf("Could not write PEM files: %s", err)
+ }
+
+ cert = filepath.Join(tempDir, "cert.pem")
+ key = filepath.Join(tempDir, "key.pem")
+ ca = filepath.Join(tempDir, "ca.pem")
+
+ return
+}
+
+func TestNewTLSConfig(t *testing.T) {
+ rmFunc, cert, key, ca := getPEMFiles(t)
+ defer rmFunc()
+
+ _, err := NewTLSConfig(cert, key, ca)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+}
+
+func TestNewTLSClientConfig(t *testing.T) {
+ rmFunc, _, _, ca := getPEMFiles(t)
+ defer rmFunc()
+
+ _, err := NewTLSClientConfig(ca)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+}
+
+func TestNewTLSConfigFromArgs(t *testing.T) {
+ rmFunc, cert, key, ca := getPEMFiles(t)
+ defer rmFunc()
+
+ _, err := NewTLSConfigFromArgs()
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+
+ c, err := NewTLSConfigFromArgs(ca)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+ if c.RootCAs == nil {
+ t.Error("RootCAs should not be nil when one arg passed")
+ }
+
+ c, err = NewTLSConfigFromArgs(cert, key)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+ if c.RootCAs != nil {
+ t.Error("RootCAs should be nil when two args passed")
+ }
+ if len(c.Certificates) != 1 {
+ t.Error("Certificates should have a single entry when two args passed")
+ }
+ args := []string{cert, key, ca}
+ c, err = NewTLSConfigFromArgs(args...)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+ if c.RootCAs == nil {
+ t.Error("RootCAs should not be nil when three args passed")
+ }
+ if len(c.Certificates) != 1 {
+ t.Error("Certificateis should have a single entry when three args passed")
+ }
+}
+
+func TestNewHTTPSTransport(t *testing.T) {
+ rmFunc, _, _, ca := getPEMFiles(t)
+ defer rmFunc()
+
+ cc, err := NewTLSClientConfig(ca)
+ if err != nil {
+ t.Errorf("Failed to create TLSConfig: %s", err)
+ }
+
+ tr := NewHTTPSTransport(cc)
+ if tr == nil {
+ t.Errorf("Failed to create https transport with cc")
+ }
+
+ tr = NewHTTPSTransport(nil)
+ if tr == nil {
+ t.Errorf("Failed to create https transport without cc")
+ }
+}
diff --git a/plugin/pkg/trace/trace.go b/plugin/pkg/trace/trace.go
new file mode 100644
index 000000000..35a8ddabd
--- /dev/null
+++ b/plugin/pkg/trace/trace.go
@@ -0,0 +1,12 @@
+package trace
+
+import (
+ "github.com/coredns/coredns/plugin"
+ ot "github.com/opentracing/opentracing-go"
+)
+
+// Trace holds the tracer and endpoint info
+type Trace interface {
+ plugin.Handler
+ Tracer() ot.Tracer
+}