aboutsummaryrefslogtreecommitdiff
path: root/plugin/hosts
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/hosts')
-rw-r--r--plugin/hosts/README.md45
-rw-r--r--plugin/hosts/hosts.go136
-rw-r--r--plugin/hosts/hosts_test.go75
-rw-r--r--plugin/hosts/hostsfile.go193
-rw-r--r--plugin/hosts/hostsfile_test.go239
-rw-r--r--plugin/hosts/setup.go88
-rw-r--r--plugin/hosts/setup_test.go86
7 files changed, 862 insertions, 0 deletions
diff --git a/plugin/hosts/README.md b/plugin/hosts/README.md
new file mode 100644
index 000000000..60c738077
--- /dev/null
+++ b/plugin/hosts/README.md
@@ -0,0 +1,45 @@
+# hosts
+
+*hosts* enables serving zone data from a `/etc/hosts` style file.
+
+The hosts plugin is useful for serving zones from a /etc/hosts file. It serves from a preloaded
+file that exists on disk. It checks the file for changes and updates the zones accordingly. This
+plugin only supports A, AAAA, and PTR records. The hosts plugin can be used with readily
+available hosts files that block access to advertising servers.
+
+## Syntax
+
+~~~
+hosts [FILE [ZONES...]] {
+ fallthrough
+}
+~~~
+
+* **FILE** the hosts file to read and parse. If the path is relative the path from the *root*
+ directive will be prepended to it. Defaults to /etc/hosts if omitted
+* **ZONES** zones it should be authoritative for. If empty, the zones from the configuration block
+ are used.
+* `fallthrough` If zone matches and no record can be generated, pass request to the next plugin.
+
+## Examples
+
+Load `/etc/hosts` file.
+
+~~~
+hosts
+~~~
+
+Load `example.hosts` file in the current directory.
+
+~~~
+hosts example.hosts
+~~~
+
+Load example.hosts file and only serve example.org and example.net from it and fall through to the
+next plugin if query doesn't match.
+
+~~~
+hosts example.hosts example.org example.net {
+ fallthrough
+}
+~~~
diff --git a/plugin/hosts/hosts.go b/plugin/hosts/hosts.go
new file mode 100644
index 000000000..09dedbb64
--- /dev/null
+++ b/plugin/hosts/hosts.go
@@ -0,0 +1,136 @@
+package hosts
+
+import (
+ "net"
+
+ "golang.org/x/net/context"
+
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/pkg/dnsutil"
+ "github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+)
+
+// Hosts is the plugin handler
+type Hosts struct {
+ Next plugin.Handler
+ *Hostsfile
+
+ Fallthrough bool
+}
+
+// ServeDNS implements the plugin.Handle interface.
+func (h Hosts) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := request.Request{W: w, Req: r}
+ qname := state.Name()
+
+ answers := []dns.RR{}
+
+ zone := plugin.Zones(h.Origins).Matches(qname)
+ if zone == "" {
+ // PTR zones don't need to be specified in Origins
+ if state.Type() != "PTR" {
+ // If this doesn't match we need to fall through regardless of h.Fallthrough
+ return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
+ }
+ }
+
+ switch state.QType() {
+ case dns.TypePTR:
+ names := h.LookupStaticAddr(dnsutil.ExtractAddressFromReverse(qname))
+ if len(names) == 0 {
+ // If this doesn't match we need to fall through regardless of h.Fallthrough
+ return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
+ }
+ answers = h.ptr(qname, names)
+ case dns.TypeA:
+ ips := h.LookupStaticHostV4(qname)
+ answers = a(qname, ips)
+ case dns.TypeAAAA:
+ ips := h.LookupStaticHostV6(qname)
+ answers = aaaa(qname, ips)
+ }
+
+ if len(answers) == 0 {
+ if h.Fallthrough {
+ return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
+ }
+ if !h.otherRecordsExist(state.QType(), qname) {
+ return dns.RcodeNameError, nil
+ }
+ }
+
+ m := new(dns.Msg)
+ m.SetReply(r)
+ m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
+ m.Answer = answers
+
+ state.SizeAndDo(m)
+ m, _ = state.Scrub(m)
+ w.WriteMsg(m)
+ return dns.RcodeSuccess, nil
+}
+
+func (h Hosts) otherRecordsExist(qtype uint16, qname string) bool {
+ switch qtype {
+ case dns.TypeA:
+ if len(h.LookupStaticHostV6(qname)) > 0 {
+ return true
+ }
+ case dns.TypeAAAA:
+ if len(h.LookupStaticHostV4(qname)) > 0 {
+ return true
+ }
+ default:
+ if len(h.LookupStaticHostV4(qname)) > 0 {
+ return true
+ }
+ if len(h.LookupStaticHostV6(qname)) > 0 {
+ return true
+ }
+ }
+ return false
+
+}
+
+// Name implements the plugin.Handle interface.
+func (h Hosts) Name() string { return "hosts" }
+
+// a takes a slice of net.IPs and returns a slice of A RRs.
+func a(zone string, ips []net.IP) []dns.RR {
+ answers := []dns.RR{}
+ for _, ip := range ips {
+ r := new(dns.A)
+ r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeA,
+ Class: dns.ClassINET, Ttl: 3600}
+ r.A = ip
+ answers = append(answers, r)
+ }
+ return answers
+}
+
+// aaaa takes a slice of net.IPs and returns a slice of AAAA RRs.
+func aaaa(zone string, ips []net.IP) []dns.RR {
+ answers := []dns.RR{}
+ for _, ip := range ips {
+ r := new(dns.AAAA)
+ r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET, Ttl: 3600}
+ r.AAAA = ip
+ answers = append(answers, r)
+ }
+ return answers
+}
+
+// ptr takes a slice of host names and filters out the ones that aren't in Origins, if specified, and returns a slice of PTR RRs.
+func (h *Hosts) ptr(zone string, names []string) []dns.RR {
+ answers := []dns.RR{}
+ for _, n := range names {
+ r := new(dns.PTR)
+ r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypePTR,
+ Class: dns.ClassINET, Ttl: 3600}
+ r.Ptr = dns.Fqdn(n)
+ answers = append(answers, r)
+ }
+ return answers
+}
diff --git a/plugin/hosts/hosts_test.go b/plugin/hosts/hosts_test.go
new file mode 100644
index 000000000..68b91b8c2
--- /dev/null
+++ b/plugin/hosts/hosts_test.go
@@ -0,0 +1,75 @@
+package hosts
+
+import (
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/coredns/coredns/plugin/pkg/dnsrecorder"
+ "github.com/coredns/coredns/plugin/test"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+func TestLookupA(t *testing.T) {
+ h := Hosts{Next: test.ErrorHandler(), Hostsfile: &Hostsfile{expire: time.Now().Add(1 * time.Hour), Origins: []string{"."}}}
+ h.Parse(strings.NewReader(hostsExample))
+
+ ctx := context.TODO()
+
+ for _, tc := range hostsTestCases {
+ m := tc.Msg()
+
+ rec := dnsrecorder.New(&test.ResponseWriter{})
+ _, err := h.ServeDNS(ctx, rec, m)
+ if err != nil {
+ t.Errorf("Expected no error, got %v\n", err)
+ return
+ }
+
+ resp := rec.Msg
+ test.SortAndCheck(t, resp, tc)
+ }
+}
+
+var hostsTestCases = []test.Case{
+ {
+ Qname: "example.org.", Qtype: dns.TypeA,
+ Answer: []dns.RR{
+ test.A("example.org. 3600 IN A 10.0.0.1"),
+ },
+ },
+ {
+ Qname: "localhost.", Qtype: dns.TypeAAAA,
+ Answer: []dns.RR{
+ test.AAAA("localhost. 3600 IN AAAA ::1"),
+ },
+ },
+ {
+ Qname: "1.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR,
+ Answer: []dns.RR{
+ test.PTR("1.0.0.10.in-addr.arpa. 3600 PTR example.org."),
+ },
+ },
+ {
+ Qname: "1.0.0.127.in-addr.arpa.", Qtype: dns.TypePTR,
+ Answer: []dns.RR{
+ test.PTR("1.0.0.127.in-addr.arpa. 3600 PTR localhost."),
+ test.PTR("1.0.0.127.in-addr.arpa. 3600 PTR localhost.domain."),
+ },
+ },
+ {
+ Qname: "example.org.", Qtype: dns.TypeAAAA,
+ Answer: []dns.RR{},
+ },
+ {
+ Qname: "example.org.", Qtype: dns.TypeMX,
+ Answer: []dns.RR{},
+ },
+}
+
+const hostsExample = `
+127.0.0.1 localhost localhost.domain
+::1 localhost localhost.domain
+10.0.0.1 example.org`
diff --git a/plugin/hosts/hostsfile.go b/plugin/hosts/hostsfile.go
new file mode 100644
index 000000000..91e828099
--- /dev/null
+++ b/plugin/hosts/hostsfile.go
@@ -0,0 +1,193 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file is a modified version of net/hosts.go from the golang repo
+
+package hosts
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/coredns/coredns/plugin"
+)
+
+const cacheMaxAge = 5 * time.Second
+
+func parseLiteralIP(addr string) net.IP {
+ if i := strings.Index(addr, "%"); i >= 0 {
+ // discard ipv6 zone
+ addr = addr[0:i]
+ }
+
+ return net.ParseIP(addr)
+}
+
+func absDomainName(b string) string {
+ return plugin.Name(b).Normalize()
+}
+
+// Hostsfile contains known host entries.
+type Hostsfile struct {
+ sync.Mutex
+
+ // list of zones we are authoritive for
+ Origins []string
+
+ // Key for the list of literal IP addresses must be a host
+ // name. It would be part of DNS labels, a FQDN or an absolute
+ // FQDN.
+ // For now the key is converted to lower case for convenience.
+ byNameV4 map[string][]net.IP
+ byNameV6 map[string][]net.IP
+
+ // Key for the list of host names must be a literal IP address
+ // including IPv6 address with zone identifier.
+ // We don't support old-classful IP address notation.
+ byAddr map[string][]string
+
+ expire time.Time
+ path string
+ mtime time.Time
+ size int64
+}
+
+// ReadHosts determines if the cached data needs to be updated based on the size and modification time of the hostsfile.
+func (h *Hostsfile) ReadHosts() {
+ now := time.Now()
+
+ if now.Before(h.expire) && len(h.byAddr) > 0 {
+ return
+ }
+ stat, err := os.Stat(h.path)
+ if err == nil && h.mtime.Equal(stat.ModTime()) && h.size == stat.Size() {
+ h.expire = now.Add(cacheMaxAge)
+ return
+ }
+
+ var file *os.File
+ if file, _ = os.Open(h.path); file == nil {
+ return
+ }
+ defer file.Close()
+
+ h.Parse(file)
+
+ // Update the data cache.
+ h.expire = now.Add(cacheMaxAge)
+ h.mtime = stat.ModTime()
+ h.size = stat.Size()
+}
+
+// Parse reads the hostsfile and populates the byName and byAddr maps.
+func (h *Hostsfile) Parse(file io.Reader) {
+ hsv4 := make(map[string][]net.IP)
+ hsv6 := make(map[string][]net.IP)
+ is := make(map[string][]string)
+
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ if i := bytes.Index(line, []byte{'#'}); i >= 0 {
+ // Discard comments.
+ line = line[0:i]
+ }
+ f := bytes.Fields(line)
+ if len(f) < 2 {
+ continue
+ }
+ addr := parseLiteralIP(string(f[0]))
+ if addr == nil {
+ continue
+ }
+ ver := ipVersion(string(f[0]))
+ for i := 1; i < len(f); i++ {
+ name := absDomainName(string(f[i]))
+ if plugin.Zones(h.Origins).Matches(name) == "" {
+ // name is not in Origins
+ continue
+ }
+ switch ver {
+ case 4:
+ hsv4[name] = append(hsv4[name], addr)
+ case 6:
+ hsv6[name] = append(hsv6[name], addr)
+ default:
+ continue
+ }
+ is[addr.String()] = append(is[addr.String()], name)
+ }
+ }
+ h.byNameV4 = hsv4
+ h.byNameV6 = hsv6
+ h.byAddr = is
+}
+
+// ipVersion returns what IP version was used textually
+func ipVersion(s string) int {
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '.':
+ return 4
+ case ':':
+ return 6
+ }
+ }
+ return 0
+}
+
+// LookupStaticHostV4 looks up the IPv4 addresses for the given host from the hosts file.
+func (h *Hostsfile) LookupStaticHostV4(host string) []net.IP {
+ h.Lock()
+ defer h.Unlock()
+ h.ReadHosts()
+ if len(h.byNameV4) != 0 {
+ if ips, ok := h.byNameV4[absDomainName(host)]; ok {
+ ipsCp := make([]net.IP, len(ips))
+ copy(ipsCp, ips)
+ return ipsCp
+ }
+ }
+ return nil
+}
+
+// LookupStaticHostV6 looks up the IPv6 addresses for the given host from the hosts file.
+func (h *Hostsfile) LookupStaticHostV6(host string) []net.IP {
+ h.Lock()
+ defer h.Unlock()
+ h.ReadHosts()
+ if len(h.byNameV6) != 0 {
+ if ips, ok := h.byNameV6[absDomainName(host)]; ok {
+ ipsCp := make([]net.IP, len(ips))
+ copy(ipsCp, ips)
+ return ipsCp
+ }
+ }
+ return nil
+}
+
+// LookupStaticAddr looks up the hosts for the given address from the hosts file.
+func (h *Hostsfile) LookupStaticAddr(addr string) []string {
+ h.Lock()
+ defer h.Unlock()
+ h.ReadHosts()
+ addr = parseLiteralIP(addr).String()
+ if addr == "" {
+ return nil
+ }
+ if len(h.byAddr) != 0 {
+ if hosts, ok := h.byAddr[addr]; ok {
+ hostsCp := make([]string, len(hosts))
+ copy(hostsCp, hosts)
+ return hostsCp
+ }
+ }
+ return nil
+}
diff --git a/plugin/hosts/hostsfile_test.go b/plugin/hosts/hostsfile_test.go
new file mode 100644
index 000000000..65841fa42
--- /dev/null
+++ b/plugin/hosts/hostsfile_test.go
@@ -0,0 +1,239 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package hosts
+
+import (
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+)
+
+func testHostsfile(file string) *Hostsfile {
+ h := &Hostsfile{expire: time.Now().Add(1 * time.Hour), Origins: []string{"."}}
+ h.Parse(strings.NewReader(file))
+ return h
+}
+
+type staticHostEntry struct {
+ in string
+ v4 []string
+ v6 []string
+}
+
+var (
+ hosts = `255.255.255.255 broadcasthost
+ 127.0.0.2 odin
+ 127.0.0.3 odin # inline comment
+ ::2 odin
+ 127.1.1.1 thor
+ # aliases
+ 127.1.1.2 ullr ullrhost
+ fe80::1%lo0 localhost
+ # Bogus entries that must be ignored.
+ 123.123.123 loki
+ 321.321.321.321`
+ singlelinehosts = `127.0.0.2 odin`
+ ipv4hosts = `# See https://tools.ietf.org/html/rfc1123.
+ #
+ # The literal IPv4 address parser in the net package is a relaxed
+ # one. It may accept a literal IPv4 address in dotted-decimal notation
+ # with leading zeros such as "001.2.003.4".
+
+ # internet address and host name
+ 127.0.0.1 localhost # inline comment separated by tab
+ 127.000.000.002 localhost # inline comment separated by space
+
+ # internet address, host name and aliases
+ 127.000.000.003 localhost localhost.localdomain`
+ ipv6hosts = `# See https://tools.ietf.org/html/rfc5952, https://tools.ietf.org/html/rfc4007.
+
+ # internet address and host name
+ ::1 localhost # inline comment separated by tab
+ fe80:0000:0000:0000:0000:0000:0000:0001 localhost # inline comment separated by space
+
+ # internet address with zone identifier and host name
+ fe80:0000:0000:0000:0000:0000:0000:0002%lo0 localhost
+
+ # internet address, host name and aliases
+ fe80::3%lo0 localhost localhost.localdomain`
+ casehosts = `127.0.0.1 PreserveMe PreserveMe.local
+ ::1 PreserveMe PreserveMe.local`
+)
+
+var lookupStaticHostTests = []struct {
+ file string
+ ents []staticHostEntry
+}{
+ {
+ hosts,
+ []staticHostEntry{
+ {"odin", []string{"127.0.0.2", "127.0.0.3"}, []string{"::2"}},
+ {"thor", []string{"127.1.1.1"}, []string{}},
+ {"ullr", []string{"127.1.1.2"}, []string{}},
+ {"ullrhost", []string{"127.1.1.2"}, []string{}},
+ {"localhost", []string{}, []string{"fe80::1"}},
+ },
+ },
+ {
+ singlelinehosts, // see golang.org/issue/6646
+ []staticHostEntry{
+ {"odin", []string{"127.0.0.2"}, []string{}},
+ },
+ },
+ {
+ ipv4hosts,
+ []staticHostEntry{
+ {"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, []string{}},
+ {"localhost.localdomain", []string{"127.0.0.3"}, []string{}},
+ },
+ },
+ {
+ ipv6hosts,
+ []staticHostEntry{
+ {"localhost", []string{}, []string{"::1", "fe80::1", "fe80::2", "fe80::3"}},
+ {"localhost.localdomain", []string{}, []string{"fe80::3"}},
+ },
+ },
+ {
+ casehosts,
+ []staticHostEntry{
+ {"PreserveMe", []string{"127.0.0.1"}, []string{"::1"}},
+ {"PreserveMe.local", []string{"127.0.0.1"}, []string{"::1"}},
+ },
+ },
+}
+
+func TestLookupStaticHost(t *testing.T) {
+
+ for _, tt := range lookupStaticHostTests {
+ h := testHostsfile(tt.file)
+ for _, ent := range tt.ents {
+ testStaticHost(t, ent, h)
+ }
+ }
+}
+
+func testStaticHost(t *testing.T, ent staticHostEntry, h *Hostsfile) {
+ ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
+ for k, in := range ins {
+ addrsV4 := h.LookupStaticHostV4(in)
+ if len(addrsV4) != len(ent.v4) {
+ t.Fatalf("%d, lookupStaticHostV4(%s) = %v; want %v", k, in, addrsV4, ent.v4)
+ }
+ for i, v4 := range addrsV4 {
+ if v4.String() != ent.v4[i] {
+ t.Fatalf("%d, lookupStaticHostV4(%s) = %v; want %v", k, in, addrsV4, ent.v4)
+ }
+ }
+ addrsV6 := h.LookupStaticHostV6(in)
+ if len(addrsV6) != len(ent.v6) {
+ t.Fatalf("%d, lookupStaticHostV6(%s) = %v; want %v", k, in, addrsV6, ent.v6)
+ }
+ for i, v6 := range addrsV6 {
+ if v6.String() != ent.v6[i] {
+ t.Fatalf("%d, lookupStaticHostV6(%s) = %v; want %v", k, in, addrsV6, ent.v6)
+ }
+ }
+ }
+}
+
+type staticIPEntry struct {
+ in string
+ out []string
+}
+
+var lookupStaticAddrTests = []struct {
+ file string
+ ents []staticIPEntry
+}{
+ {
+ hosts,
+ []staticIPEntry{
+ {"255.255.255.255", []string{"broadcasthost"}},
+ {"127.0.0.2", []string{"odin"}},
+ {"127.0.0.3", []string{"odin"}},
+ {"::2", []string{"odin"}},
+ {"127.1.1.1", []string{"thor"}},
+ {"127.1.1.2", []string{"ullr", "ullrhost"}},
+ {"fe80::1", []string{"localhost"}},
+ },
+ },
+ {
+ singlelinehosts, // see golang.org/issue/6646
+ []staticIPEntry{
+ {"127.0.0.2", []string{"odin"}},
+ },
+ },
+ {
+ ipv4hosts, // see golang.org/issue/8996
+ []staticIPEntry{
+ {"127.0.0.1", []string{"localhost"}},
+ {"127.0.0.2", []string{"localhost"}},
+ {"127.0.0.3", []string{"localhost", "localhost.localdomain"}},
+ },
+ },
+ {
+ ipv6hosts, // see golang.org/issue/8996
+ []staticIPEntry{
+ {"::1", []string{"localhost"}},
+ {"fe80::1", []string{"localhost"}},
+ {"fe80::2", []string{"localhost"}},
+ {"fe80::3", []string{"localhost", "localhost.localdomain"}},
+ },
+ },
+ {
+ casehosts, // see golang.org/issue/12806
+ []staticIPEntry{
+ {"127.0.0.1", []string{"PreserveMe", "PreserveMe.local"}},
+ {"::1", []string{"PreserveMe", "PreserveMe.local"}},
+ },
+ },
+}
+
+func TestLookupStaticAddr(t *testing.T) {
+ for _, tt := range lookupStaticAddrTests {
+ h := testHostsfile(tt.file)
+ for _, ent := range tt.ents {
+ testStaticAddr(t, ent, h)
+ }
+ }
+}
+
+func testStaticAddr(t *testing.T, ent staticIPEntry, h *Hostsfile) {
+ hosts := h.LookupStaticAddr(ent.in)
+ for i := range ent.out {
+ ent.out[i] = absDomainName(ent.out[i])
+ }
+ if !reflect.DeepEqual(hosts, ent.out) {
+ t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", h.path, ent.in, hosts, h)
+ }
+}
+
+func TestHostCacheModification(t *testing.T) {
+ // Ensure that programs can't modify the internals of the host cache.
+ // See https://github.com/golang/go/issues/14212.
+
+ h := testHostsfile(ipv4hosts)
+ ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, []string{}}
+ testStaticHost(t, ent, h)
+ // Modify the addresses return by lookupStaticHost.
+ addrs := h.LookupStaticHostV6(ent.in)
+ for i := range addrs {
+ addrs[i] = net.IPv4zero
+ }
+ testStaticHost(t, ent, h)
+
+ h = testHostsfile(ipv6hosts)
+ entip := staticIPEntry{"::1", []string{"localhost"}}
+ testStaticAddr(t, entip, h)
+ // Modify the hosts return by lookupStaticAddr.
+ hosts := h.LookupStaticAddr(entip.in)
+ for i := range hosts {
+ hosts[i] += "junk"
+ }
+ testStaticAddr(t, entip, h)
+}
diff --git a/plugin/hosts/setup.go b/plugin/hosts/setup.go
new file mode 100644
index 000000000..c7c0c728a
--- /dev/null
+++ b/plugin/hosts/setup.go
@@ -0,0 +1,88 @@
+package hosts
+
+import (
+ "log"
+ "os"
+ "path"
+
+ "github.com/coredns/coredns/core/dnsserver"
+ "github.com/coredns/coredns/plugin"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("hosts", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ h, err := hostsParse(c)
+ if err != nil {
+ return plugin.Error("hosts", err)
+ }
+
+ dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
+ h.Next = next
+ return h
+ })
+
+ return nil
+}
+
+func hostsParse(c *caddy.Controller) (Hosts, error) {
+ var h = Hosts{
+ Hostsfile: &Hostsfile{path: "/etc/hosts"},
+ }
+ defer h.ReadHosts()
+
+ config := dnsserver.GetConfig(c)
+
+ for c.Next() {
+ args := c.RemainingArgs()
+ if len(args) >= 1 {
+ h.path = args[0]
+ args = args[1:]
+
+ if !path.IsAbs(h.path) && config.Root != "" {
+ h.path = path.Join(config.Root, h.path)
+ }
+ _, err := os.Stat(h.path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ log.Printf("[WARNING] File does not exist: %s", h.path)
+ } else {
+ return h, c.Errf("unable to access hosts file '%s': %v", h.path, err)
+ }
+ }
+ }
+
+ origins := make([]string, len(c.ServerBlockKeys))
+ copy(origins, c.ServerBlockKeys)
+ if len(args) > 0 {
+ origins = args
+ }
+
+ for i := range origins {
+ origins[i] = plugin.Host(origins[i]).Normalize()
+ }
+ h.Origins = origins
+
+ for c.NextBlock() {
+ switch c.Val() {
+ case "fallthrough":
+ args := c.RemainingArgs()
+ if len(args) == 0 {
+ h.Fallthrough = true
+ continue
+ }
+ return h, c.ArgErr()
+ default:
+ return h, c.Errf("unknown property '%s'", c.Val())
+ }
+ }
+ }
+ return h, nil
+}
diff --git a/plugin/hosts/setup_test.go b/plugin/hosts/setup_test.go
new file mode 100644
index 000000000..a4c95b1c6
--- /dev/null
+++ b/plugin/hosts/setup_test.go
@@ -0,0 +1,86 @@
+package hosts
+
+import (
+ "testing"
+
+ "github.com/mholt/caddy"
+)
+
+func TestHostsParse(t *testing.T) {
+ tests := []struct {
+ inputFileRules string
+ shouldErr bool
+ expectedPath string
+ expectedOrigins []string
+ expectedFallthrough bool
+ }{
+ {
+ `hosts
+`,
+ false, "/etc/hosts", nil, false,
+ },
+ {
+ `hosts /tmp`,
+ false, "/tmp", nil, false,
+ },
+ {
+ `hosts /etc/hosts miek.nl.`,
+ false, "/etc/hosts", []string{"miek.nl."}, false,
+ },
+ {
+ `hosts /etc/hosts miek.nl. pun.gent.`,
+ false, "/etc/hosts", []string{"miek.nl.", "pun.gent."}, false,
+ },
+ {
+ `hosts {
+ fallthrough
+ }`,
+ false, "/etc/hosts", nil, true,
+ },
+ {
+ `hosts /tmp {
+ fallthrough
+ }`,
+ false, "/tmp", nil, true,
+ },
+ {
+ `hosts /etc/hosts miek.nl. {
+ fallthrough
+ }`,
+ false, "/etc/hosts", []string{"miek.nl."}, true,
+ },
+ {
+ `hosts /etc/hosts miek.nl 10.0.0.9/8 {
+ fallthrough
+ }`,
+ false, "/etc/hosts", []string{"miek.nl.", "10.in-addr.arpa."}, true,
+ },
+ }
+
+ for i, test := range tests {
+ c := caddy.NewTestController("dns", test.inputFileRules)
+ h, err := hostsParse(c)
+
+ if err == nil && test.shouldErr {
+ t.Fatalf("Test %d expected errors, but got no error", i)
+ } else if err != nil && !test.shouldErr {
+ t.Fatalf("Test %d expected no errors, but got '%v'", i, err)
+ } else if !test.shouldErr {
+ if h.path != test.expectedPath {
+ t.Fatalf("Test %d expected %v, got %v", i, test.expectedPath, h.path)
+ }
+ } else {
+ if h.Fallthrough != test.expectedFallthrough {
+ t.Fatalf("Test %d expected fallthrough of %v, got %v", i, test.expectedFallthrough, h.Fallthrough)
+ }
+ if len(h.Origins) != len(test.expectedOrigins) {
+ t.Fatalf("Test %d expected %v, got %v", i, test.expectedOrigins, h.Origins)
+ }
+ for j, name := range test.expectedOrigins {
+ if h.Origins[j] != name {
+ t.Fatalf("Test %d expected %v for %d th zone, got %v", i, name, j, h.Origins[j])
+ }
+ }
+ }
+ }
+}