aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r--vendor/google.golang.org/grpc/stream.go90
1 files changed, 43 insertions, 47 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index 1c621ba87..75eab40b1 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -27,6 +27,7 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/trace"
+ "google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@@ -106,10 +107,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
var (
t transport.ClientTransport
s *transport.Stream
- put func()
+ done func(balancer.DoneInfo)
cancel context.CancelFunc
)
- c := defaultCallInfo
+ c := defaultCallInfo()
mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
@@ -117,11 +118,16 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if mc.Timeout != nil {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
+ defer func() {
+ if err != nil {
+ cancel()
+ }
+ }()
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
- if err := o.before(&c); err != nil {
+ if err := o.before(c); err != nil {
return nil, toRPCErr(err)
}
}
@@ -162,7 +168,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}()
}
- ctx = newContextWithRPCInfo(ctx)
+ ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
@@ -183,11 +189,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}()
}
- gopts := BalancerGetOptions{
- BlockingWait: !c.failFast,
- }
for {
- t, put, err = cc.getTransport(ctx, gopts)
+ t, done, err = cc.getTransport(ctx, c.failFast)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
if _, ok := status.FromError(err); ok {
@@ -205,15 +208,15 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr)
if err != nil {
- if _, ok := err.(transport.ConnectionError); ok && put != nil {
+ if _, ok := err.(transport.ConnectionError); ok && done != nil {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
- if put != nil {
- put()
- put = nil
+ if done != nil {
+ done(balancer.DoneInfo{Err: err})
+ done = nil
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
@@ -235,10 +238,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
dc: cc.dopts.dc,
cancel: cancel,
- put: put,
- t: t,
- s: s,
- p: &parser{r: s},
+ done: done,
+ t: t,
+ s: s,
+ p: &parser{r: s},
tracing: EnableTracing,
trInfo: trInfo,
@@ -246,9 +249,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsCtx: ctx,
statsHandler: cc.dopts.copts.StatsHandler,
}
- if cc.dopts.cp != nil {
- cs.cbuf = new(bytes.Buffer)
- }
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
// when there is no pending I/O operations on this stream.
go func() {
@@ -278,21 +278,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream.
type clientStream struct {
opts []CallOption
- c callInfo
+ c *callInfo
t transport.ClientTransport
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
cp Compressor
- cbuf *bytes.Buffer
dc Decompressor
cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex
- put func()
+ done func(balancer.DoneInfo)
closed bool
finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
@@ -342,7 +341,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
return
}
if err == io.EOF {
- // Specialize the process for server streaming. SendMesg is only called
+ // Specialize the process for server streaming. SendMsg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
@@ -362,22 +361,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Client: true,
}
}
- out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
- defer func() {
- if cs.cbuf != nil {
- cs.cbuf.Reset()
- }
- }()
+ hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
if err != nil {
return err
}
if cs.c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
- if len(out) > *cs.c.maxSendMessageSize {
- return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize)
+ if len(data) > *cs.c.maxSendMessageSize {
+ return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
}
- err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
@@ -449,7 +443,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
func (cs *clientStream) CloseSend() (err error) {
- err = cs.t.Write(cs.s, nil, &transport.Options{Last: true})
+ err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
defer func() {
if err != nil {
cs.finish(err)
@@ -489,15 +483,15 @@ func (cs *clientStream) finish(err error) {
}
}()
for _, o := range cs.opts {
- o.after(&cs.c)
+ o.after(cs.c)
}
- if cs.put != nil {
+ if cs.done != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{
bytesSent: cs.s.BytesSent(),
bytesReceived: cs.s.BytesReceived(),
})
- cs.put()
- cs.put = nil
+ cs.done(balancer.DoneInfo{Err: err})
+ cs.done = nil
}
if cs.statsHandler != nil {
end := &stats.End{
@@ -552,7 +546,6 @@ type serverStream struct {
codec Codec
cp Compressor
dc Decompressor
- cbuf *bytes.Buffer
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
@@ -599,24 +592,23 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
ss.mu.Unlock()
}
+ if err != nil && err != io.EOF {
+ st, _ := status.FromError(toRPCErr(err))
+ ss.t.WriteStatus(ss.s, st)
+ }
}()
var outPayload *stats.OutPayload
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
- defer func() {
- if ss.cbuf != nil {
- ss.cbuf.Reset()
- }
- }()
+ hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
if err != nil {
return err
}
- if len(out) > ss.maxSendMessageSize {
- return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
+ if len(data) > ss.maxSendMessageSize {
+ return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
}
- if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
+ if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
if outPayload != nil {
@@ -640,6 +632,10 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
}
ss.mu.Unlock()
}
+ if err != nil && err != io.EOF {
+ st, _ := status.FromError(toRPCErr(err))
+ ss.t.WriteStatus(ss.s, st)
+ }
}()
var inPayload *stats.InPayload
if ss.statsHandler != nil {