aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
Diffstat (limited to 'plugin')
-rw-r--r--plugin/template/setup.go2
-rw-r--r--plugin/template/template.go6
-rw-r--r--plugin/template/template_test.go22
3 files changed, 26 insertions, 4 deletions
diff --git a/plugin/template/setup.go b/plugin/template/setup.go
index cd4cc1b90..841d2944f 100644
--- a/plugin/template/setup.go
+++ b/plugin/template/setup.go
@@ -149,7 +149,7 @@ func templateParse(c *caddy.Controller) (handler Handler, err error) {
if err != nil {
return handler, err
}
- t.upstream = u
+ t.upstream = &u
default:
return handler, c.ArgErr()
}
diff --git a/plugin/template/template.go b/plugin/template/template.go
index 9bf5f8dd7..0bec0000f 100644
--- a/plugin/template/template.go
+++ b/plugin/template/template.go
@@ -34,7 +34,7 @@ type template struct {
qclass uint16
qtype uint16
fall fall.F
- upstream upstream.Upstream
+ upstream *upstream.Upstream
}
type templateData struct {
@@ -84,8 +84,8 @@ func (h Handler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
return dns.RcodeServerFailure, err
}
msg.Answer = append(msg.Answer, rr)
- if rr.Header().Rrtype == dns.TypeCNAME {
- up, _ := template.upstream.Lookup(state, rr.(*dns.CNAME).Target, dns.TypeA)
+ if template.upstream != nil && (state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA) && rr.Header().Rrtype == dns.TypeCNAME {
+ up, _ := template.upstream.Lookup(state, rr.(*dns.CNAME).Target, state.QType())
msg.Answer = append(msg.Answer, up.Answer...)
}
}
diff --git a/plugin/template/template_test.go b/plugin/template/template_test.go
index 288d833ec..045eba772 100644
--- a/plugin/template/template_test.go
+++ b/plugin/template/template_test.go
@@ -92,6 +92,14 @@ func TestHandler(t *testing.T) {
fall: fall.Root,
zones: []string{"."},
}
+ cnameTemplate := template{
+ regex: []*regexp.Regexp{regexp.MustCompile("example[.]net[.]")},
+ answer: []*gotmpl.Template{gotmpl.Must(gotmpl.New("answer").Parse("example.net 60 IN CNAME target.example.com"))},
+ qclass: dns.ClassANY,
+ qtype: dns.TypeANY,
+ fall: fall.Root,
+ zones: []string{"."},
+ }
tests := []struct {
tmpl template
@@ -254,6 +262,20 @@ func TestHandler(t *testing.T) {
return nil
},
},
+ {
+ name: "CNAMEWithoutUpstream",
+ tmpl: cnameTemplate,
+ qclass: dns.ClassINET,
+ qtype: dns.TypeA,
+ qname: "example.net.",
+ expectedCode: dns.RcodeSuccess,
+ verifyResponse: func(r *dns.Msg) error {
+ if len(r.Answer) != 1 {
+ return fmt.Errorf("expected 1 answer, got %v", len(r.Answer))
+ }
+ return nil
+ },
+ },
}
ctx := context.TODO()