aboutsummaryrefslogtreecommitdiff
path: root/plugin/loop/loop.go
blob: 8d29798adb6e0a458bff2c5eadcdb78e8ca4a202 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package loop

import (
	"context"
	"sync"

	"github.com/coredns/coredns/plugin"
	clog "github.com/coredns/coredns/plugin/pkg/log"
	"github.com/coredns/coredns/request"

	"github.com/miekg/dns"
)

var log = clog.NewWithPlugin("loop")

// Loop is a plugin that implements loop detection by sending a "random" query.
type Loop struct {
	Next plugin.Handler

	zone  string
	qname string
	addr  string

	sync.RWMutex
	i   int
	off bool
}

// New returns a new initialized Loop.
func New(zone string) *Loop { return &Loop{zone: zone, qname: qname(zone)} }

// ServeDNS implements the plugin.Handler interface.
func (l *Loop) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	if r.Question[0].Qtype != dns.TypeHINFO {
		return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
	}
	if l.disabled() {
		return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
	}

	state := request.Request{W: w, Req: r}

	zone := plugin.Zones([]string{l.zone}).Matches(state.Name())
	if zone == "" {
		return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
	}

	if state.Name() == l.qname {
		l.inc()
	}

	if l.seen() > 2 {
		log.Fatalf(`Loop (%s -> %s) detected for zone %q, see https://coredns.io/plugins/loop#troubleshooting. Query: "HINFO %s"`, state.RemoteAddr(), l.address(), l.zone, l.qname)
	}

	return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}

// Name implements the plugin.Handler interface.
func (l *Loop) Name() string { return "loop" }

func (l *Loop) exchange(addr string) (*dns.Msg, error) {
	m := new(dns.Msg)
	m.SetQuestion(l.qname, dns.TypeHINFO)

	return dns.Exchange(m, addr)
}

func (l *Loop) seen() int {
	l.RLock()
	defer l.RUnlock()
	return l.i
}

func (l *Loop) inc() {
	l.Lock()
	defer l.Unlock()
	l.i++
}

func (l *Loop) reset() {
	l.Lock()
	defer l.Unlock()
	l.i = 0
}

func (l *Loop) setDisabled() {
	l.Lock()
	defer l.Unlock()
	l.off = true
}

func (l *Loop) disabled() bool {
	l.RLock()
	defer l.RUnlock()
	return l.off
}

func (l *Loop) setAddress(addr string) {
	l.Lock()
	defer l.Unlock()
	l.addr = addr
}

func (l *Loop) address() string {
	l.RLock()
	defer l.RUnlock()
	return l.addr
}