diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r-- | vendor/google.golang.org/grpc/stream.go | 155 |
1 files changed, 106 insertions, 49 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 75eab40b1..9eeaafef8 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -19,7 +19,6 @@ package grpc import ( - "bytes" "errors" "io" "sync" @@ -29,6 +28,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -94,15 +94,23 @@ type ClientStream interface { Stream } -// NewClientStream creates a new Stream for the client side. This is called -// by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { +// NewStream creates a new Stream for the client side. This is typically +// called by generated code. +func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) { if cc.dopts.streamInt != nil { return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) } return newClientStream(ctx, desc, cc, method, opts...) } +// NewClientStream creates a new Stream for the client side. This is typically +// called by generated code. +// +// DEPRECATED: Use ClientConn.NewStream instead. +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + return cc.NewStream(ctx, desc, method, opts...) +} + func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport @@ -116,7 +124,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth c.failFast = !*mc.WaitForReady } - if mc.Timeout != nil { + if mc.Timeout != nil && *mc.Timeout >= 0 { ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) defer func() { if err != nil { @@ -143,8 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // time soon, so we ask the transport to flush the header. Flush: desc.ClientStreams, } - if cc.dopts.cp != nil { + + // Set our outgoing compression according to the UseCompressor CallOption, if + // set. In that case, also find the compressor from the encoding package. + // Otherwise, use the compressor configured by the WithCompressor DialOption, + // if set. + var cp Compressor + var comp encoding.Compressor + if ct := c.compressorType; ct != "" { + callHdr.SendCompress = ct + if ct != encoding.Identity { + comp = encoding.GetCompressor(ct) + if comp == nil { + return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) + } + } + } else if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() + cp = cc.dopts.cp } if c.creds != nil { callHdr.Creds = c.creds @@ -189,42 +213,39 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } + for { + // Check to make sure the context has expired. This will prevent us from + // looping forever if an error occurs for wait-for-ready RPCs where no data + // is sent on the wire. + select { + case <-ctx.Done(): + return nil, toRPCErr(ctx.Err()) + default: + } + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { - // TODO(zhaoq): Probably revisit the error handling. - if _, ok := status.FromError(err); ok { - return nil, err - } - if err == errConnClosing || err == errConnUnavailable { - if c.failFast { - return nil, Errorf(codes.Unavailable, "%v", err) - } - continue - } - // All the other errors are treated as Internal errors. - return nil, Errorf(codes.Internal, "%v", err) + return nil, err } s, err = t.NewStream(ctx, callHdr) if err != nil { - if _, ok := err.(transport.ConnectionError); ok && done != nil { - // If error is connection error, transport was sending data on wire, - // and we are not sure if anything has been sent on wire. - // If error is not connection error, we are sure nothing has been sent. - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) - } if done != nil { done(balancer.DoneInfo{Err: err}) done = nil } - if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { + // In the event of any error from NewStream, we never attempted to write + // anything to the wire, so we can retry indefinitely for non-fail-fast + // RPCs. + if !c.failFast { continue } return nil, toRPCErr(err) } break } + // Set callInfo.peer object from stream's context. if peer, ok := peer.FromContext(s.Context()); ok { c.peer = peer @@ -234,8 +255,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth c: c, desc: desc, codec: cc.dopts.codec, - cp: cc.dopts.cp, + cp: cp, dc: cc.dopts.dc, + comp: comp, cancel: cancel, done: done, @@ -249,8 +271,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth statsCtx: ctx, statsHandler: cc.dopts.copts.StatsHandler, } - // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination - // when there is no pending I/O operations on this stream. + // Listen on s.Context().Done() to detect cancellation and s.Done() to detect + // normal termination when there is no pending I/O operations on this stream. go func() { select { case <-t.Error(): @@ -277,15 +299,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c *callInfo - t transport.ClientTransport - s *transport.Stream - p *parser - desc *StreamDesc - codec Codec - cp Compressor - dc Decompressor + opts []CallOption + c *callInfo + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + + codec Codec + cp Compressor + dc Decompressor + comp encoding.Compressor + decomp encoding.Compressor + decompSet bool + cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. @@ -361,7 +388,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, } } - hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload) + hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) if err != nil { return err } @@ -389,7 +416,23 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload) + if !cs.decompSet { + // Block until we receive headers containing received message encoding. + if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity { + if cs.dc == nil || cs.dc.Type() != ct { + // No configured decompressor, or it does not match the incoming + // message encoding; attempt to find a registered compressor that does. + cs.dc = nil + cs.decomp = encoding.GetCompressor(ct) + } + } else { + // No compression is used; disable our decompressor. + cs.dc = nil + } + // Only initialize this state once per stream. + cs.decompSet = true + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -415,7 +458,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if cs.c.maxReceiveMessageSize == nil { return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) @@ -487,7 +530,7 @@ func (cs *clientStream) finish(err error) { } if cs.done != nil { updateRPCInfoInContext(cs.s.Context(), rpcInfo{ - bytesSent: cs.s.BytesSent(), + bytesSent: true, bytesReceived: cs.s.BytesReceived(), }) cs.done(balancer.DoneInfo{Err: err}) @@ -540,12 +583,16 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { - t transport.ServerTransport - s *transport.Stream - p *parser - codec Codec - cp Compressor - dc Decompressor + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + + cp Compressor + dc Decompressor + comp encoding.Compressor + decomp encoding.Compressor + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -601,7 +648,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload) + hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) if err != nil { return err } @@ -641,7 +688,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil { if err == io.EOF { return err } @@ -655,3 +702,13 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } return nil } + +// MethodFromServerStream returns the method string for the input stream. +// The returned string is in the format of "/service/method". +func MethodFromServerStream(stream ServerStream) (string, bool) { + s, ok := transport.StreamFromContext(stream.Context()) + if !ok { + return "", ok + } + return s.Method(), ok +} |