aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2017-08-08 11:33:51 +0000
committerGravatar Miek Gieben <miek@miek.nl> 2017-08-08 11:33:51 +0000
commitbe551f21a09c1fdcd5f9f9c1f895555796fef8f4 (patch)
treeaabc34575c3a2230b98b772568a2c6720c23dc8c
parent7ed44cb6cebc68109dc6fc919858ebbdfc2da066 (diff)
downloadcoredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.tar.gz
coredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.tar.zst
coredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.zip
core: add missing trimzone files
*ugh*: forgot to add these files. This add the dnsutil.TrimZone function.
-rw-r--r--middleware/pkg/dnsutil/zone.go20
-rw-r--r--middleware/pkg/dnsutil/zone_test.go39
2 files changed, 59 insertions, 0 deletions
diff --git a/middleware/pkg/dnsutil/zone.go b/middleware/pkg/dnsutil/zone.go
new file mode 100644
index 000000000..579fef1ba
--- /dev/null
+++ b/middleware/pkg/dnsutil/zone.go
@@ -0,0 +1,20 @@
+package dnsutil
+
+import (
+ "errors"
+
+ "github.com/miekg/dns"
+)
+
+// TrimZone removes the zone component from q. It returns the trimmed
+// name or an error is zone is longer then qname. The trimmed name will be returned
+// without a trailing dot.
+func TrimZone(q string, z string) (string, error) {
+ zl := dns.CountLabel(z)
+ i, ok := dns.PrevLabel(q, zl)
+ if ok || i-1 < 0 {
+ return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
+ }
+ // This includes the '.', remove on return
+ return q[:i-1], nil
+}
diff --git a/middleware/pkg/dnsutil/zone_test.go b/middleware/pkg/dnsutil/zone_test.go
new file mode 100644
index 000000000..334f3d9d2
--- /dev/null
+++ b/middleware/pkg/dnsutil/zone_test.go
@@ -0,0 +1,39 @@
+package dnsutil
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestTrimZone(t *testing.T) {
+ tests := []struct {
+ qname string
+ zone string
+ expected string
+ err error
+ }{
+ {"a.example.org", "example.org", "a", nil},
+ {"a.b.example.org", "example.org", "a.b", nil},
+ {"b.", ".", "b", nil},
+ {"example.org", "example.org", "", errors.New("should err")},
+ {"org", "example.org", "", errors.New("should err")},
+ }
+
+ for i, tc := range tests {
+ got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone))
+ if tc.err != nil && err == nil {
+ t.Errorf("Test %d, expected error got nil")
+ continue
+ }
+ if tc.err == nil && err != nil {
+ t.Errorf("Test %d, expected no error got %v", i, err)
+ continue
+ }
+ if got != tc.expected {
+ t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got)
+ continue
+ }
+ }
+}