diff options
author | 2017-08-08 11:33:51 +0000 | |
---|---|---|
committer | 2017-08-08 11:33:51 +0000 | |
commit | be551f21a09c1fdcd5f9f9c1f895555796fef8f4 (patch) | |
tree | aabc34575c3a2230b98b772568a2c6720c23dc8c | |
parent | 7ed44cb6cebc68109dc6fc919858ebbdfc2da066 (diff) | |
download | coredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.tar.gz coredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.tar.zst coredns-be551f21a09c1fdcd5f9f9c1f895555796fef8f4.zip |
core: add missing trimzone files
*ugh*: forgot to add these files. This add the dnsutil.TrimZone
function.
-rw-r--r-- | middleware/pkg/dnsutil/zone.go | 20 | ||||
-rw-r--r-- | middleware/pkg/dnsutil/zone_test.go | 39 |
2 files changed, 59 insertions, 0 deletions
diff --git a/middleware/pkg/dnsutil/zone.go b/middleware/pkg/dnsutil/zone.go new file mode 100644 index 000000000..579fef1ba --- /dev/null +++ b/middleware/pkg/dnsutil/zone.go @@ -0,0 +1,20 @@ +package dnsutil + +import ( + "errors" + + "github.com/miekg/dns" +) + +// TrimZone removes the zone component from q. It returns the trimmed +// name or an error is zone is longer then qname. The trimmed name will be returned +// without a trailing dot. +func TrimZone(q string, z string) (string, error) { + zl := dns.CountLabel(z) + i, ok := dns.PrevLabel(q, zl) + if ok || i-1 < 0 { + return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z) + } + // This includes the '.', remove on return + return q[:i-1], nil +} diff --git a/middleware/pkg/dnsutil/zone_test.go b/middleware/pkg/dnsutil/zone_test.go new file mode 100644 index 000000000..334f3d9d2 --- /dev/null +++ b/middleware/pkg/dnsutil/zone_test.go @@ -0,0 +1,39 @@ +package dnsutil + +import ( + "errors" + "testing" + + "github.com/miekg/dns" +) + +func TestTrimZone(t *testing.T) { + tests := []struct { + qname string + zone string + expected string + err error + }{ + {"a.example.org", "example.org", "a", nil}, + {"a.b.example.org", "example.org", "a.b", nil}, + {"b.", ".", "b", nil}, + {"example.org", "example.org", "", errors.New("should err")}, + {"org", "example.org", "", errors.New("should err")}, + } + + for i, tc := range tests { + got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone)) + if tc.err != nil && err == nil { + t.Errorf("Test %d, expected error got nil") + continue + } + if tc.err == nil && err != nil { + t.Errorf("Test %d, expected no error got %v", i, err) + continue + } + if got != tc.expected { + t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got) + continue + } + } +} |