aboutsummaryrefslogtreecommitdiff
path: root/plugin/acl/setup.go
blob: 189acc6c494e19151990750a1e04aec808540e6b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package acl

import (
	"net"
	"strings"

	"github.com/coredns/caddy"
	"github.com/coredns/coredns/core/dnsserver"
	"github.com/coredns/coredns/plugin"

	"github.com/infobloxopen/go-trees/iptree"
	"github.com/miekg/dns"
)

const pluginName = "acl"

func init() { plugin.Register(pluginName, setup) }

func newDefaultFilter() *iptree.Tree {
	defaultFilter := iptree.NewTree()
	_, IPv4All, _ := net.ParseCIDR("0.0.0.0/0")
	_, IPv6All, _ := net.ParseCIDR("::/0")
	defaultFilter.InplaceInsertNet(IPv4All, struct{}{})
	defaultFilter.InplaceInsertNet(IPv6All, struct{}{})
	return defaultFilter
}

func setup(c *caddy.Controller) error {
	a, err := parse(c)
	if err != nil {
		return plugin.Error(pluginName, err)
	}

	dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
		a.Next = next
		return a
	})

	return nil
}

func parse(c *caddy.Controller) (ACL, error) {
	a := ACL{}
	for c.Next() {
		r := rule{}
		args := c.RemainingArgs()
		r.zones = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)

		for c.NextBlock() {
			p := policy{}

			action := strings.ToLower(c.Val())
			if action == "allow" {
				p.action = actionAllow
			} else if action == "block" {
				p.action = actionBlock
			} else if action == "filter" {
				p.action = actionFilter
			} else if action == "drop" {
				p.action = actionDrop
			} else {
				return a, c.Errf("unexpected token %q; expect 'allow', 'block', 'filter' or 'drop'", c.Val())
			}

			p.qtypes = make(map[uint16]struct{})
			p.filter = iptree.NewTree()

			hasTypeSection := false
			hasNetSection := false

			remainingTokens := c.RemainingArgs()
			for len(remainingTokens) > 0 {
				if !isPreservedIdentifier(remainingTokens[0]) {
					return a, c.Errf("unexpected token %q; expect 'type | net'", remainingTokens[0])
				}
				section := strings.ToLower(remainingTokens[0])

				i := 1
				var tokens []string
				for ; i < len(remainingTokens) && !isPreservedIdentifier(remainingTokens[i]); i++ {
					tokens = append(tokens, remainingTokens[i])
				}
				remainingTokens = remainingTokens[i:]

				if len(tokens) == 0 {
					return a, c.Errf("no token specified in %q section", section)
				}

				switch section {
				case "type":
					hasTypeSection = true
					for _, token := range tokens {
						if token == "*" {
							p.qtypes[dns.TypeNone] = struct{}{}
							break
						}
						qtype, ok := dns.StringToType[token]
						if !ok {
							return a, c.Errf("unexpected token %q; expect legal QTYPE", token)
						}
						p.qtypes[qtype] = struct{}{}
					}
				case "net":
					hasNetSection = true
					for _, token := range tokens {
						if token == "*" {
							p.filter = newDefaultFilter()
							break
						}
						token = normalize(token)
						_, source, err := net.ParseCIDR(token)
						if err != nil {
							return a, c.Errf("illegal CIDR notation %q", token)
						}
						p.filter.InplaceInsertNet(source, struct{}{})
					}
				default:
					return a, c.Errf("unexpected token %q; expect 'type | net'", section)
				}
			}

			// optional `type` section means all record types.
			if !hasTypeSection {
				p.qtypes[dns.TypeNone] = struct{}{}
			}

			// optional `net` means all ip addresses.
			if !hasNetSection {
				p.filter = newDefaultFilter()
			}

			r.policies = append(r.policies, p)
		}
		a.Rules = append(a.Rules, r)
	}
	return a, nil
}

func isPreservedIdentifier(token string) bool {
	identifier := strings.ToLower(token)
	return identifier == "type" || identifier == "net"
}

// normalize appends '/32' for any single IPv4 address and '/128' for IPv6.
func normalize(rawNet string) string {
	if idx := strings.IndexAny(rawNet, "/"); idx >= 0 {
		return rawNet
	}

	if idx := strings.IndexAny(rawNet, ":"); idx >= 0 {
		return rawNet + "/128"
	}
	return rawNet + "/32"
}