aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/dnsserver/server_grpc.go13
-rw-r--r--core/dnsserver/watch.go18
-rw-r--r--pb/dns.pb.go322
-rw-r--r--pb/dns.proto37
-rw-r--r--plugin/federation/kubernetes_api_test.go4
-rw-r--r--plugin/kubernetes/README.md5
-rw-r--r--plugin/kubernetes/controller.go222
-rw-r--r--plugin/kubernetes/controller_test.go53
-rw-r--r--plugin/kubernetes/handler_test.go4
-rw-r--r--plugin/kubernetes/kubernetes.go25
-rw-r--r--plugin/kubernetes/kubernetes_test.go86
-rw-r--r--plugin/kubernetes/ns_test.go5
-rw-r--r--plugin/kubernetes/reverse_test.go4
-rw-r--r--plugin/kubernetes/watch.go20
-rw-r--r--plugin/kubernetes/watch_test.go15
-rw-r--r--plugin/pkg/watch/watch.go23
-rw-r--r--plugin/pkg/watch/watcher.go178
17 files changed, 977 insertions, 57 deletions
diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go
index db81b317f..e5b87749d 100644
--- a/core/dnsserver/server_grpc.go
+++ b/core/dnsserver/server_grpc.go
@@ -8,6 +8,7 @@ import (
"net"
"github.com/coredns/coredns/pb"
+ "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
@@ -22,6 +23,7 @@ type ServergRPC struct {
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
+ watch watch.Watcher
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
@@ -38,7 +40,7 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
tlsConfig = conf.TLSConfig
}
- return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
+ return &ServergRPC{Server: s, tlsConfig: tlsConfig, watch: watch.NewWatcher(watchables(s.zones))}, nil
}
// Serve implements caddy.TCPServer interface.
@@ -100,6 +102,9 @@ func (s *ServergRPC) OnStartupComplete() {
func (s *ServergRPC) Stop() (err error) {
s.m.Lock()
defer s.m.Unlock()
+ if s.watch != nil {
+ s.watch.Stop()
+ }
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
@@ -138,6 +143,12 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket
return &pb.DnsPacket{Msg: packed}, nil
}
+// Watch is the entrypoint called by the gRPC layer when the user asks
+// to watch a query.
+func (s *ServergRPC) Watch(stream pb.DnsService_WatchServer) error {
+ return s.watch.Watch(stream)
+}
+
// Shutdown stops the server (non gracefully).
func (s *ServergRPC) Shutdown() error {
if s.grpcServer != nil {
diff --git a/core/dnsserver/watch.go b/core/dnsserver/watch.go
new file mode 100644
index 000000000..590bac144
--- /dev/null
+++ b/core/dnsserver/watch.go
@@ -0,0 +1,18 @@
+package dnsserver
+
+import (
+ "github.com/coredns/coredns/plugin/pkg/watch"
+)
+
+func watchables(zones map[string]*Config) []watch.Watchable {
+ var w []watch.Watchable
+ for _, config := range zones {
+ plugins := config.Handlers()
+ for _, p := range plugins {
+ if x, ok := p.(watch.Watchable); ok {
+ w = append(w, x)
+ }
+ }
+ }
+ return w
+}
diff --git a/pb/dns.pb.go b/pb/dns.pb.go
index 0c75de94a..d79e24f6d 100644
--- a/pb/dns.pb.go
+++ b/pb/dns.pb.go
@@ -1,6 +1,5 @@
-// Code generated by protoc-gen-go.
+// Code generated by protoc-gen-go. DO NOT EDIT.
// source: dns.proto
-// DO NOT EDIT!
/*
Package pb is a generated protocol buffer package.
@@ -10,6 +9,10 @@ It is generated from these files:
It has these top-level messages:
DnsPacket
+ WatchRequest
+ WatchCreateRequest
+ WatchCancelRequest
+ WatchResponse
*/
package pb
@@ -19,7 +22,6 @@ import math "math"
import (
context "context"
-
grpc "google.golang.org/grpc"
)
@@ -50,8 +52,223 @@ func (m *DnsPacket) GetMsg() []byte {
return nil
}
+type WatchRequest struct {
+ // request_union is a request to either create a new watcher or cancel an existing watcher.
+ //
+ // Types that are valid to be assigned to RequestUnion:
+ // *WatchRequest_CreateRequest
+ // *WatchRequest_CancelRequest
+ RequestUnion isWatchRequest_RequestUnion `protobuf_oneof:"request_union"`
+}
+
+func (m *WatchRequest) Reset() { *m = WatchRequest{} }
+func (m *WatchRequest) String() string { return proto.CompactTextString(m) }
+func (*WatchRequest) ProtoMessage() {}
+func (*WatchRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
+type isWatchRequest_RequestUnion interface {
+ isWatchRequest_RequestUnion()
+}
+
+type WatchRequest_CreateRequest struct {
+ CreateRequest *WatchCreateRequest `protobuf:"bytes,1,opt,name=create_request,json=createRequest,oneof"`
+}
+type WatchRequest_CancelRequest struct {
+ CancelRequest *WatchCancelRequest `protobuf:"bytes,2,opt,name=cancel_request,json=cancelRequest,oneof"`
+}
+
+func (*WatchRequest_CreateRequest) isWatchRequest_RequestUnion() {}
+func (*WatchRequest_CancelRequest) isWatchRequest_RequestUnion() {}
+
+func (m *WatchRequest) GetRequestUnion() isWatchRequest_RequestUnion {
+ if m != nil {
+ return m.RequestUnion
+ }
+ return nil
+}
+
+func (m *WatchRequest) GetCreateRequest() *WatchCreateRequest {
+ if x, ok := m.GetRequestUnion().(*WatchRequest_CreateRequest); ok {
+ return x.CreateRequest
+ }
+ return nil
+}
+
+func (m *WatchRequest) GetCancelRequest() *WatchCancelRequest {
+ if x, ok := m.GetRequestUnion().(*WatchRequest_CancelRequest); ok {
+ return x.CancelRequest
+ }
+ return nil
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*WatchRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
+ return _WatchRequest_OneofMarshaler, _WatchRequest_OneofUnmarshaler, _WatchRequest_OneofSizer, []interface{}{
+ (*WatchRequest_CreateRequest)(nil),
+ (*WatchRequest_CancelRequest)(nil),
+ }
+}
+
+func _WatchRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*WatchRequest)
+ // request_union
+ switch x := m.RequestUnion.(type) {
+ case *WatchRequest_CreateRequest:
+ b.EncodeVarint(1<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.CreateRequest); err != nil {
+ return err
+ }
+ case *WatchRequest_CancelRequest:
+ b.EncodeVarint(2<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.CancelRequest); err != nil {
+ return err
+ }
+ case nil:
+ default:
+ return fmt.Errorf("WatchRequest.RequestUnion has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _WatchRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*WatchRequest)
+ switch tag {
+ case 1: // request_union.create_request
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(WatchCreateRequest)
+ err := b.DecodeMessage(msg)
+ m.RequestUnion = &WatchRequest_CreateRequest{msg}
+ return true, err
+ case 2: // request_union.cancel_request
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(WatchCancelRequest)
+ err := b.DecodeMessage(msg)
+ m.RequestUnion = &WatchRequest_CancelRequest{msg}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
+func _WatchRequest_OneofSizer(msg proto.Message) (n int) {
+ m := msg.(*WatchRequest)
+ // request_union
+ switch x := m.RequestUnion.(type) {
+ case *WatchRequest_CreateRequest:
+ s := proto.Size(x.CreateRequest)
+ n += proto.SizeVarint(1<<3 | proto.WireBytes)
+ n += proto.SizeVarint(uint64(s))
+ n += s
+ case *WatchRequest_CancelRequest:
+ s := proto.Size(x.CancelRequest)
+ n += proto.SizeVarint(2<<3 | proto.WireBytes)
+ n += proto.SizeVarint(uint64(s))
+ n += s
+ case nil:
+ default:
+ panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
+ }
+ return n
+}
+
+type WatchCreateRequest struct {
+ Query *DnsPacket `protobuf:"bytes,1,opt,name=query" json:"query,omitempty"`
+}
+
+func (m *WatchCreateRequest) Reset() { *m = WatchCreateRequest{} }
+func (m *WatchCreateRequest) String() string { return proto.CompactTextString(m) }
+func (*WatchCreateRequest) ProtoMessage() {}
+func (*WatchCreateRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
+
+func (m *WatchCreateRequest) GetQuery() *DnsPacket {
+ if m != nil {
+ return m.Query
+ }
+ return nil
+}
+
+type WatchCancelRequest struct {
+ // watch_id is the watcher id to cancel
+ WatchId int64 `protobuf:"varint,1,opt,name=watch_id,json=watchId" json:"watch_id,omitempty"`
+}
+
+func (m *WatchCancelRequest) Reset() { *m = WatchCancelRequest{} }
+func (m *WatchCancelRequest) String() string { return proto.CompactTextString(m) }
+func (*WatchCancelRequest) ProtoMessage() {}
+func (*WatchCancelRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
+
+func (m *WatchCancelRequest) GetWatchId() int64 {
+ if m != nil {
+ return m.WatchId
+ }
+ return 0
+}
+
+type WatchResponse struct {
+ // watch_id is the ID of the watcher that corresponds to the response.
+ WatchId int64 `protobuf:"varint,1,opt,name=watch_id,json=watchId" json:"watch_id,omitempty"`
+ // created is set to true if the response is for a create watch request.
+ // The client should record the watch_id and expect to receive DNS replies
+ // from the same stream.
+ // All replies sent to the created watcher will attach with the same watch_id.
+ Created bool `protobuf:"varint,2,opt,name=created" json:"created,omitempty"`
+ // canceled is set to true if the response is for a cancel watch request.
+ // No further events will be sent to the canceled watcher.
+ Canceled bool `protobuf:"varint,3,opt,name=canceled" json:"canceled,omitempty"`
+ Qname string `protobuf:"bytes,4,opt,name=qname" json:"qname,omitempty"`
+ Err string `protobuf:"bytes,5,opt,name=err" json:"err,omitempty"`
+}
+
+func (m *WatchResponse) Reset() { *m = WatchResponse{} }
+func (m *WatchResponse) String() string { return proto.CompactTextString(m) }
+func (*WatchResponse) ProtoMessage() {}
+func (*WatchResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
+
+func (m *WatchResponse) GetWatchId() int64 {
+ if m != nil {
+ return m.WatchId
+ }
+ return 0
+}
+
+func (m *WatchResponse) GetCreated() bool {
+ if m != nil {
+ return m.Created
+ }
+ return false
+}
+
+func (m *WatchResponse) GetCanceled() bool {
+ if m != nil {
+ return m.Canceled
+ }
+ return false
+}
+
+func (m *WatchResponse) GetQname() string {
+ if m != nil {
+ return m.Qname
+ }
+ return ""
+}
+
+func (m *WatchResponse) GetErr() string {
+ if m != nil {
+ return m.Err
+ }
+ return ""
+}
+
func init() {
proto.RegisterType((*DnsPacket)(nil), "coredns.dns.DnsPacket")
+ proto.RegisterType((*WatchRequest)(nil), "coredns.dns.WatchRequest")
+ proto.RegisterType((*WatchCreateRequest)(nil), "coredns.dns.WatchCreateRequest")
+ proto.RegisterType((*WatchCancelRequest)(nil), "coredns.dns.WatchCancelRequest")
+ proto.RegisterType((*WatchResponse)(nil), "coredns.dns.WatchResponse")
}
// Reference imports to suppress errors if they are not otherwise used.
@@ -66,6 +283,7 @@ const _ = grpc.SupportPackageIsVersion4
type DnsServiceClient interface {
Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error)
+ Watch(ctx context.Context, opts ...grpc.CallOption) (DnsService_WatchClient, error)
}
type dnsServiceClient struct {
@@ -85,10 +303,42 @@ func (c *dnsServiceClient) Query(ctx context.Context, in *DnsPacket, opts ...grp
return out, nil
}
+func (c *dnsServiceClient) Watch(ctx context.Context, opts ...grpc.CallOption) (DnsService_WatchClient, error) {
+ stream, err := grpc.NewClientStream(ctx, &_DnsService_serviceDesc.Streams[0], c.cc, "/coredns.dns.DnsService/Watch", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &dnsServiceWatchClient{stream}
+ return x, nil
+}
+
+type DnsService_WatchClient interface {
+ Send(*WatchRequest) error
+ Recv() (*WatchResponse, error)
+ grpc.ClientStream
+}
+
+type dnsServiceWatchClient struct {
+ grpc.ClientStream
+}
+
+func (x *dnsServiceWatchClient) Send(m *WatchRequest) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *dnsServiceWatchClient) Recv() (*WatchResponse, error) {
+ m := new(WatchResponse)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
// Server API for DnsService service
type DnsServiceServer interface {
Query(context.Context, *DnsPacket) (*DnsPacket, error)
+ Watch(DnsService_WatchServer) error
}
func RegisterDnsServiceServer(s *grpc.Server, srv DnsServiceServer) {
@@ -113,6 +363,32 @@ func _DnsService_Query_Handler(srv interface{}, ctx context.Context, dec func(in
return interceptor(ctx, in, info, handler)
}
+func _DnsService_Watch_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(DnsServiceServer).Watch(&dnsServiceWatchServer{stream})
+}
+
+type DnsService_WatchServer interface {
+ Send(*WatchResponse) error
+ Recv() (*WatchRequest, error)
+ grpc.ServerStream
+}
+
+type dnsServiceWatchServer struct {
+ grpc.ServerStream
+}
+
+func (x *dnsServiceWatchServer) Send(m *WatchResponse) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *dnsServiceWatchServer) Recv() (*WatchRequest, error) {
+ m := new(WatchRequest)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
var _DnsService_serviceDesc = grpc.ServiceDesc{
ServiceName: "coredns.dns.DnsService",
HandlerType: (*DnsServiceServer)(nil),
@@ -122,20 +398,40 @@ var _DnsService_serviceDesc = grpc.ServiceDesc{
Handler: _DnsService_Query_Handler,
},
},
- Streams: []grpc.StreamDesc{},
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "Watch",
+ Handler: _DnsService_Watch_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
Metadata: "dns.proto",
}
func init() { proto.RegisterFile("dns.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
- // 120 bytes of a gzipped FileDescriptorProto
- 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0xc9, 0x2b, 0xd6,
- 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x4e, 0xce, 0x2f, 0x4a, 0x05, 0x71, 0x53, 0xf2, 0x8a,
- 0x95, 0x64, 0xb9, 0x38, 0x5d, 0xf2, 0x8a, 0x03, 0x12, 0x93, 0xb3, 0x53, 0x4b, 0x84, 0x04, 0xb8,
- 0x98, 0x73, 0x8b, 0xd3, 0x25, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x40, 0x4c, 0x23, 0x57, 0x2e,
- 0x2e, 0x97, 0xbc, 0xe2, 0xe0, 0xd4, 0xa2, 0xb2, 0xcc, 0xe4, 0x54, 0x21, 0x73, 0x2e, 0xd6, 0xc0,
- 0xd2, 0xd4, 0xa2, 0x4a, 0x21, 0x31, 0x3d, 0x24, 0x33, 0xf4, 0xe0, 0x06, 0x48, 0xe1, 0x10, 0x77,
- 0x62, 0x89, 0x62, 0x2a, 0x48, 0x4a, 0x62, 0x03, 0xdb, 0x6f, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff,
- 0xf5, 0xd1, 0x3f, 0x26, 0x8c, 0x00, 0x00, 0x00,
+ // 333 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x52, 0x41, 0x4f, 0x32, 0x31,
+ 0x14, 0xfc, 0x0a, 0xec, 0x07, 0x3c, 0x40, 0xcd, 0x8b, 0x31, 0xcb, 0x26, 0x46, 0xb2, 0x27, 0x0e,
+ 0x06, 0x0d, 0x1e, 0xbc, 0xaf, 0x1c, 0xf0, 0xa6, 0xf5, 0x60, 0xe2, 0x85, 0x2c, 0xdd, 0x17, 0x25,
+ 0x4a, 0x17, 0xda, 0x45, 0xe3, 0x3f, 0xd0, 0xdf, 0xe3, 0x1f, 0x34, 0xfb, 0xba, 0x90, 0x35, 0x88,
+ 0xb7, 0xce, 0x74, 0x3a, 0xed, 0xcc, 0x2b, 0x34, 0x13, 0x6d, 0x07, 0x0b, 0x93, 0x66, 0x29, 0xb6,
+ 0x54, 0x6a, 0x28, 0x87, 0x89, 0xb6, 0xe1, 0x31, 0x34, 0x47, 0xda, 0xde, 0xc4, 0xea, 0x99, 0x32,
+ 0x3c, 0x80, 0xea, 0xdc, 0x3e, 0xfa, 0xa2, 0x27, 0xfa, 0x6d, 0x99, 0x2f, 0xc3, 0x2f, 0x01, 0xed,
+ 0xfb, 0x38, 0x53, 0x4f, 0x92, 0x96, 0x2b, 0xb2, 0x19, 0x8e, 0x61, 0x4f, 0x19, 0x8a, 0x33, 0x9a,
+ 0x18, 0xc7, 0xb0, 0xba, 0x35, 0x3c, 0x19, 0x94, 0x5c, 0x07, 0x7c, 0xe4, 0x8a, 0x75, 0xc5, 0xc1,
+ 0xf1, 0x3f, 0xd9, 0x51, 0x65, 0x82, 0x9d, 0x62, 0xad, 0xe8, 0x65, 0xe3, 0x54, 0xd9, 0xe9, 0xc4,
+ 0xba, 0xb2, 0x53, 0x99, 0x88, 0xf6, 0xa1, 0x53, 0x58, 0x4c, 0x56, 0x7a, 0x96, 0xea, 0x30, 0x02,
+ 0xdc, 0x7e, 0x01, 0x9e, 0x82, 0xb7, 0x5c, 0x91, 0x79, 0x2f, 0x5e, 0x7c, 0xf4, 0xe3, 0x9e, 0x4d,
+ 0x09, 0xd2, 0x89, 0xc2, 0xb3, 0xb5, 0x47, 0xf9, 0x2a, 0xec, 0x42, 0xe3, 0x2d, 0x67, 0x27, 0xb3,
+ 0x84, 0x6d, 0xaa, 0xb2, 0xce, 0xf8, 0x3a, 0x09, 0x3f, 0x04, 0x74, 0x8a, 0xaa, 0xec, 0x22, 0xd5,
+ 0x96, 0xfe, 0x10, 0xa3, 0x0f, 0x75, 0xd7, 0x46, 0xc2, 0xa9, 0x1b, 0x72, 0x0d, 0x31, 0x80, 0x86,
+ 0x4b, 0x47, 0x89, 0x5f, 0xe5, 0xad, 0x0d, 0xc6, 0x43, 0xf0, 0x96, 0x3a, 0x9e, 0x93, 0x5f, 0xeb,
+ 0x89, 0x7e, 0x53, 0x3a, 0x90, 0x4f, 0x8d, 0x8c, 0xf1, 0x3d, 0xe6, 0xf2, 0xe5, 0xf0, 0x53, 0x00,
+ 0x8c, 0xb4, 0xbd, 0x23, 0xf3, 0x3a, 0x53, 0x84, 0x97, 0xe0, 0xdd, 0xe6, 0x99, 0x70, 0x47, 0xe4,
+ 0x60, 0x07, 0x8f, 0x11, 0x78, 0x9c, 0x08, 0xbb, 0xdb, 0x33, 0x29, 0x1a, 0x09, 0x82, 0xdf, 0xb6,
+ 0x5c, 0x01, 0x7d, 0x71, 0x2e, 0xa2, 0xda, 0x43, 0x65, 0x31, 0x9d, 0xfe, 0xe7, 0xaf, 0x77, 0xf1,
+ 0x1d, 0x00, 0x00, 0xff, 0xff, 0xd2, 0x5b, 0x8c, 0xe1, 0x87, 0x02, 0x00, 0x00,
}
diff --git a/pb/dns.proto b/pb/dns.proto
index 8461f01e6..e4ac2eb2c 100644
--- a/pb/dns.proto
+++ b/pb/dns.proto
@@ -9,4 +9,41 @@ message DnsPacket {
service DnsService {
rpc Query (DnsPacket) returns (DnsPacket);
+ rpc Watch (stream WatchRequest) returns (stream WatchResponse);
+}
+
+message WatchRequest {
+ // request_union is a request to either create a new watcher or cancel an existing watcher.
+ oneof request_union {
+ WatchCreateRequest create_request = 1;
+ WatchCancelRequest cancel_request = 2;
+ }
+}
+
+message WatchCreateRequest {
+ DnsPacket query = 1;
+}
+
+message WatchCancelRequest {
+ // watch_id is the watcher id to cancel
+ int64 watch_id = 1;
+}
+
+message WatchResponse {
+ // watch_id is the ID of the watcher that corresponds to the response.
+ int64 watch_id = 1;
+
+ // created is set to true if the response is for a create watch request.
+ // The client should record the watch_id and expect to receive DNS replies
+ // from the same stream.
+ // All replies sent to the created watcher will attach with the same watch_id.
+ bool created = 2;
+
+ // canceled is set to true if the response is for a cancel watch request.
+ // No further events will be sent to the canceled watcher.
+ bool canceled = 3;
+
+ string qname = 4;
+
+ string err = 5;
}
diff --git a/plugin/federation/kubernetes_api_test.go b/plugin/federation/kubernetes_api_test.go
index ee4757d22..b468510a5 100644
--- a/plugin/federation/kubernetes_api_test.go
+++ b/plugin/federation/kubernetes_api_test.go
@@ -2,6 +2,7 @@ package federation
import (
"github.com/coredns/coredns/plugin/kubernetes"
+ "github.com/coredns/coredns/plugin/pkg/watch"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -15,6 +16,9 @@ func (APIConnFederationTest) Stop() error { return ni
func (APIConnFederationTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnFederationTest) EpIndexReverse(string) []*api.Endpoints { return nil }
func (APIConnFederationTest) Modified() int64 { return 0 }
+func (APIConnFederationTest) SetWatchChan(watch.Chan) {}
+func (APIConnFederationTest) Watch(string) error { return nil }
+func (APIConnFederationTest) StopWatching(string) {}
func (APIConnFederationTest) PodIndex(string) []*api.Pod {
a := []*api.Pod{{
diff --git a/plugin/kubernetes/README.md b/plugin/kubernetes/README.md
index 128e843e2..24965126a 100644
--- a/plugin/kubernetes/README.md
+++ b/plugin/kubernetes/README.md
@@ -110,6 +110,11 @@ kubernetes [ZONES...] {
This plugin implements dynamic health checking. Currently this is limited to reporting healthy when
the API has synced.
+## Watch
+
+This plugin implements watch. A client that connects to CoreDNS using `coredns/client` can be notified
+of changes to A, AAAA, and SRV records for Kubernetes services and endpoints.
+
## Examples
Handle all queries in the `cluster.local` zone. Connect to Kubernetes in-cluster. Also handle all
diff --git a/plugin/kubernetes/controller.go b/plugin/kubernetes/controller.go
index 0d7370a56..286f87d8e 100644
--- a/plugin/kubernetes/controller.go
+++ b/plugin/kubernetes/controller.go
@@ -7,6 +7,8 @@ import (
"sync/atomic"
"time"
+ dnswatch "github.com/coredns/coredns/plugin/pkg/watch"
+
api "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/cache"
@@ -45,6 +47,11 @@ type dnsController interface {
// Modified returns the timestamp of the most recent changes
Modified() int64
+
+ // Watch-related items
+ SetWatchChan(dnswatch.Chan)
+ Watch(string) error
+ StopWatching(string)
}
type dnsControl struct {
@@ -73,6 +80,12 @@ type dnsControl struct {
stopLock sync.Mutex
shutdown bool
stopCh chan struct{}
+
+ // watch-related items channel
+ watchChan dnswatch.Chan
+ watched map[string]bool
+ zones []string
+ endpointNameMode bool
}
type dnsControlOpts struct {
@@ -83,14 +96,20 @@ type dnsControlOpts struct {
// Label handling.
labelSelector *meta.LabelSelector
selector labels.Selector
+
+ zones []string
+ endpointNameMode bool
}
// newDNSController creates a controller for CoreDNS.
func newdnsController(kubeClient *kubernetes.Clientset, opts dnsControlOpts) *dnsControl {
dns := dnsControl{
- client: kubeClient,
- selector: opts.selector,
- stopCh: make(chan struct{}),
+ client: kubeClient,
+ selector: opts.selector,
+ stopCh: make(chan struct{}),
+ watched: make(map[string]bool),
+ zones: opts.zones,
+ endpointNameMode: opts.endpointNameMode,
}
dns.svcLister, dns.svcController = cache.NewIndexerInformer(
@@ -292,6 +311,22 @@ func namespaceWatchFunc(c *kubernetes.Clientset, s labels.Selector) func(options
}
}
+func (dns *dnsControl) SetWatchChan(c dnswatch.Chan) {
+ dns.watchChan = c
+}
+
+func (dns *dnsControl) Watch(qname string) error {
+ if dns.watchChan == nil {
+ return fmt.Errorf("cannot start watch because the channel has not been set")
+ }
+ dns.watched[qname] = true
+ return nil
+}
+
+func (dns *dnsControl) StopWatching(qname string) {
+ delete(dns.watched, qname)
+}
+
// Stop stops the controller.
func (dns *dnsControl) Stop() error {
dns.stopLock.Lock()
@@ -492,63 +527,164 @@ func (dns *dnsControl) updateModifed() {
atomic.StoreInt64(&dns.modified, unix)
}
-func (dns *dnsControl) Add(obj interface{}) { dns.updateModifed() }
-func (dns *dnsControl) Delete(obj interface{}) { dns.updateModifed() }
+func (dns *dnsControl) sendServiceUpdates(s *api.Service) {
+ for i := range dns.zones {
+ name := serviceFQDN(s, dns.zones[i])
+ if _, ok := dns.watched[name]; ok {
+ dns.watchChan <- name
+ }
+ }
+}
+
+func (dns *dnsControl) sendPodUpdates(p *api.Pod) {
+ for i := range dns.zones {
+ name := podFQDN(p, dns.zones[i])
+ if _, ok := dns.watched[name]; ok {
+ dns.watchChan <- name
+ }
+ }
+}
+
+func (dns *dnsControl) sendEndpointsUpdates(ep *api.Endpoints) {
+ for _, zone := range dns.zones {
+ names := append(endpointFQDN(ep, zone, dns.endpointNameMode), serviceFQDN(ep, zone))
+ for _, name := range names {
+ if _, ok := dns.watched[name]; ok {
+ dns.watchChan <- name
+ }
+ }
+ }
+}
+
+// endpointsSubsetDiffs returns an Endpoints struct containing the Subsets that have changed between a and b.
+// When we notify clients of changed endpoints we only want to notify them of endpoints that have changed.
+// The Endpoints API object holds more than one endpoint, held in a list of Subsets. Each Subset refers to
+// an endpoint. So, here we create a new Endpoints struct, and populate it with only the endpoints that have changed.
+// This new Endpoints object is later used to generate the list of endpoint FQDNs to send to the client.
+// This function computes this literally by combining the sets (in a and not in b) union (in b and not in a).
+func endpointsSubsetDiffs(a, b *api.Endpoints) *api.Endpoints {
+ c := b.DeepCopy()
+ c.Subsets = []api.EndpointSubset{}
+
+ // In the following loop, the first iteration computes (in a but not in b).
+ // The second iteration then adds (in b but not in a)
+ // The end result is an Endpoints that only contains the subsets (endpoints) that are different between a and b.
+ for _, abba := range [][]*api.Endpoints{{a, b}, {b, a}} {
+ a := abba[0]
+ b := abba[1]
+ left:
+ for _, as := range a.Subsets {
+ for _, bs := range b.Subsets {
+ if subsetsEquivalent(as, bs) {
+ continue left
+ }
+ }
+ c.Subsets = append(c.Subsets, as)
+ }
+ }
+ return c
+}
-func (dns *dnsControl) Update(objOld, newObj interface{}) {
- // endpoint updates can come frequently, make sure
- // it's a change we care about
- if o, ok := objOld.(*api.Endpoints); ok {
- n := newObj.(*api.Endpoints)
- if endpointsEquivalent(o, n) {
+// sendUpdates sends a notification to the server if a watch
+// is enabled for the qname
+func (dns *dnsControl) sendUpdates(oldObj, newObj interface{}) {
+ // If both objects have the same resource version, they are identical.
+ if newObj != nil && oldObj != nil && (oldObj.(meta.Object).GetResourceVersion() == newObj.(meta.Object).GetResourceVersion()) {
+ return
+ }
+ obj := newObj
+ if obj == nil {
+ obj = oldObj
+ }
+ switch ob := obj.(type) {
+ case *api.Service:
+ dns.updateModifed()
+ dns.sendServiceUpdates(ob)
+ case *api.Endpoints:
+ if newObj == nil || oldObj == nil {
+ dns.updateModifed()
+ dns.sendEndpointsUpdates(ob)
+ return
+ }
+ p := oldObj.(*api.Endpoints)
+ // endpoint updates can come frequently, make sure it's a change we care about
+ if endpointsEquivalent(p, ob) {
return
}
+ dns.updateModifed()
+ dns.sendEndpointsUpdates(endpointsSubsetDiffs(p, ob))
+ case *api.Pod:
+ dns.updateModifed()
+ dns.sendPodUpdates(ob)
+ default:
+ log.Warningf("Updates for %T not supported.", ob)
}
- dns.updateModifed()
}
-// endpointsEquivalent checks if the update to an endpoint is something
-// that matters to us: ready addresses, host names, ports (including names for SRV)
-func endpointsEquivalent(a, b *api.Endpoints) bool {
- // supposedly we should be able to rely on
- // these being sorted and able to be compared
- // they are supposed to be in a canonical format
+func (dns *dnsControl) Add(obj interface{}) {
+ dns.sendUpdates(nil, obj)
+}
+func (dns *dnsControl) Delete(obj interface{}) {
+ dns.sendUpdates(obj, nil)
+}
+func (dns *dnsControl) Update(oldObj, newObj interface{}) {
+ dns.sendUpdates(oldObj, newObj)
+}
- if len(a.Subsets) != len(b.Subsets) {
+// subsetsEquivalent checks if two endpoint subsets are significantly equivalent
+// I.e. that they have the same ready addresses, host names, ports (including protocol
+// and service names for SRV)
+func subsetsEquivalent(sa, sb api.EndpointSubset) bool {
+ if len(sa.Addresses) != len(sb.Addresses) {
+ return false
+ }
+ if len(sa.Ports) != len(sb.Ports) {
return false
}
- for i, sa := range a.Subsets {
- // check the Addresses and Ports. Ignore unready addresses.
- sb := b.Subsets[i]
- if len(sa.Addresses) != len(sb.Addresses) {
+ // in Addresses and Ports, we should be able to rely on
+ // these being sorted and able to be compared
+ // they are supposed to be in a canonical format
+ for addr, aaddr := range sa.Addresses {
+ baddr := sb.Addresses[addr]
+ if aaddr.IP != baddr.IP {
return false
}
- if len(sa.Ports) != len(sb.Ports) {
+ if aaddr.Hostname != baddr.Hostname {
return false
}
+ }
- for addr, aaddr := range sa.Addresses {
- baddr := sb.Addresses[addr]
- if aaddr.IP != baddr.IP {
- return false
- }
- if aaddr.Hostname != baddr.Hostname {
- return false
- }
+ for port, aport := range sa.Ports {
+ bport := sb.Ports[port]
+ if aport.Name != bport.Name {
+ return false
+ }
+ if aport.Port != bport.Port {
+ return false
+ }
+ if aport.Protocol != bport.Protocol {
+ return false
}
+ }
+ return true
+}
- for port, aport := range sa.Ports {
- bport := sb.Ports[port]
- if aport.Name != bport.Name {
- return false
- }
- if aport.Port != bport.Port {
- return false
- }
- if aport.Protocol != bport.Protocol {
- return false
- }
+// endpointsEquivalent checks if the update to an endpoint is something
+// that matters to us or if they are effectively equivalent.
+func endpointsEquivalent(a, b *api.Endpoints) bool {
+
+ if len(a.Subsets) != len(b.Subsets) {
+ return false
+ }
+
+ // we should be able to rely on
+ // these being sorted and able to be compared
+ // they are supposed to be in a canonical format
+ for i, sa := range a.Subsets {
+ sb := b.Subsets[i]
+ if !subsetsEquivalent(sa, sb) {
+ return false
}
}
return true
diff --git a/plugin/kubernetes/controller_test.go b/plugin/kubernetes/controller_test.go
new file mode 100644
index 000000000..02915fb51
--- /dev/null
+++ b/plugin/kubernetes/controller_test.go
@@ -0,0 +1,53 @@
+package kubernetes
+
+import (
+ "strconv"
+ "strings"
+ "testing"
+
+ api "k8s.io/api/core/v1"
+)
+
+func endpointSubsets(addrs ...string) (eps []api.EndpointSubset) {
+ for _, ap := range addrs {
+ apa := strings.Split(ap, ":")
+ address := apa[0]
+ port, _ := strconv.Atoi(apa[1])
+ eps = append(eps, api.EndpointSubset{Addresses: []api.EndpointAddress{{IP: address}}, Ports: []api.EndpointPort{{Port: int32(port)}}})
+ }
+ return eps
+}
+
+func TestEndpointsSubsetDiffs(t *testing.T) {
+ var tests = []struct {
+ a, b, expected api.Endpoints
+ }{
+ { // From a->b: Nothing changes
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ api.Endpoints{},
+ },
+ { // From a->b: Everything goes away
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ api.Endpoints{},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ },
+ { // From a->b: Everything is new
+ api.Endpoints{},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80", "10.0.0.2:8080")},
+ },
+ { // From a->b: One goes away, one is new
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.2:8080")},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.1:80")},
+ api.Endpoints{Subsets: endpointSubsets("10.0.0.2:8080", "10.0.0.1:80")},
+ },
+ }
+
+ for i, te := range tests {
+ got := endpointsSubsetDiffs(&te.a, &te.b)
+ if !endpointsEquivalent(got, &te.expected) {
+ t.Errorf("Expected '%v' for test %v, got '%v'.", te.expected, i, got)
+ }
+ }
+}
diff --git a/plugin/kubernetes/handler_test.go b/plugin/kubernetes/handler_test.go
index 388903137..2edeb8e8e 100644
--- a/plugin/kubernetes/handler_test.go
+++ b/plugin/kubernetes/handler_test.go
@@ -6,6 +6,7 @@ import (
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
@@ -332,6 +333,9 @@ func (APIConnServeTest) Stop() error { return nil }
func (APIConnServeTest) EpIndexReverse(string) []*api.Endpoints { return nil }
func (APIConnServeTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnServeTest) Modified() int64 { return time.Now().Unix() }
+func (APIConnServeTest) SetWatchChan(watch.Chan) {}
+func (APIConnServeTest) Watch(string) error { return nil }
+func (APIConnServeTest) StopWatching(string) {}
func (APIConnServeTest) PodIndex(string) []*api.Pod {
a := []*api.Pod{{
diff --git a/plugin/kubernetes/kubernetes.go b/plugin/kubernetes/kubernetes.go
index af0e64ee9..03b93748b 100644
--- a/plugin/kubernetes/kubernetes.go
+++ b/plugin/kubernetes/kubernetes.go
@@ -260,6 +260,8 @@ func (k *Kubernetes) InitKubeCache() (err error) {
k.opts.initPodCache = k.podMode == podModeVerified
+ k.opts.zones = k.Zones
+ k.opts.endpointNameMode = k.endpointNameMode
k.APIConn = newdnsController(kubeClient, k.opts)
return err
@@ -292,6 +294,29 @@ func (k *Kubernetes) Records(state request.Request, exact bool) ([]msg.Service,
return services, err
}
+// serviceFQDN returns the k8s cluster dns spec service FQDN for the service (or endpoint) object.
+func serviceFQDN(obj meta.Object, zone string) string {
+ return dnsutil.Join(append([]string{}, obj.GetName(), obj.GetNamespace(), Svc, zone))
+}
+
+// podFQDN returns the k8s cluster dns spec FQDN for the pod.
+func podFQDN(p *api.Pod, zone string) string {
+ name := strings.Replace(p.Status.PodIP, ".", "-", -1)
+ name = strings.Replace(name, ":", "-", -1)
+ return dnsutil.Join(append([]string{}, name, p.GetNamespace(), Pod, zone))
+}
+
+// endpointFQDN returns a list of k8s cluster dns spec service FQDNs for each subset in the endpoint.
+func endpointFQDN(ep *api.Endpoints, zone string, endpointNameMode bool) []string {
+ var names []string
+ for _, ss := range ep.Subsets {
+ for _, addr := range ss.Addresses {
+ names = append(names, dnsutil.Join(append([]string{}, endpointHostname(addr, endpointNameMode), serviceFQDN(ep, zone))))
+ }
+ }
+ return names
+}
+
func endpointHostname(addr api.EndpointAddress, endpointNameMode bool) string {
if addr.Hostname != "" {
return strings.ToLower(addr.Hostname)
diff --git a/plugin/kubernetes/kubernetes_test.go b/plugin/kubernetes/kubernetes_test.go
index e10fe894b..36d00a92f 100644
--- a/plugin/kubernetes/kubernetes_test.go
+++ b/plugin/kubernetes/kubernetes_test.go
@@ -4,6 +4,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -64,6 +65,9 @@ func (APIConnServiceTest) PodIndex(string) []*api.Pod { return nil }
func (APIConnServiceTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnServiceTest) EpIndexReverse(string) []*api.Endpoints { return nil }
func (APIConnServiceTest) Modified() int64 { return 0 }
+func (APIConnServiceTest) SetWatchChan(watch.Chan) {}
+func (APIConnServiceTest) Watch(string) error { return nil }
+func (APIConnServiceTest) StopWatching(string) {}
func (APIConnServiceTest) SvcIndex(string) []*api.Service {
svcs := []*api.Service{
@@ -390,3 +394,85 @@ func TestServices(t *testing.T) {
}
}
}
+
+func TestServiceFQDN(t *testing.T) {
+ fqdn := serviceFQDN(
+ &api.Service{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "svc1",
+ Namespace: "testns",
+ },
+ }, "cluster.local")
+
+ expected := "svc1.testns.svc.cluster.local."
+ if fqdn != expected {
+ t.Errorf("Expected '%v', got '%v'.", expected, fqdn)
+ }
+}
+
+func TestPodFQDN(t *testing.T) {
+ fqdn := podFQDN(
+ &api.Pod{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "pod1",
+ Namespace: "testns",
+ },
+ Status: api.PodStatus{
+ PodIP: "10.10.0.10",
+ },
+ }, "cluster.local")
+
+ expected := "10-10-0-10.testns.pod.cluster.local."
+ if fqdn != expected {
+ t.Errorf("Expected '%v', got '%v'.", expected, fqdn)
+ }
+ fqdn = podFQDN(
+ &api.Pod{
+ ObjectMeta: meta.ObjectMeta{
+ Name: "pod1",
+ Namespace: "testns",
+ },
+ Status: api.PodStatus{
+ PodIP: "aaaa:bbbb:cccc::zzzz",
+ },
+ }, "cluster.local")
+
+ expected = "aaaa-bbbb-cccc--zzzz.testns.pod.cluster.local."
+ if fqdn != expected {
+ t.Errorf("Expected '%v', got '%v'.", expected, fqdn)
+ }
+}
+
+func TestEndpointFQDN(t *testing.T) {
+ fqdns := endpointFQDN(
+ &api.Endpoints{
+ Subsets: []api.EndpointSubset{
+ {
+ Addresses: []api.EndpointAddress{
+ {
+ IP: "172.0.0.1",
+ Hostname: "ep1a",
+ },
+ {
+ IP: "172.0.0.2",
+ },
+ },
+ },
+ },
+ ObjectMeta: meta.ObjectMeta{
+ Name: "svc1",
+ Namespace: "testns",
+ },
+ }, "cluster.local", false)
+
+ expected := []string{
+ "ep1a.svc1.testns.svc.cluster.local.",
+ "172-0-0-2.svc1.testns.svc.cluster.local.",
+ }
+
+ for i := range fqdns {
+ if fqdns[i] != expected[i] {
+ t.Errorf("Expected '%v', got '%v'.", expected[i], fqdns[i])
+ }
+ }
+}
diff --git a/plugin/kubernetes/ns_test.go b/plugin/kubernetes/ns_test.go
index 7dcc83eeb..f331d3231 100644
--- a/plugin/kubernetes/ns_test.go
+++ b/plugin/kubernetes/ns_test.go
@@ -3,6 +3,8 @@ package kubernetes
import (
"testing"
+ "github.com/coredns/coredns/plugin/pkg/watch"
+
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
)
@@ -18,6 +20,9 @@ func (APIConnTest) SvcIndexReverse(string) []*api.Service { return nil }
func (APIConnTest) EpIndex(string) []*api.Endpoints { return nil }
func (APIConnTest) EndpointsList() []*api.Endpoints { return nil }
func (APIConnTest) Modified() int64 { return 0 }
+func (APIConnTest) SetWatchChan(watch.Chan) {}
+func (APIConnTest) Watch(string) error { return nil }
+func (APIConnTest) StopWatching(string) {}
func (APIConnTest) ServiceList() []*api.Service {
svcs := []*api.Service{
diff --git a/plugin/kubernetes/reverse_test.go b/plugin/kubernetes/reverse_test.go
index 2cf41de1a..681172021 100644
--- a/plugin/kubernetes/reverse_test.go
+++ b/plugin/kubernetes/reverse_test.go
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
@@ -22,6 +23,9 @@ func (APIConnReverseTest) EpIndex(string) []*api.Endpoints { return nil }
func (APIConnReverseTest) EndpointsList() []*api.Endpoints { return nil }
func (APIConnReverseTest) ServiceList() []*api.Service { return nil }
func (APIConnReverseTest) Modified() int64 { return 0 }
+func (APIConnReverseTest) SetWatchChan(watch.Chan) {}
+func (APIConnReverseTest) Watch(string) error { return nil }
+func (APIConnReverseTest) StopWatching(string) {}
func (APIConnReverseTest) SvcIndex(svc string) []*api.Service {
if svc != "svc1.testns" {
diff --git a/plugin/kubernetes/watch.go b/plugin/kubernetes/watch.go
new file mode 100644
index 000000000..488540444
--- /dev/null
+++ b/plugin/kubernetes/watch.go
@@ -0,0 +1,20 @@
+package kubernetes
+
+import (
+ "github.com/coredns/coredns/plugin/pkg/watch"
+)
+
+// SetWatchChan implements watch.Watchable
+func (k *Kubernetes) SetWatchChan(c watch.Chan) {
+ k.APIConn.SetWatchChan(c)
+}
+
+// Watch is called when a watch is started for a name.
+func (k *Kubernetes) Watch(qname string) error {
+ return k.APIConn.Watch(qname)
+}
+
+// StopWatching is called when no more watches remain for a name
+func (k *Kubernetes) StopWatching(qname string) {
+ k.APIConn.StopWatching(qname)
+}
diff --git a/plugin/kubernetes/watch_test.go b/plugin/kubernetes/watch_test.go
new file mode 100644
index 000000000..46b2e5dc4
--- /dev/null
+++ b/plugin/kubernetes/watch_test.go
@@ -0,0 +1,15 @@
+package kubernetes
+
+import (
+ "testing"
+
+ "github.com/coredns/coredns/plugin/pkg/watch"
+)
+
+func TestIsWatchable(t *testing.T) {
+ k := &Kubernetes{}
+ var i interface{} = k
+ if _, ok := i.(watch.Watchable); !ok {
+ t.Error("Kubernetes should implement watch.Watchable and does not")
+ }
+}
diff --git a/plugin/pkg/watch/watch.go b/plugin/pkg/watch/watch.go
new file mode 100644
index 000000000..7e77bb7b3
--- /dev/null
+++ b/plugin/pkg/watch/watch.go
@@ -0,0 +1,23 @@
+package watch
+
+// Chan is used to inform the server of a change. Whenever
+// a watched FQDN has a change in data, that FQDN should be
+// sent down this channel.
+type Chan chan string
+
+// Watchable is the interface watchable plugins should implement
+type Watchable interface {
+ // Name returns the plugin name.
+ Name() string
+
+ // SetWatchChan is called when the watch channel is created.
+ SetWatchChan(Chan)
+
+ // Watch is called whenever a watch is created for a FQDN. Plugins
+ // should send the FQDN down the watch channel when its data may have
+ // changed. This is an exact match only.
+ Watch(qname string) error
+
+ // StopWatching is called whenever all watches are canceled for a FQDN.
+ StopWatching(qname string)
+}
diff --git a/plugin/pkg/watch/watcher.go b/plugin/pkg/watch/watcher.go
new file mode 100644
index 000000000..59474a7bc
--- /dev/null
+++ b/plugin/pkg/watch/watcher.go
@@ -0,0 +1,178 @@
+package watch
+
+import (
+ "fmt"
+ "io"
+ "sync"
+
+ "github.com/miekg/dns"
+
+ "github.com/coredns/coredns/pb"
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/pkg/log"
+ "github.com/coredns/coredns/request"
+)
+
+// Watcher handles watch creation, cancellation, and processing.
+type Watcher interface {
+ // Watch monitors a client stream and creates and cancels watches.
+ Watch(pb.DnsService_WatchServer) error
+
+ // Stop cancels open watches and stops the watch processing go routine.
+ Stop()
+}
+
+// Manager contains all the data needed to manage watches
+type Manager struct {
+ changes Chan
+ stopper chan bool
+ counter int64
+ watches map[string]watchlist
+ plugins []Watchable
+ mutex sync.Mutex
+}
+
+type watchlist map[int64]pb.DnsService_WatchServer
+
+// NewWatcher creates a Watcher, which is used to manage watched names.
+func NewWatcher(plugins []Watchable) *Manager {
+ w := &Manager{changes: make(Chan), stopper: make(chan bool), watches: make(map[string]watchlist), plugins: plugins}
+
+ for _, p := range plugins {
+ p.SetWatchChan(w.changes)
+ }
+
+ go w.process()
+ return w
+}
+
+func (w *Manager) nextID() int64 {
+ w.mutex.Lock()
+
+ w.counter++
+ id := w.counter
+
+ w.mutex.Unlock()
+ return id
+}
+
+// Watch monitors a client stream and creates and cancels watches.
+func (w *Manager) Watch(stream pb.DnsService_WatchServer) error {
+ for {
+ in, err := stream.Recv()
+ if err == io.EOF {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ create := in.GetCreateRequest()
+ if create != nil {
+ msg := new(dns.Msg)
+ err := msg.Unpack(create.Query.Msg)
+ if err != nil {
+ log.Warningf("Could not decode watch request: %s\n", err)
+ stream.Send(&pb.WatchResponse{Err: "could not decode request"})
+ continue
+ }
+ id := w.nextID()
+ if err := stream.Send(&pb.WatchResponse{WatchId: id, Created: true}); err != nil {
+ // if we fail to notify client of watch creation, don't create the watch
+ continue
+ }
+
+ // Normalize qname
+ qname := (&request.Request{Req: msg}).Name()
+
+ w.mutex.Lock()
+ if _, ok := w.watches[qname]; !ok {
+ w.watches[qname] = make(watchlist)
+ }
+ w.watches[qname][id] = stream
+ w.mutex.Unlock()
+
+ for _, p := range w.plugins {
+ err := p.Watch(qname)
+ if err != nil {
+ log.Warningf("Failed to start watch for %s in plugin %s: %s\n", qname, p.Name(), err)
+ stream.Send(&pb.WatchResponse{Err: fmt.Sprintf("failed to start watch for %s in plugin %s", qname, p.Name())})
+ }
+ }
+ continue
+ }
+
+ cancel := in.GetCancelRequest()
+ if cancel != nil {
+ w.mutex.Lock()
+ for qname, wl := range w.watches {
+ ws, ok := wl[cancel.WatchId]
+ if !ok {
+ continue
+ }
+
+ // only allow cancels from the client that started it
+ // TODO: test what happens if a stream tries to cancel a watchID that it doesn't own
+ if ws != stream {
+ continue
+ }
+
+ delete(wl, cancel.WatchId)
+
+ // if there are no more watches for this qname, we should tell the plugins
+ if len(wl) == 0 {
+ for _, p := range w.plugins {
+ p.StopWatching(qname)
+ }
+ delete(w.watches, qname)
+ }
+
+ // let the client know we canceled the watch
+ stream.Send(&pb.WatchResponse{WatchId: cancel.WatchId, Canceled: true})
+ }
+ w.mutex.Unlock()
+ continue
+ }
+ }
+}
+
+func (w *Manager) process() {
+ for {
+ select {
+ case <-w.stopper:
+ return
+ case changed := <-w.changes:
+ w.mutex.Lock()
+ for qname, wl := range w.watches {
+ if plugin.Zones([]string{changed}).Matches(qname) == "" {
+ continue
+ }
+ for id, stream := range wl {
+ wr := pb.WatchResponse{WatchId: id, Qname: qname}
+ err := stream.Send(&wr)
+ if err != nil {
+ log.Warningf("Error sending change for %s to watch %d: %s. Removing watch.\n", qname, id, err)
+ delete(w.watches[qname], id)
+ }
+ }
+ }
+ w.mutex.Unlock()
+ }
+ }
+}
+
+// Stop cancels open watches and stops the watch processing go routine.
+func (w *Manager) Stop() {
+ w.stopper <- true
+ w.mutex.Lock()
+ for wn, wl := range w.watches {
+ for id, stream := range wl {
+ wr := pb.WatchResponse{WatchId: id, Canceled: true}
+ err := stream.Send(&wr)
+ if err != nil {
+ log.Warningf("Error notifiying client of cancellation: %s\n", err)
+ }
+ }
+ delete(w.watches, wn)
+ }
+ w.mutex.Unlock()
+}