aboutsummaryrefslogtreecommitdiff
path: root/plugin/loop/loop.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/loop/loop.go')
-rw-r--r--plugin/loop/loop.go90
1 files changed, 90 insertions, 0 deletions
diff --git a/plugin/loop/loop.go b/plugin/loop/loop.go
new file mode 100644
index 000000000..56e039c9c
--- /dev/null
+++ b/plugin/loop/loop.go
@@ -0,0 +1,90 @@
+package loop
+
+import (
+ "context"
+ "sync"
+
+ "github.com/coredns/coredns/plugin"
+ clog "github.com/coredns/coredns/plugin/pkg/log"
+ "github.com/coredns/coredns/request"
+
+ "github.com/miekg/dns"
+)
+
+var log = clog.NewWithPlugin("loop")
+
+// Loop is a plugin that implements loop detection by sending a "random" query.
+type Loop struct {
+ Next plugin.Handler
+
+ zone string
+ qname string
+
+ sync.RWMutex
+ i int
+ off bool
+}
+
+// New returns a new initialized Loop.
+func New(zone string) *Loop { return &Loop{zone: zone, qname: qname(zone)} }
+
+// ServeDNS implements the plugin.Handler interface.
+func (l *Loop) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ if r.Question[0].Qtype != dns.TypeHINFO {
+ return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
+ }
+ if l.disabled() {
+ return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
+ }
+
+ state := request.Request{W: w, Req: r}
+
+ zone := plugin.Zones([]string{l.zone}).Matches(state.Name())
+ if zone == "" {
+ return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
+ }
+
+ if state.Name() == l.qname {
+ l.inc()
+ }
+
+ if l.seen() > 2 {
+ log.Fatalf("Seen \"HINFO IN %s\" more than twice, loop detected", l.qname)
+ }
+
+ return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
+}
+
+// Name implements the plugin.Handler interface.
+func (l *Loop) Name() string { return "loop" }
+
+func (l *Loop) exchange(addr string) (*dns.Msg, error) {
+ m := new(dns.Msg)
+ m.SetQuestion(l.qname, dns.TypeHINFO)
+
+ return dns.Exchange(m, addr)
+}
+
+func (l *Loop) seen() int {
+ l.RLock()
+ defer l.RUnlock()
+ return l.i
+}
+
+func (l *Loop) inc() {
+ l.Lock()
+ defer l.Unlock()
+ l.i++
+}
+
+func (l *Loop) setDisabled() {
+ l.Lock()
+ defer l.Unlock()
+ l.off = true
+}
+
+func (l *Loop) disabled() bool {
+ l.RLock()
+ defer l.RUnlock()
+ return l.off
+}