aboutsummaryrefslogtreecommitdiff
path: root/plugin
diff options
context:
space:
mode:
Diffstat (limited to 'plugin')
-rw-r--r--plugin/transfer/failed_write_test.go31
-rw-r--r--plugin/transfer/transfer.go35
2 files changed, 56 insertions, 10 deletions
diff --git a/plugin/transfer/failed_write_test.go b/plugin/transfer/failed_write_test.go
new file mode 100644
index 000000000..90b5c4de2
--- /dev/null
+++ b/plugin/transfer/failed_write_test.go
@@ -0,0 +1,31 @@
+package transfer
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "github.com/coredns/coredns/plugin/test"
+
+ "github.com/miekg/dns"
+)
+
+type badwriter struct {
+ dns.ResponseWriter
+ count int
+}
+
+func (w *badwriter) WriteMsg(res *dns.Msg) error { return fmt.Errorf("failed to write msg") }
+
+func TestWriteMessageFailed(t *testing.T) {
+ transfer := newTestTransfer()
+ ctx := context.TODO()
+ w := &badwriter{ResponseWriter: &test.ResponseWriter{}}
+ m := &dns.Msg{}
+ m.SetAxfr("example.org.")
+
+ _, err := transfer.ServeDNS(ctx, w, m)
+ if err == nil {
+ t.Error("Expected error, got none")
+ }
+}
diff --git a/plugin/transfer/transfer.go b/plugin/transfer/transfer.go
index 3558f2e0f..45251cda0 100644
--- a/plugin/transfer/transfer.go
+++ b/plugin/transfer/transfer.go
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"net"
- "sync"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
@@ -107,11 +106,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
// Send response to client
ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
- wg := new(sync.WaitGroup)
- wg.Add(1)
+ errCh := make(chan error)
go func() {
- tr.Out(w, r, ch)
- wg.Done()
+ if err := tr.Out(w, r, ch); err != nil {
+ errCh <- err
+ }
+ close(errCh)
}()
rrs := []dns.RR{}
@@ -123,7 +123,11 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
}
rrs = append(rrs, records...)
if len(rrs) > 500 {
- ch <- &dns.Envelope{RR: rrs}
+ select {
+ case ch <- &dns.Envelope{RR: rrs}:
+ case err := <-errCh:
+ return dns.RcodeServerFailure, err
+ }
l += len(rrs)
rrs = []dns.RR{}
}
@@ -134,7 +138,10 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
// need to return the SOA back to the client and return.
if len(rrs) == 1 && soa != nil { // soa should never be nil...
close(ch)
- wg.Wait()
+ err := <-errCh
+ if err != nil {
+ return dns.RcodeServerFailure, err
+ }
m := new(dns.Msg)
m.SetReply(r)
@@ -146,12 +153,20 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
}
if len(rrs) > 0 {
- ch <- &dns.Envelope{RR: rrs}
+ select {
+ case ch <- &dns.Envelope{RR: rrs}:
+ case err := <-errCh:
+ return dns.RcodeServerFailure, err
+ }
l += len(rrs)
+
}
- close(ch) // Even though we close the channel here, we still have
- wg.Wait() // to wait before we can return and close the connection.
+ close(ch) // Even though we close the channel here, we still have
+ err = <-errCh // to wait before we can return and close the connection.
+ if err != nil {
+ return dns.RcodeServerFailure, err
+ }
logserial := uint32(0)
if soa != nil {