diff options
Diffstat (limited to 'plugin/rewrite/ttl.go')
-rw-r--r-- | plugin/rewrite/ttl.go | 75 |
1 files changed, 54 insertions, 21 deletions
diff --git a/plugin/rewrite/ttl.go b/plugin/rewrite/ttl.go index 364583dc7..1791301d6 100644 --- a/plugin/rewrite/ttl.go +++ b/plugin/rewrite/ttl.go @@ -14,11 +14,16 @@ import ( ) type ttlResponseRule struct { - TTL uint32 + minTTL uint32 + maxTTL uint32 } func (r *ttlResponseRule) RewriteResponse(rr dns.RR) { - rr.Header().Ttl = r.TTL + if rr.Header().Ttl < r.minTTL { + rr.Header().Ttl = r.minTTL + } else if rr.Header().Ttl > r.maxTTL { + rr.Header().Ttl = r.maxTTL + } } type ttlRuleBase struct { @@ -26,10 +31,10 @@ type ttlRuleBase struct { response ttlResponseRule } -func newTTLRuleBase(nextAction string, ttl uint32) ttlRuleBase { +func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase { return ttlRuleBase{ nextAction: nextAction, - response: ttlResponseRule{TTL: ttl}, + response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl}, } } @@ -108,7 +113,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { if len(args) == 3 { s = args[2] } - ttl, valid := isValidTTL(s) + minTtl, maxTtl, valid := isValidTTL(s) if !valid { return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s) } @@ -116,22 +121,22 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { switch strings.ToLower(args[0]) { case ExactMatch: return &exactTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), plugin.Name(args[1]).Normalize(), }, nil case PrefixMatch: return &prefixTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), plugin.Name(args[1]).Normalize(), }, nil case SuffixMatch: return &suffixTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), plugin.Name(args[1]).Normalize(), }, nil case SubstringMatch: return &substringTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), plugin.Name(args[1]).Normalize(), }, nil case RegexMatch: @@ -140,7 +145,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1]) } return ®exTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), regexPattern, }, nil default: @@ -151,22 +156,50 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) { return nil, fmt.Errorf("many few arguments for a ttl rule") } return &exactTTLRule{ - newTTLRuleBase(nextAction, ttl), + newTTLRuleBase(nextAction, minTtl, maxTtl), plugin.Name(args[0]).Normalize(), }, nil } // validTTL returns true if v is valid TTL value. -func isValidTTL(v string) (uint32, bool) { - i, err := strconv.Atoi(v) - if err != nil { - return uint32(0), false - } - if i > 2147483647 { - return uint32(0), false +func isValidTTL(v string) (uint32, uint32, bool) { + s := strings.Split(v, "-") + if len(s) == 1 { + i, err := strconv.ParseUint(s[0], 10, 32) + if err != nil { + return 0, 0, false + } + return uint32(i), uint32(i), true } - if i < 0 { - return uint32(0), false + if len(s) == 2 { + var min, max uint64 + var err error + if s[0] == "" { + min = 0 + } else { + min, err = strconv.ParseUint(s[0], 10, 32) + if err != nil { + return 0, 0, false + } + } + if s[1] == "" { + if s[0] == "" { + // explicitly reject ttl directive "-" that would otherwise be interpreted + // as 0-2147483647 which is pretty useless + return 0, 0, false + } + max = 2147483647 + } else { + max, err = strconv.ParseUint(s[1], 10, 32) + if err != nil { + return 0, 0, false + } + } + if min > max { + // reject invalid range + return 0, 0, false + } + return uint32(min), uint32(max), true } - return uint32(i), true + return 0, 0, false } |