aboutsummaryrefslogtreecommitdiff
path: root/plugin/rewrite/ttl.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugin/rewrite/ttl.go')
-rw-r--r--plugin/rewrite/ttl.go75
1 files changed, 54 insertions, 21 deletions
diff --git a/plugin/rewrite/ttl.go b/plugin/rewrite/ttl.go
index 364583dc7..1791301d6 100644
--- a/plugin/rewrite/ttl.go
+++ b/plugin/rewrite/ttl.go
@@ -14,11 +14,16 @@ import (
)
type ttlResponseRule struct {
- TTL uint32
+ minTTL uint32
+ maxTTL uint32
}
func (r *ttlResponseRule) RewriteResponse(rr dns.RR) {
- rr.Header().Ttl = r.TTL
+ if rr.Header().Ttl < r.minTTL {
+ rr.Header().Ttl = r.minTTL
+ } else if rr.Header().Ttl > r.maxTTL {
+ rr.Header().Ttl = r.maxTTL
+ }
}
type ttlRuleBase struct {
@@ -26,10 +31,10 @@ type ttlRuleBase struct {
response ttlResponseRule
}
-func newTTLRuleBase(nextAction string, ttl uint32) ttlRuleBase {
+func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase {
return ttlRuleBase{
nextAction: nextAction,
- response: ttlResponseRule{TTL: ttl},
+ response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl},
}
}
@@ -108,7 +113,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
if len(args) == 3 {
s = args[2]
}
- ttl, valid := isValidTTL(s)
+ minTtl, maxTtl, valid := isValidTTL(s)
if !valid {
return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s)
}
@@ -116,22 +121,22 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
switch strings.ToLower(args[0]) {
case ExactMatch:
return &exactTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case PrefixMatch:
return &prefixTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SuffixMatch:
return &suffixTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SubstringMatch:
return &substringTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case RegexMatch:
@@ -140,7 +145,7 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1])
}
return &regexTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
regexPattern,
}, nil
default:
@@ -151,22 +156,50 @@ func newTTLRule(nextAction string, args ...string) (Rule, error) {
return nil, fmt.Errorf("many few arguments for a ttl rule")
}
return &exactTTLRule{
- newTTLRuleBase(nextAction, ttl),
+ newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[0]).Normalize(),
}, nil
}
// validTTL returns true if v is valid TTL value.
-func isValidTTL(v string) (uint32, bool) {
- i, err := strconv.Atoi(v)
- if err != nil {
- return uint32(0), false
- }
- if i > 2147483647 {
- return uint32(0), false
+func isValidTTL(v string) (uint32, uint32, bool) {
+ s := strings.Split(v, "-")
+ if len(s) == 1 {
+ i, err := strconv.ParseUint(s[0], 10, 32)
+ if err != nil {
+ return 0, 0, false
+ }
+ return uint32(i), uint32(i), true
}
- if i < 0 {
- return uint32(0), false
+ if len(s) == 2 {
+ var min, max uint64
+ var err error
+ if s[0] == "" {
+ min = 0
+ } else {
+ min, err = strconv.ParseUint(s[0], 10, 32)
+ if err != nil {
+ return 0, 0, false
+ }
+ }
+ if s[1] == "" {
+ if s[0] == "" {
+ // explicitly reject ttl directive "-" that would otherwise be interpreted
+ // as 0-2147483647 which is pretty useless
+ return 0, 0, false
+ }
+ max = 2147483647
+ } else {
+ max, err = strconv.ParseUint(s[1], 10, 32)
+ if err != nil {
+ return 0, 0, false
+ }
+ }
+ if min > max {
+ // reject invalid range
+ return 0, 0, false
+ }
+ return uint32(min), uint32(max), true
}
- return uint32(i), true
+ return 0, 0, false
}