diff options
Diffstat (limited to 'plugin/pkg/replacer')
-rw-r--r-- | plugin/pkg/replacer/replacer.go | 161 | ||||
-rw-r--r-- | plugin/pkg/replacer/replacer_test.go | 61 |
2 files changed, 222 insertions, 0 deletions
diff --git a/plugin/pkg/replacer/replacer.go b/plugin/pkg/replacer/replacer.go new file mode 100644 index 000000000..fc98e5d29 --- /dev/null +++ b/plugin/pkg/replacer/replacer.go @@ -0,0 +1,161 @@ +package replacer + +import ( + "strconv" + "strings" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Replacer is a type which can replace placeholder +// substrings in a string with actual values from a +// dns.Msg and responseRecorder. Always use +// NewReplacer to get one of these. +type Replacer interface { + Replace(string) string + Set(key, value string) +} + +type replacer struct { + replacements map[string]string + emptyValue string +} + +// New makes a new replacer based on r and rr. +// Do not create a new replacer until r and rr have all +// the needed values, because this function copies those +// values into the replacer. rr may be nil if it is not +// available. emptyValue should be the string that is used +// in place of empty string (can still be empty string). +func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer { + req := request.Request{W: rr, Req: r} + rep := replacer{ + replacements: map[string]string{ + "{type}": req.Type(), + "{name}": req.Name(), + "{class}": req.Class(), + "{proto}": req.Proto(), + "{when}": func() string { + return time.Now().Format(timeFormat) + }(), + "{size}": strconv.Itoa(req.Len()), + "{remote}": req.IP(), + "{port}": req.Port(), + }, + emptyValue: emptyValue, + } + if rr != nil { + rcode := dns.RcodeToString[rr.Rcode] + if rcode == "" { + rcode = strconv.Itoa(rr.Rcode) + } + rep.replacements["{rcode}"] = rcode + rep.replacements["{rsize}"] = strconv.Itoa(rr.Len) + rep.replacements["{duration}"] = time.Since(rr.Start).String() + if rr.Msg != nil { + rep.replacements[headerReplacer+"rflags}"] = flagsToString(rr.Msg.MsgHdr) + } + } + + // Header placeholders (case-insensitive) + rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id)) + rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(r.Opcode) + rep.replacements[headerReplacer+"do}"] = boolToString(req.Do()) + rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size()) + + return rep +} + +// Replace performs a replacement of values on s and returns +// the string with the replaced values. +func (r replacer) Replace(s string) string { + // Header replacements - these are case-insensitive, so we can't just use strings.Replace() + for strings.Contains(s, headerReplacer) { + idxStart := strings.Index(s, headerReplacer) + endOffset := idxStart + len(headerReplacer) + idxEnd := strings.Index(s[endOffset:], "}") + if idxEnd > -1 { + placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) + replacement := r.replacements[placeholder] + if replacement == "" { + replacement = r.emptyValue + } + s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] + } else { + break + } + } + + // Regular replacements - these are easier because they're case-sensitive + for placeholder, replacement := range r.replacements { + if replacement == "" { + replacement = r.emptyValue + } + s = strings.Replace(s, placeholder, replacement, -1) + } + + return s +} + +// Set sets key to value in the replacements map. +func (r replacer) Set(key, value string) { + r.replacements["{"+key+"}"] = value +} + +func boolToString(b bool) string { + if b { + return "true" + } + return "false" +} + +// flagsToString checks all header flags and returns those +// that are set as a string separated with commas +func flagsToString(h dns.MsgHdr) string { + flags := make([]string, 7) + i := 0 + + if h.Response { + flags[i] = "qr" + i++ + } + + if h.Authoritative { + flags[i] = "aa" + i++ + } + if h.Truncated { + flags[i] = "tc" + i++ + } + if h.RecursionDesired { + flags[i] = "rd" + i++ + } + if h.RecursionAvailable { + flags[i] = "ra" + i++ + } + if h.Zero { + flags[i] = "z" + i++ + } + if h.AuthenticatedData { + flags[i] = "ad" + i++ + } + if h.CheckingDisabled { + flags[i] = "cd" + i++ + } + return strings.Join(flags[:i], ",") +} + +const ( + timeFormat = "02/Jan/2006:15:04:05 -0700" + headerReplacer = "{>" +) diff --git a/plugin/pkg/replacer/replacer_test.go b/plugin/pkg/replacer/replacer_test.go new file mode 100644 index 000000000..95c3bbd52 --- /dev/null +++ b/plugin/pkg/replacer/replacer_test.go @@ -0,0 +1,61 @@ +package replacer + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnsrecorder" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestNewReplacer(t *testing.T) { + w := dnsrecorder.New(&test.ResponseWriter{}) + + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + + replaceValues := New(r, w, "") + + switch v := replaceValues.(type) { + case replacer: + + if v.replacements["{type}"] != "HINFO" { + t.Errorf("Expected type to be HINFO, got %q", v.replacements["{type}"]) + } + if v.replacements["{name}"] != "example.org." { + t.Errorf("Expected request name to be example.org., got %q", v.replacements["{name}"]) + } + if v.replacements["{size}"] != "29" { // size of request + t.Errorf("Expected size to be 29, got %q", v.replacements["{size}"]) + } + + default: + t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n") + } +} + +func TestSet(t *testing.T) { + w := dnsrecorder.New(&test.ResponseWriter{}) + + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + + repl := New(r, w, "") + + repl.Set("name", "coredns.io.") + repl.Set("type", "A") + repl.Set("size", "20") + + if repl.Replace("This name is {name}") != "This name is coredns.io." { + t.Error("Expected name replacement failed") + } + if repl.Replace("This type is {type}") != "This type is A" { + t.Error("Expected type replacement failed") + } + if repl.Replace("The request size is {size}") != "The request size is 20" { + t.Error("Expected size replacement failed") + } +} |