diff options
Diffstat (limited to 'plugin')
-rw-r--r-- | plugin/federation/kubernetes_api_test.go | 4 | ||||
-rw-r--r-- | plugin/kubernetes/README.md | 5 | ||||
-rw-r--r-- | plugin/kubernetes/controller.go | 222 | ||||
-rw-r--r-- | plugin/kubernetes/controller_test.go | 53 | ||||
-rw-r--r-- | plugin/kubernetes/handler_test.go | 4 | ||||
-rw-r--r-- | plugin/kubernetes/kubernetes.go | 25 | ||||
-rw-r--r-- | plugin/kubernetes/kubernetes_test.go | 86 | ||||
-rw-r--r-- | plugin/kubernetes/ns_test.go | 5 | ||||
-rw-r--r-- | plugin/kubernetes/reverse_test.go | 4 | ||||
-rw-r--r-- | plugin/kubernetes/watch.go | 20 | ||||
-rw-r--r-- | plugin/kubernetes/watch_test.go | 15 | ||||
-rw-r--r-- | plugin/pkg/watch/watch.go | 23 | ||||
-rw-r--r-- | plugin/pkg/watch/watcher.go | 178 |
13 files changed, 601 insertions, 43 deletions
diff --git a/plugin/federation/kubernetes_api_test.go b/plugin/federation/kubernetes_api_test.go index ee4757d22..b468510a5 100644 --- a/plugin/federation/kubernetes_api_test.go +++ b/plugin/federation/kubernetes_api_test.go @@ -2,6 +2,7 @@ package federation import ( "github.com/coredns/coredns/plugin/kubernetes" + "github.com/coredns/coredns/plugin/pkg/watch" api "k8s.io/api/core/v1" meta "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -15,6 +16,9 @@ func (APIConnFederationTest) Stop() error { return ni func (APIConnFederationTest) SvcIndexReverse(string) []*api.Service { return nil } func (APIConnFederationTest) EpIndexReverse(string) []*api.Endpoints { return nil } func (APIConnFederationTest) Modified() int64 { return 0 } +func (APIConnFederationTest) SetWatchChan(watch.Chan) {} +func (APIConnFederationTest) Watch(string) error { return nil } +func (APIConnFederationTest) StopWatching(string) {} func (APIConnFederationTest) PodIndex(string) []*api.Pod { a := []*api.Pod{{ diff --git a/plugin/kubernetes/README.md b/plugin/kubernetes/README.md index 128e843e2..24965126a 100644 --- a/plugin/kubernetes/README.md +++ b/plugin/kubernetes/README.md @@ -110,6 +110,11 @@ kubernetes [ZONES...] { This plugin implements dynamic health checking. Currently this is limited to reporting healthy when the API has synced. +## Watch + +This plugin implements watch. A client that connects to CoreDNS using `coredns/client` can be notified +of changes to A, AAAA, and SRV records for Kubernetes services and endpoints. + ## Examples Handle all queries in the `cluster.local` zone. Connect to Kubernetes in-cluster. Also handle all diff --git a/plugin/kubernetes/controller.go b/plugin/kubernetes/controller.go index 0d7370a56..286f87d8e 100644 --- a/plugin/kubernetes/controller.go +++ b/plugin/kubernetes/controller.go @@ -7,6 +7,8 @@ import ( "sync/atomic" "time" + dnswatch "github.com/coredns/coredns/plugin/pkg/watch" + api "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/cache" @@ -45,6 +47,11 @@ type dnsController interface { // Modified returns the timestamp of the most recent changes Modified() int64 + + // Watch-related items + SetWatchChan(dnswatch.Chan) + Watch(string) error + StopWatching(string) } type dnsControl struct { @@ -73,6 +80,12 @@ type dnsControl struct { stopLock sync.Mutex shutdown bool stopCh chan struct{} + + // watch-related items channel + watchChan dnswatch.Chan + watched map[string]bool + zones []string + endpointNameMode bool } type dnsControlOpts struct { @@ -83,14 +96,20 @@ type dnsControlOpts struct { // Label handling. labelSelector *meta.LabelSelector selector labels.Selector + + zones []string + endpointNameMode bool } // newDNSController creates a controller for CoreDNS. func newdnsController(kubeClient *kubernetes.Clientset, opts dnsControlOpts) *dnsControl { dns := dnsControl{ - client: kubeClient, - selector: opts.selector, - stopCh: make(chan struct{}), + client: kubeClient, + selector: opts.selector, + stopCh: make(chan struct{}), + watched: make(map[string]bool), + zones: opts.zones, + endpointNameMode: opts.endpointNameMode, } dns.svcLister, dns.svcController = cache.NewIndexerInformer( @@ -292,6 +311,22 @@ func namespaceWatchFunc(c *kubernetes.Clientset, s labels.Selector) func(options } } +func (dns *dnsControl) SetWatchChan(c dnswatch.Chan) { + dns.watchChan = c +} + +func (dns *dnsControl) Watch(qname string) error { + if dns.watchChan == nil { + return fmt.Errorf("cannot start watch because the channel has not been set") + } + dns.watched[qname] = true + return nil +} + +func (dns *dnsControl) StopWatching(qname string) { + delete(dns.watched, qname) +} + // Stop stops the controller. func (dns *dnsControl) Stop() error { dns.stopLock.Lock() @@ -492,63 +527,164 @@ func (dns *dnsControl) updateModifed() { atomic.StoreInt64(&dns.modified, unix) } -func (dns *dnsControl) Add(obj interface{}) { dns.updateModifed() } -func (dns *dnsControl) Delete(obj interface{}) { dns.updateModifed() } +func (dns *dnsControl) sendServiceUpdates(s *api.Service) { + for i := range dns.zones { + name := serviceFQDN(s, dns.zones[i]) + if _, ok := dns.watched[name]; ok { + dns.watchChan <- name + } + } +} + +func (dns *dnsControl) sendPodUpdates(p *api.Pod) { + for i := range dns.zones { + name := podFQDN(p, dns.zones[i]) + if _, ok := dns.watched[name]; ok { + dns.watchChan <- name + } + } +} + +func (dns *dnsControl) sendEndpointsUpdates(ep *api.Endpoints) { + for _, zone := range dns.zones { + names := append(endpointFQDN(ep, zone, dns.endpointNameMode), serviceFQDN(ep, zone)) + for _, name := range names { + if _, ok := dns.watched[name]; ok { + dns.watchChan <- name + } + } + } +} + +// endpointsSubsetDiffs returns an Endpoints struct containing the Subsets that have changed between a and b. +// When we notify clients of changed endpoints we only want to notify them of endpoints that have changed. +// The Endpoints API object holds more than one endpoint, held in a list of Subsets. Each Subset refers to +// an endpoint. So, here we create a new Endpoints struct, and populate it with only the endpoints that have changed. +// This new Endpoints object is later used to generate the list of endpoint FQDNs to send to the client. +// This function computes this literally by combining the sets (in a and not in b) union (in b and not in a). +func endpointsSubsetDiffs(a, b *api.Endpoints) *api.Endpoints { + c := b.DeepCopy() + c.Subsets = []api.EndpointSubset{} + + // In the following loop, the first iteration computes (in a but not in b). + // The second iteration then adds (in b but not in a) + // The end result is an Endpoints that only contains the subsets (endpoints) that are different between a and b. + for _, abba := range [][]*api.Endpoints{{a, b}, {b, a}} { + a := abba[0] + b := abba[1] + left: + for _, as := range a.Subsets { + for _, bs := range b.Subsets { + if subsetsEquivalent(as, bs) { + continue left + } + } + c.Subsets = append(c.Subsets, as) + } + } + return c +} -func (dns *dnsControl) Update(objOld, newObj interface{}) { - // endpoint updates can come frequently, make sure - // it's a change we care about - if o, ok := objOld.(*api.Endpoints); ok { - n := newObj.(*api.Endpoints) - if endpointsEquivalent(o, n) { +// sendUpdates sends a notification to the server if a watch +// is enabled for the qname +func (dns *dnsControl) sendUpdates(oldObj, newObj interface{}) { + // If both objects have the same resource version, they are identical. + if newObj != nil && oldObj != nil && (oldObj.(meta.Object).GetResourceVersion() == newObj.(meta.Object).GetResourceVersion()) { + return + } + obj := newObj + if obj == nil { + obj = oldObj + } + switch ob := obj.(type) { + case *api.Service: + dns.updateModifed() + dns.sendServiceUpdates(ob) + case *api.Endpoints: + if newObj == nil || oldObj == nil { + dns.updateModifed() + dns.sendEndpointsUpdates(ob) + return + } + p := oldObj.(*api.Endpoints) + // endpoint updates can come frequently, make sure it's a change we care about + if endpointsEquivalent(p, ob) { return } + dns.updateModifed() + dns.sendEndpointsUpdates(endpointsSubsetDiffs(p, ob)) + case *api.Pod: + dns.updateModifed() + dns.sendPodUpdates(ob) + default: + log.Warningf("Updates for %T not supported.", ob) } - dns.updateModifed() } -// endpointsEquivalent checks if the update to an endpoint is something -// that matters to us: ready addresses, host names, ports (including names for SRV) -func endpointsEquivalent(a, b *api.Endpoints) bool { - // supposedly we should be able to rely on - // these being sorted and able to be compared - // they are supposed to be in a canonical format +func (dns *dnsControl) Add(obj interface{}) { + dns.sendUpdates(nil, obj) +} +func (dns *dnsControl) Delete(obj interface{}) { + dns.sendUpdates(obj, nil) +} +func (dns *dnsControl) Update(oldObj, newObj interface{}) { + dns.sendUpdates(oldObj, newObj) +} - if len(a.Subsets) != len(b.Subsets) { +// subsetsEquivalent checks if two endpoint subsets are significantly equivalent +// I.e. that they have the same ready addresses, host names, ports (including protocol +// and service names for SRV) +func subsetsEquivalent(sa, sb api.EndpointSubset) bool { + if len(sa.Addresses) != len(sb.Addresses) { + return false + } + if len(sa.Ports) != len(sb.Ports) { return false } - for i, sa := range a.Subsets { - // check the Addresses and Ports. Ignore unready addresses. - sb := b.Subsets[i] - if len(sa.Addresses) != len(sb.Addresses) { + // in Addresses and Ports, we should be able to rely on + // these being sorted and able to be compared + // they are supposed to be in a canonical format + for addr, aaddr := range sa.Addresses { + baddr := sb.Addresses[addr] + if aaddr.IP != baddr.IP { return false } - if len(sa.Ports) != len(sb.Ports) { + if aaddr.Hostname != baddr.Hostname { return false } + } - for addr, aaddr := range sa.Addresses { - baddr := sb.Addresses[addr] - if aaddr.IP != baddr.IP { - return false - } - if aaddr.Hostname != baddr.Hostname { - return false - } + for port, aport := range sa.Ports { + bport := sb.Ports[port] + if aport.Name != bport.Name { + return false + } + if aport.Port != bport.Port { + return false + } + if aport.Protocol != bport.Protocol { + return false } + } + return true +} - for port, aport := range sa.Ports { - bport := sb.Ports[port] - if aport.Name != bport.Name { - return false - } - if aport.Port != bport.Port { - return false - } - if aport.Protocol != bport.Protocol { - return false - } +// endpointsEquivalent checks if the update to an endpoint is something +// that matters to us or if they are effectively equivalent. +func endpointsEquivalent(a, b *api.Endpoints) bool { + + if len(a.Subsets) != len(b.Subsets) { + return false + } + + // we should be able to rely on + // these being sorted and able to be compared + // they are supposed to be in a canonical format + for i, sa := range a.Subsets { + sb := b.Subsets[i] + if !subsetsEquivalent(sa, sb) { + return false } } return true diff --git a/plugin/kubernetes/controller_test.go b/plugin/kubernetes/controller_test.go new file mode 100644 index 000000000..02915fb51 --- /dev/null +++ b/plugin/kubernetes/controller_test.go @@ -0,0 +1,53 @@ +package kubernetes + +import ( + "strconv" + "strings" + "testing" + + api "k8s.io/api/core/v1" +) + +func endpointSubsets(addrs ...string) (eps []api.EndpointSubset) { + for _, ap := range addrs { + apa := strings.Split(ap, ":") + address := apa[0] + port, _ := strconv.Atoi(apa[1]) + eps = append(eps, api.EndpointSubset{Addresses: []api.EndpointAddress{{IP: address}}, Ports: []api.EndpointPort{{Port: int32(port)}}}) + } + return eps +} + +func TestEndpointsSubsetDiffs(t *testing.T) { + var tests = []struct { + a, b, expected api.Endpoints + }{ + { // From a->b: Nothing changes + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + api.Endpoints{}, + }, + { // From a->b: Everything goes away + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + api.Endpoints{}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + }, + { // From a->b: Everything is new + api.Endpoints{}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")}, + }, + { // From a->b: One goes away, one is new + api.Endpoints{Subsets: endpointSubsets("10.0.0.2:8080")}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80")}, + api.Endpoints{Subsets: endpointSubsets("10.0.0.2:8080", "10.0.0.1:80")}, + }, + } + + for i, te := range tests { + got := endpointsSubsetDiffs(&te.a, &te.b) + if !endpointsEquivalent(got, &te.expected) { + t.Errorf("Expected '%v' for test %v, got '%v'.", te.expected, i, got) + } + } +} diff --git a/plugin/kubernetes/handler_test.go b/plugin/kubernetes/handler_test.go index 388903137..2edeb8e8e 100644 --- a/plugin/kubernetes/handler_test.go +++ b/plugin/kubernetes/handler_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/watch" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" @@ -332,6 +333,9 @@ func (APIConnServeTest) Stop() error { return nil } func (APIConnServeTest) EpIndexReverse(string) []*api.Endpoints { return nil } func (APIConnServeTest) SvcIndexReverse(string) []*api.Service { return nil } func (APIConnServeTest) Modified() int64 { return time.Now().Unix() } +func (APIConnServeTest) SetWatchChan(watch.Chan) {} +func (APIConnServeTest) Watch(string) error { return nil } +func (APIConnServeTest) StopWatching(string) {} func (APIConnServeTest) PodIndex(string) []*api.Pod { a := []*api.Pod{{ diff --git a/plugin/kubernetes/kubernetes.go b/plugin/kubernetes/kubernetes.go index af0e64ee9..03b93748b 100644 --- a/plugin/kubernetes/kubernetes.go +++ b/plugin/kubernetes/kubernetes.go @@ -260,6 +260,8 @@ func (k *Kubernetes) InitKubeCache() (err error) { k.opts.initPodCache = k.podMode == podModeVerified + k.opts.zones = k.Zones + k.opts.endpointNameMode = k.endpointNameMode k.APIConn = newdnsController(kubeClient, k.opts) return err @@ -292,6 +294,29 @@ func (k *Kubernetes) Records(state request.Request, exact bool) ([]msg.Service, return services, err } +// serviceFQDN returns the k8s cluster dns spec service FQDN for the service (or endpoint) object. +func serviceFQDN(obj meta.Object, zone string) string { + return dnsutil.Join(append([]string{}, obj.GetName(), obj.GetNamespace(), Svc, zone)) +} + +// podFQDN returns the k8s cluster dns spec FQDN for the pod. +func podFQDN(p *api.Pod, zone string) string { + name := strings.Replace(p.Status.PodIP, ".", "-", -1) + name = strings.Replace(name, ":", "-", -1) + return dnsutil.Join(append([]string{}, name, p.GetNamespace(), Pod, zone)) +} + +// endpointFQDN returns a list of k8s cluster dns spec service FQDNs for each subset in the endpoint. +func endpointFQDN(ep *api.Endpoints, zone string, endpointNameMode bool) []string { + var names []string + for _, ss := range ep.Subsets { + for _, addr := range ss.Addresses { + names = append(names, dnsutil.Join(append([]string{}, endpointHostname(addr, endpointNameMode), serviceFQDN(ep, zone)))) + } + } + return names +} + func endpointHostname(addr api.EndpointAddress, endpointNameMode bool) string { if addr.Hostname != "" { return strings.ToLower(addr.Hostname) diff --git a/plugin/kubernetes/kubernetes_test.go b/plugin/kubernetes/kubernetes_test.go index e10fe894b..36d00a92f 100644 --- a/plugin/kubernetes/kubernetes_test.go +++ b/plugin/kubernetes/kubernetes_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/watch" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -64,6 +65,9 @@ func (APIConnServiceTest) PodIndex(string) []*api.Pod { return nil } func (APIConnServiceTest) SvcIndexReverse(string) []*api.Service { return nil } func (APIConnServiceTest) EpIndexReverse(string) []*api.Endpoints { return nil } func (APIConnServiceTest) Modified() int64 { return 0 } +func (APIConnServiceTest) SetWatchChan(watch.Chan) {} +func (APIConnServiceTest) Watch(string) error { return nil } +func (APIConnServiceTest) StopWatching(string) {} func (APIConnServiceTest) SvcIndex(string) []*api.Service { svcs := []*api.Service{ @@ -390,3 +394,85 @@ func TestServices(t *testing.T) { } } } + +func TestServiceFQDN(t *testing.T) { + fqdn := serviceFQDN( + &api.Service{ + ObjectMeta: meta.ObjectMeta{ + Name: "svc1", + Namespace: "testns", + }, + }, "cluster.local") + + expected := "svc1.testns.svc.cluster.local." + if fqdn != expected { + t.Errorf("Expected '%v', got '%v'.", expected, fqdn) + } +} + +func TestPodFQDN(t *testing.T) { + fqdn := podFQDN( + &api.Pod{ + ObjectMeta: meta.ObjectMeta{ + Name: "pod1", + Namespace: "testns", + }, + Status: api.PodStatus{ + PodIP: "10.10.0.10", + }, + }, "cluster.local") + + expected := "10-10-0-10.testns.pod.cluster.local." + if fqdn != expected { + t.Errorf("Expected '%v', got '%v'.", expected, fqdn) + } + fqdn = podFQDN( + &api.Pod{ + ObjectMeta: meta.ObjectMeta{ + Name: "pod1", + Namespace: "testns", + }, + Status: api.PodStatus{ + PodIP: "aaaa:bbbb:cccc::zzzz", + }, + }, "cluster.local") + + expected = "aaaa-bbbb-cccc--zzzz.testns.pod.cluster.local." + if fqdn != expected { + t.Errorf("Expected '%v', got '%v'.", expected, fqdn) + } +} + +func TestEndpointFQDN(t *testing.T) { + fqdns := endpointFQDN( + &api.Endpoints{ + Subsets: []api.EndpointSubset{ + { + Addresses: []api.EndpointAddress{ + { + IP: "172.0.0.1", + Hostname: "ep1a", + }, + { + IP: "172.0.0.2", + }, + }, + }, + }, + ObjectMeta: meta.ObjectMeta{ + Name: "svc1", + Namespace: "testns", + }, + }, "cluster.local", false) + + expected := []string{ + "ep1a.svc1.testns.svc.cluster.local.", + "172-0-0-2.svc1.testns.svc.cluster.local.", + } + + for i := range fqdns { + if fqdns[i] != expected[i] { + t.Errorf("Expected '%v', got '%v'.", expected[i], fqdns[i]) + } + } +} diff --git a/plugin/kubernetes/ns_test.go b/plugin/kubernetes/ns_test.go index 7dcc83eeb..f331d3231 100644 --- a/plugin/kubernetes/ns_test.go +++ b/plugin/kubernetes/ns_test.go @@ -3,6 +3,8 @@ package kubernetes import ( "testing" + "github.com/coredns/coredns/plugin/pkg/watch" + api "k8s.io/api/core/v1" meta "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -18,6 +20,9 @@ func (APIConnTest) SvcIndexReverse(string) []*api.Service { return nil } func (APIConnTest) EpIndex(string) []*api.Endpoints { return nil } func (APIConnTest) EndpointsList() []*api.Endpoints { return nil } func (APIConnTest) Modified() int64 { return 0 } +func (APIConnTest) SetWatchChan(watch.Chan) {} +func (APIConnTest) Watch(string) error { return nil } +func (APIConnTest) StopWatching(string) {} func (APIConnTest) ServiceList() []*api.Service { svcs := []*api.Service{ diff --git a/plugin/kubernetes/reverse_test.go b/plugin/kubernetes/reverse_test.go index 2cf41de1a..681172021 100644 --- a/plugin/kubernetes/reverse_test.go +++ b/plugin/kubernetes/reverse_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/watch" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" @@ -22,6 +23,9 @@ func (APIConnReverseTest) EpIndex(string) []*api.Endpoints { return nil } func (APIConnReverseTest) EndpointsList() []*api.Endpoints { return nil } func (APIConnReverseTest) ServiceList() []*api.Service { return nil } func (APIConnReverseTest) Modified() int64 { return 0 } +func (APIConnReverseTest) SetWatchChan(watch.Chan) {} +func (APIConnReverseTest) Watch(string) error { return nil } +func (APIConnReverseTest) StopWatching(string) {} func (APIConnReverseTest) SvcIndex(svc string) []*api.Service { if svc != "svc1.testns" { diff --git a/plugin/kubernetes/watch.go b/plugin/kubernetes/watch.go new file mode 100644 index 000000000..488540444 --- /dev/null +++ b/plugin/kubernetes/watch.go @@ -0,0 +1,20 @@ +package kubernetes + +import ( + "github.com/coredns/coredns/plugin/pkg/watch" +) + +// SetWatchChan implements watch.Watchable +func (k *Kubernetes) SetWatchChan(c watch.Chan) { + k.APIConn.SetWatchChan(c) +} + +// Watch is called when a watch is started for a name. +func (k *Kubernetes) Watch(qname string) error { + return k.APIConn.Watch(qname) +} + +// StopWatching is called when no more watches remain for a name +func (k *Kubernetes) StopWatching(qname string) { + k.APIConn.StopWatching(qname) +} diff --git a/plugin/kubernetes/watch_test.go b/plugin/kubernetes/watch_test.go new file mode 100644 index 000000000..46b2e5dc4 --- /dev/null +++ b/plugin/kubernetes/watch_test.go @@ -0,0 +1,15 @@ +package kubernetes + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/watch" +) + +func TestIsWatchable(t *testing.T) { + k := &Kubernetes{} + var i interface{} = k + if _, ok := i.(watch.Watchable); !ok { + t.Error("Kubernetes should implement watch.Watchable and does not") + } +} diff --git a/plugin/pkg/watch/watch.go b/plugin/pkg/watch/watch.go new file mode 100644 index 000000000..7e77bb7b3 --- /dev/null +++ b/plugin/pkg/watch/watch.go @@ -0,0 +1,23 @@ +package watch + +// Chan is used to inform the server of a change. Whenever +// a watched FQDN has a change in data, that FQDN should be +// sent down this channel. +type Chan chan string + +// Watchable is the interface watchable plugins should implement +type Watchable interface { + // Name returns the plugin name. + Name() string + + // SetWatchChan is called when the watch channel is created. + SetWatchChan(Chan) + + // Watch is called whenever a watch is created for a FQDN. Plugins + // should send the FQDN down the watch channel when its data may have + // changed. This is an exact match only. + Watch(qname string) error + + // StopWatching is called whenever all watches are canceled for a FQDN. + StopWatching(qname string) +} diff --git a/plugin/pkg/watch/watcher.go b/plugin/pkg/watch/watcher.go new file mode 100644 index 000000000..59474a7bc --- /dev/null +++ b/plugin/pkg/watch/watcher.go @@ -0,0 +1,178 @@ +package watch + +import ( + "fmt" + "io" + "sync" + + "github.com/miekg/dns" + + "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" +) + +// Watcher handles watch creation, cancellation, and processing. +type Watcher interface { + // Watch monitors a client stream and creates and cancels watches. + Watch(pb.DnsService_WatchServer) error + + // Stop cancels open watches and stops the watch processing go routine. + Stop() +} + +// Manager contains all the data needed to manage watches +type Manager struct { + changes Chan + stopper chan bool + counter int64 + watches map[string]watchlist + plugins []Watchable + mutex sync.Mutex +} + +type watchlist map[int64]pb.DnsService_WatchServer + +// NewWatcher creates a Watcher, which is used to manage watched names. +func NewWatcher(plugins []Watchable) *Manager { + w := &Manager{changes: make(Chan), stopper: make(chan bool), watches: make(map[string]watchlist), plugins: plugins} + + for _, p := range plugins { + p.SetWatchChan(w.changes) + } + + go w.process() + return w +} + +func (w *Manager) nextID() int64 { + w.mutex.Lock() + + w.counter++ + id := w.counter + + w.mutex.Unlock() + return id +} + +// Watch monitors a client stream and creates and cancels watches. +func (w *Manager) Watch(stream pb.DnsService_WatchServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + create := in.GetCreateRequest() + if create != nil { + msg := new(dns.Msg) + err := msg.Unpack(create.Query.Msg) + if err != nil { + log.Warningf("Could not decode watch request: %s\n", err) + stream.Send(&pb.WatchResponse{Err: "could not decode request"}) + continue + } + id := w.nextID() + if err := stream.Send(&pb.WatchResponse{WatchId: id, Created: true}); err != nil { + // if we fail to notify client of watch creation, don't create the watch + continue + } + + // Normalize qname + qname := (&request.Request{Req: msg}).Name() + + w.mutex.Lock() + if _, ok := w.watches[qname]; !ok { + w.watches[qname] = make(watchlist) + } + w.watches[qname][id] = stream + w.mutex.Unlock() + + for _, p := range w.plugins { + err := p.Watch(qname) + if err != nil { + log.Warningf("Failed to start watch for %s in plugin %s: %s\n", qname, p.Name(), err) + stream.Send(&pb.WatchResponse{Err: fmt.Sprintf("failed to start watch for %s in plugin %s", qname, p.Name())}) + } + } + continue + } + + cancel := in.GetCancelRequest() + if cancel != nil { + w.mutex.Lock() + for qname, wl := range w.watches { + ws, ok := wl[cancel.WatchId] + if !ok { + continue + } + + // only allow cancels from the client that started it + // TODO: test what happens if a stream tries to cancel a watchID that it doesn't own + if ws != stream { + continue + } + + delete(wl, cancel.WatchId) + + // if there are no more watches for this qname, we should tell the plugins + if len(wl) == 0 { + for _, p := range w.plugins { + p.StopWatching(qname) + } + delete(w.watches, qname) + } + + // let the client know we canceled the watch + stream.Send(&pb.WatchResponse{WatchId: cancel.WatchId, Canceled: true}) + } + w.mutex.Unlock() + continue + } + } +} + +func (w *Manager) process() { + for { + select { + case <-w.stopper: + return + case changed := <-w.changes: + w.mutex.Lock() + for qname, wl := range w.watches { + if plugin.Zones([]string{changed}).Matches(qname) == "" { + continue + } + for id, stream := range wl { + wr := pb.WatchResponse{WatchId: id, Qname: qname} + err := stream.Send(&wr) + if err != nil { + log.Warningf("Error sending change for %s to watch %d: %s. Removing watch.\n", qname, id, err) + delete(w.watches[qname], id) + } + } + } + w.mutex.Unlock() + } + } +} + +// Stop cancels open watches and stops the watch processing go routine. +func (w *Manager) Stop() { + w.stopper <- true + w.mutex.Lock() + for wn, wl := range w.watches { + for id, stream := range wl { + wr := pb.WatchResponse{WatchId: id, Canceled: true} + err := stream.Send(&wr) + if err != nil { + log.Warningf("Error notifiying client of cancellation: %s\n", err) + } + } + delete(w.watches, wn) + } + w.mutex.Unlock() +} |