diff options
-rw-r--r-- | plugin/hosts/README.md | 6 | ||||
-rw-r--r-- | plugin/hosts/hosts.go | 18 | ||||
-rw-r--r-- | plugin/hosts/hosts_test.go | 15 | ||||
-rw-r--r-- | plugin/hosts/hostsfile.go | 110 | ||||
-rw-r--r-- | plugin/hosts/hostsfile_test.go | 6 | ||||
-rw-r--r-- | plugin/hosts/setup.go | 78 |
6 files changed, 163 insertions, 70 deletions
diff --git a/plugin/hosts/README.md b/plugin/hosts/README.md index bdd1955c9..e525b9a65 100644 --- a/plugin/hosts/README.md +++ b/plugin/hosts/README.md @@ -41,6 +41,9 @@ PTR records for reverse lookups are generated automatically by CoreDNS (based on ~~~ hosts [FILE [ZONES...]] { [INLINE] + ttl SECONDS + no_reverse + reload DURATION fallthrough [ZONES...] } ~~~ @@ -53,6 +56,9 @@ hosts [FILE [ZONES...]] { * **INLINE** the hosts file contents inlined in Corefile. If there are any lines before fallthrough then all of them will be treated as the additional content for hosts file. The specified hosts file path will still be read but entries will be overrided. +* `ttl` change the DNS TTL of the records generated (forward and reverse). The default is 3600 seonds (1 hour). +* `reload` change the period between each hostsfile reload. A time of zero seconds disable the feature. Examples of valid durations: "300ms", "1.5h" or "2h45m" are valid duration with units "ns" (nanosecond), "us" (or "µs" for microsecond), "ms" (millisecond), "s" (second), "m" (minute), "h" (hour). +* `no_reverse` disable the automatic generation of the the `in-addr.arpa` or `ip6.arpa` entries for the hosts * `fallthrough` If zone matches and no record can be generated, pass request to the next plugin. If **[ZONES...]** is omitted, then fallthrough happens for all zones for which the plugin is authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then only diff --git a/plugin/hosts/hosts.go b/plugin/hosts/hosts.go index a23a93801..8650053c0 100644 --- a/plugin/hosts/hosts.go +++ b/plugin/hosts/hosts.go @@ -43,13 +43,13 @@ func (h Hosts) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( // If this doesn't match we need to fall through regardless of h.Fallthrough return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) } - answers = h.ptr(qname, names) + answers = h.ptr(qname, h.options.ttl, names) case dns.TypeA: ips := h.LookupStaticHostV4(qname) - answers = a(qname, ips) + answers = a(qname, h.options.ttl, ips) case dns.TypeAAAA: ips := h.LookupStaticHostV6(qname) - answers = aaaa(qname, ips) + answers = aaaa(qname, h.options.ttl, ips) } if len(answers) == 0 { @@ -96,12 +96,12 @@ func (h Hosts) otherRecordsExist(qtype uint16, qname string) bool { func (h Hosts) Name() string { return "hosts" } // a takes a slice of net.IPs and returns a slice of A RRs. -func a(zone string, ips []net.IP) []dns.RR { +func a(zone string, ttl uint32, ips []net.IP) []dns.RR { answers := []dns.RR{} for _, ip := range ips { r := new(dns.A) r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeA, - Class: dns.ClassINET, Ttl: 3600} + Class: dns.ClassINET, Ttl: ttl} r.A = ip answers = append(answers, r) } @@ -109,12 +109,12 @@ func a(zone string, ips []net.IP) []dns.RR { } // aaaa takes a slice of net.IPs and returns a slice of AAAA RRs. -func aaaa(zone string, ips []net.IP) []dns.RR { +func aaaa(zone string, ttl uint32, ips []net.IP) []dns.RR { answers := []dns.RR{} for _, ip := range ips { r := new(dns.AAAA) r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, Ttl: 3600} + Class: dns.ClassINET, Ttl: ttl} r.AAAA = ip answers = append(answers, r) } @@ -122,12 +122,12 @@ func aaaa(zone string, ips []net.IP) []dns.RR { } // ptr takes a slice of host names and filters out the ones that aren't in Origins, if specified, and returns a slice of PTR RRs. -func (h *Hosts) ptr(zone string, names []string) []dns.RR { +func (h *Hosts) ptr(zone string, ttl uint32, names []string) []dns.RR { answers := []dns.RR{} for _, n := range names { r := new(dns.PTR) r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypePTR, - Class: dns.ClassINET, Ttl: 3600} + Class: dns.ClassINET, Ttl: ttl} r.Ptr = dns.Fqdn(n) answers = append(answers, r) } diff --git a/plugin/hosts/hosts_test.go b/plugin/hosts/hosts_test.go index e412b7ae5..975710bb3 100644 --- a/plugin/hosts/hosts_test.go +++ b/plugin/hosts/hosts_test.go @@ -12,10 +12,19 @@ import ( "github.com/miekg/dns" ) -func (h *Hostsfile) parseReader(r io.Reader) { h.hmap = h.parse(r, h.inline) } +func (h *Hostsfile) parseReader(r io.Reader) { + h.hmap = h.parse(r) +} func TestLookupA(t *testing.T) { - h := Hosts{Next: test.ErrorHandler(), Hostsfile: &Hostsfile{Origins: []string{"."}}} + h := Hosts{ + Next: test.ErrorHandler(), + Hostsfile: &Hostsfile{ + Origins: []string{"."}, + hmap: newHostsMap(), + options: newOptions(), + }, + } h.parseReader(strings.NewReader(hostsExample)) ctx := context.TODO() @@ -90,4 +99,6 @@ const hostsExample = ` ::1 localhost localhost.domain 10.0.0.1 example.org ::FFFF:10.0.0.2 example.com +reload 5s +timeout 3600 ` diff --git a/plugin/hosts/hostsfile.go b/plugin/hosts/hostsfile.go index 1cce850c3..5e0fd5bf6 100644 --- a/plugin/hosts/hostsfile.go +++ b/plugin/hosts/hostsfile.go @@ -32,6 +32,26 @@ func absDomainName(b string) string { return plugin.Name(b).Normalize() } +type options struct { + // automatically generate IP to Hostname PTR entries + // for host entries we parse + autoReverse bool + + // The TTL of the record we generate + ttl uint32 + + // The time between two reload of the configuration + reload time.Duration +} + +func newOptions() *options { + return &options{ + autoReverse: true, + ttl: 3600, + reload: durationOf5s, + } +} + type hostsMap struct { // Key for the list of literal IP addresses must be a host // name. It would be part of DNS labels, a FQDN or an absolute @@ -46,6 +66,11 @@ type hostsMap struct { byAddr map[string][]string } +const ( + durationOf0s = time.Duration(0) + durationOf5s = time.Duration(5 * time.Second) +) + func newHostsMap() *hostsMap { return &hostsMap{ byNameV4: make(map[string][]net.IP), @@ -90,6 +115,8 @@ type Hostsfile struct { // mtime and size are only read and modified by a single goroutine mtime time.Time size int64 + + options *options } // readHosts determines if the cached data needs to be updated based on the size and modification time of the hostsfile. @@ -106,7 +133,7 @@ func (h *Hostsfile) readHosts() { return } - newMap := h.parse(file, h.inline) + newMap := h.parse(file) log.Debugf("Parsed hosts file into %d entries", newMap.Len()) h.Lock() @@ -124,13 +151,12 @@ func (h *Hostsfile) initInline(inline []string) { return } - hmap := newHostsMap() - h.inline = h.parse(strings.NewReader(strings.Join(inline, "\n")), hmap) + h.inline = h.parse(strings.NewReader(strings.Join(inline, "\n"))) *h.hmap = *h.inline } // Parse reads the hostsfile and populates the byName and byAddr maps. -func (h *Hostsfile) parse(r io.Reader, override *hostsMap) *hostsMap { +func (h *Hostsfile) parse(r io.Reader) *hostsMap { hmap := newHostsMap() scanner := bufio.NewScanner(r) @@ -163,22 +189,22 @@ func (h *Hostsfile) parse(r io.Reader, override *hostsMap) *hostsMap { default: continue } + if !h.options.autoReverse { + continue + } hmap.byAddr[addr.String()] = append(hmap.byAddr[addr.String()], name) } } - if override == nil { - return hmap - } - - for name := range override.byNameV4 { - hmap.byNameV4[name] = append(hmap.byNameV4[name], override.byNameV4[name]...) + for name := range h.hmap.byNameV4 { + hmap.byNameV4[name] = append(hmap.byNameV4[name], h.hmap.byNameV4[name]...) } - for name := range override.byNameV4 { - hmap.byNameV6[name] = append(hmap.byNameV6[name], override.byNameV6[name]...) + for name := range h.hmap.byNameV4 { + hmap.byNameV6[name] = append(hmap.byNameV6[name], h.hmap.byNameV6[name]...) } - for addr := range override.byAddr { - hmap.byAddr[addr] = append(hmap.byAddr[addr], override.byAddr[addr]...) + + for addr := range h.hmap.byAddr { + hmap.byAddr[addr] = append(hmap.byAddr[addr], h.hmap.byAddr[addr]...) } return hmap @@ -199,32 +225,34 @@ func ipVersion(s string) int { return 0 } -// LookupStaticHostV4 looks up the IPv4 addresses for the given host from the hosts file. -func (h *Hostsfile) LookupStaticHostV4(host string) []net.IP { +// LookupStaticHost looks up the IP addresses for the given host from the hosts file. +func (h *Hostsfile) lookupStaticHost(hmapByName map[string][]net.IP, host string) []net.IP { + fqhost := absDomainName(host) + h.RLock() defer h.RUnlock() - if len(h.hmap.byNameV4) != 0 { - if ips, ok := h.hmap.byNameV4[absDomainName(host)]; ok { - ipsCp := make([]net.IP, len(ips)) - copy(ipsCp, ips) - return ipsCp - } + + if len(hmapByName) == 0 { + return nil } - return nil + + ips, ok := hmapByName[fqhost] + if !ok { + return nil + } + ipsCp := make([]net.IP, len(ips)) + copy(ipsCp, ips) + return ipsCp +} + +// LookupStaticHostV4 looks up the IPv4 addresses for the given host from the hosts file. +func (h *Hostsfile) LookupStaticHostV4(host string) []net.IP { + return h.lookupStaticHost(h.hmap.byNameV4, host) } // LookupStaticHostV6 looks up the IPv6 addresses for the given host from the hosts file. func (h *Hostsfile) LookupStaticHostV6(host string) []net.IP { - h.RLock() - defer h.RUnlock() - if len(h.hmap.byNameV6) != 0 { - if ips, ok := h.hmap.byNameV6[absDomainName(host)]; ok { - ipsCp := make([]net.IP, len(ips)) - copy(ipsCp, ips) - return ipsCp - } - } - return nil + return h.lookupStaticHost(h.hmap.byNameV6, host) } // LookupStaticAddr looks up the hosts for the given address from the hosts file. @@ -235,12 +263,14 @@ func (h *Hostsfile) LookupStaticAddr(addr string) []string { if addr == "" { return nil } - if len(h.hmap.byAddr) != 0 { - if hosts, ok := h.hmap.byAddr[addr]; ok { - hostsCp := make([]string, len(hosts)) - copy(hostsCp, hosts) - return hostsCp - } + if len(h.hmap.byAddr) == 0 { + return nil + } + hosts, ok := h.hmap.byAddr[addr] + if !ok { + return nil } - return nil + hostsCp := make([]string, len(hosts)) + copy(hostsCp, hosts) + return hostsCp } diff --git a/plugin/hosts/hostsfile_test.go b/plugin/hosts/hostsfile_test.go index 26a2916f0..db0e63d75 100644 --- a/plugin/hosts/hostsfile_test.go +++ b/plugin/hosts/hostsfile_test.go @@ -12,7 +12,11 @@ import ( ) func testHostsfile(file string) *Hostsfile { - h := &Hostsfile{Origins: []string{"."}} + h := &Hostsfile{ + Origins: []string{"."}, + hmap: newHostsMap(), + options: newOptions(), + } h.parseReader(strings.NewReader(file)) return h } diff --git a/plugin/hosts/setup.go b/plugin/hosts/setup.go index 3dafb5be7..945a53dfd 100644 --- a/plugin/hosts/setup.go +++ b/plugin/hosts/setup.go @@ -3,6 +3,7 @@ package hosts import ( "os" "path/filepath" + "strconv" "strings" "time" @@ -22,28 +23,37 @@ func init() { }) } +func periodicHostsUpdate(h *Hosts) chan bool { + parseChan := make(chan bool) + + if h.options.reload == durationOf0s { + return parseChan + } + + go func() { + ticker := time.NewTicker(h.options.reload) + for { + select { + case <-parseChan: + return + case <-ticker.C: + h.readHosts() + } + } + }() + return parseChan +} + func setup(c *caddy.Controller) error { h, err := hostsParse(c) if err != nil { return plugin.Error("hosts", err) } - parseChan := make(chan bool) + parseChan := periodicHostsUpdate(&h) c.OnStartup(func() error { h.readHosts() - - go func() { - ticker := time.NewTicker(5 * time.Second) - for { - select { - case <-parseChan: - return - case <-ticker.C: - h.readHosts() - } - } - }() return nil }) @@ -61,15 +71,18 @@ func setup(c *caddy.Controller) error { } func hostsParse(c *caddy.Controller) (Hosts, error) { - var h = Hosts{ + config := dnsserver.GetConfig(c) + + options := newOptions() + + h := Hosts{ Hostsfile: &Hostsfile{ - path: "/etc/hosts", - hmap: newHostsMap(), + path: "/etc/hosts", + hmap: newHostsMap(), + options: options, }, } - config := dnsserver.GetConfig(c) - inline := []string{} i := 0 for c.Next() { @@ -79,6 +92,7 @@ func hostsParse(c *caddy.Controller) (Hosts, error) { i++ args := c.RemainingArgs() + if len(args) >= 1 { h.path = args[0] args = args[1:] @@ -114,6 +128,34 @@ func hostsParse(c *caddy.Controller) (Hosts, error) { switch c.Val() { case "fallthrough": h.Fall.SetZonesFromArgs(c.RemainingArgs()) + case "no_reverse": + options.autoReverse = false + case "ttl": + remaining := c.RemainingArgs() + if len(remaining) < 1 { + return h, c.Errf("ttl needs a time in second") + } + ttl, err := strconv.Atoi(remaining[0]) + if err != nil { + return h, c.Errf("ttl needs a number of second") + } + if ttl <= 0 || ttl > 65535 { + return h, c.Errf("ttl provided is invalid") + } + options.ttl = uint32(ttl) + case "reload": + remaining := c.RemainingArgs() + if len(remaining) != 1 { + return h, c.Errf("reload needs a duration (zero seconds to disable)") + } + reload, err := time.ParseDuration(remaining[0]) + if err != nil { + return h, c.Errf("invalid duration for reload '%s'", remaining[0]) + } + if reload < durationOf0s { + return h, c.Errf("invalid negative duration for reload '%s'", remaining[0]) + } + options.reload = reload default: if len(h.Fall.Zones) == 0 { line := strings.Join(append([]string{c.Val()}, c.RemainingArgs()...), " ") |