diff options
Diffstat (limited to 'middleware/normalize.go')
-rw-r--r-- | middleware/normalize.go | 68 |
1 files changed, 52 insertions, 16 deletions
diff --git a/middleware/normalize.go b/middleware/normalize.go index 77ef97993..18fe58f61 100644 --- a/middleware/normalize.go +++ b/middleware/normalize.go @@ -1,7 +1,9 @@ package middleware import ( + "fmt" "net" + "strconv" "strings" "github.com/miekg/dns" @@ -53,8 +55,6 @@ func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) } type ( // Host represents a host from the Corefile, may contain port. Host string // Host represents a host from the Corefile, may contain port. - // Addr represents an address in the Corefile. - Addr string // Addr resprents an address in the Corefile. ) // Normalize will return the host portion of host, stripping @@ -72,24 +72,60 @@ func (h Host) Normalize() string { s = s[len(TransportGRPC+"://"):] } - // separate host and port - host, _, err := net.SplitHostPort(s) - if err != nil { - host, _, _ = net.SplitHostPort(s + ":") - } + // The error can be ignore here, because this function is called after the corefile + // has already been vetted. + host, _, _ := SplitHostPort(s) return Name(host).Normalize() } -// Normalize will return a normalized address, if not port is specified -// port 53 is added, otherwise the port will be left as is. -func (a Addr) Normalize() string { - // separate host and port - addr, port, err := net.SplitHostPort(string(a)) - if err != nil { - addr, port, _ = net.SplitHostPort(string(a) + ":53") +// SplitHostPort splits s up in a host and port portion, taking reverse address notation into account. +// String the string s should *not* be prefixed with any protocols, i.e. dns:// +func SplitHostPort(s string) (host, port string, err error) { + // If there is: :[0-9]+ on the end we assume this is the port. This works for (ascii) domain + // names and our reverse syntax, which always needs a /mask *before* the port. + // So from the back, find first colon, and then check if its a number. + host = s + + colon := strings.LastIndex(s, ":") + if colon == len(s)-1 { + return "", "", fmt.Errorf("expecting data after last colon: %q", s) + } + if colon != -1 { + if p, err := strconv.Atoi(s[colon+1:]); err == nil { + port = strconv.Itoa(p) + host = s[:colon] + } + } + + // TODO(miek): this should take escaping into account. + if len(host) > 255 { + return "", "", fmt.Errorf("specified zone is too long: %d > 255", len(host)) + } + + _, d := dns.IsDomainName(host) + if !d { + return "", "", fmt.Errorf("zone is not a valid domain name: %s", host) + } + + // Check if it parses as a reverse zone, if so we use that. Must be fully + // specified IP and mask and mask % 8 = 0. + ip, net, err := net.ParseCIDR(host) + if err == nil { + if rev, e := dns.ReverseAddr(ip.String()); e == nil { + ones, bits := net.Mask.Size() + if (bits-ones)%8 == 0 { + offset, end := 0, false + for i := 0; i < (bits-ones)/8; i++ { + offset, end = dns.NextLabel(rev, offset) + if end { + break + } + } + host = rev[offset:] + } + } } - // TODO(miek): lowercase it? - return net.JoinHostPort(addr, port) + return host, port, nil } // Duplicated from core/dnsserver/address.go ! |