aboutsummaryrefslogtreecommitdiff
path: root/middleware/cache/setup.go
diff options
context:
space:
mode:
Diffstat (limited to 'middleware/cache/setup.go')
-rw-r--r--middleware/cache/setup.go110
1 files changed, 82 insertions, 28 deletions
diff --git a/middleware/cache/setup.go b/middleware/cache/setup.go
index c0b09024b..de90d3acb 100644
--- a/middleware/cache/setup.go
+++ b/middleware/cache/setup.go
@@ -2,10 +2,12 @@ package cache
import (
"strconv"
+ "time"
"github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
+ "github.com/hashicorp/golang-lru"
"github.com/mholt/caddy"
)
@@ -16,51 +18,103 @@ func init() {
})
}
-// Cache sets up the root file path of the server.
func setup(c *caddy.Controller) error {
- ttl, zones, err := cacheParse(c)
+ ca, err := cacheParse(c)
if err != nil {
return middleware.Error("cache", err)
}
dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler {
- return NewCache(ttl, zones, next)
+ ca.Next = next
+ return ca
})
return nil
}
-func cacheParse(c *caddy.Controller) (int, []string, error) {
- var (
- err error
- ttl int
- origins []string
- )
+func cacheParse(c *caddy.Controller) (*Cache, error) {
+
+ ca := &Cache{pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxNTTL}
for c.Next() {
- if c.Val() == "cache" {
- // cache [ttl] [zones..]
- origins = make([]string, len(c.ServerBlockKeys))
- copy(origins, c.ServerBlockKeys)
- args := c.RemainingArgs()
+ // cache [ttl] [zones..]
+ origins := make([]string, len(c.ServerBlockKeys))
+ copy(origins, c.ServerBlockKeys)
+ args := c.RemainingArgs()
+
+ if len(args) > 0 {
+ // first args may be just a number, then it is the ttl, if not it is a zone
+ ttl, err := strconv.Atoi(args[0])
+ if err == nil {
+ ca.pttl = time.Duration(ttl) * time.Second
+ ca.nttl = time.Duration(ttl) * time.Second
+ args = args[1:]
+ }
if len(args) > 0 {
- origins = args
- // first args may be just a number, then it is the ttl, if not it is a zone
- t := origins[0]
- ttl, err = strconv.Atoi(t)
- if err == nil {
- origins = origins[1:]
- if len(origins) == 0 {
- // There was *only* the ttl, revert back to server block
- copy(origins, c.ServerBlockKeys)
+ copy(origins, args)
+ }
+ }
+
+ // Refinements? In an extra block.
+ for c.NextBlock() {
+ switch c.Val() {
+ // first number is cap, second is an new ttl
+ case "positive":
+ args := c.RemainingArgs()
+ if len(args) == 0 {
+ return nil, c.ArgErr()
+ }
+ pcap, err := strconv.Atoi(args[0])
+ if err != nil {
+ return nil, err
+ }
+ ca.pcap = pcap
+ if len(args) > 1 {
+ pttl, err := strconv.Atoi(args[1])
+ if err != nil {
+ return nil, err
+ }
+ ca.pttl = time.Duration(pttl) * time.Second
+ }
+ case "negative":
+ args := c.RemainingArgs()
+ if len(args) == 0 {
+ return nil, c.ArgErr()
+ }
+ ncap, err := strconv.Atoi(args[0])
+ if err != nil {
+ return nil, err
+ }
+ ca.ncap = ncap
+ if len(args) > 1 {
+ nttl, err := strconv.Atoi(args[1])
+ if err != nil {
+ return nil, err
}
+ ca.nttl = time.Duration(nttl) * time.Second
}
+ default:
+ return nil, c.ArgErr()
}
+ }
- for i := range origins {
- origins[i] = middleware.Host(origins[i]).Normalize()
- }
- return ttl, origins, nil
+ for i := range origins {
+ origins[i] = middleware.Host(origins[i]).Normalize()
+ }
+
+ var err error
+ ca.Zones = origins
+
+ ca.pcache, err = lru.New(ca.pcap)
+ if err != nil {
+ return nil, err
+ }
+ ca.ncache, err = lru.New(ca.ncap)
+ if err != nil {
+ return nil, err
}
+
+ return ca, nil
}
- return 0, nil, nil
+
+ return nil, nil
}