aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r--vendor/google.golang.org/grpc/stream.go155
1 files changed, 106 insertions, 49 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index 75eab40b1..9eeaafef8 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -19,7 +19,6 @@
package grpc
import (
- "bytes"
"errors"
"io"
"sync"
@@ -29,6 +28,7 @@ import (
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -94,15 +94,23 @@ type ClientStream interface {
Stream
}
-// NewClientStream creates a new Stream for the client side. This is called
-// by generated code.
-func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
+// NewStream creates a new Stream for the client side. This is typically
+// called by generated code.
+func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
return newClientStream(ctx, desc, cc, method, opts...)
}
+// NewClientStream creates a new Stream for the client side. This is typically
+// called by generated code.
+//
+// DEPRECATED: Use ClientConn.NewStream instead.
+func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+ return cc.NewStream(ctx, desc, method, opts...)
+}
+
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
@@ -116,7 +124,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
c.failFast = !*mc.WaitForReady
}
- if mc.Timeout != nil {
+ if mc.Timeout != nil && *mc.Timeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer func() {
if err != nil {
@@ -143,8 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// time soon, so we ask the transport to flush the header.
Flush: desc.ClientStreams,
}
- if cc.dopts.cp != nil {
+
+ // Set our outgoing compression according to the UseCompressor CallOption, if
+ // set. In that case, also find the compressor from the encoding package.
+ // Otherwise, use the compressor configured by the WithCompressor DialOption,
+ // if set.
+ var cp Compressor
+ var comp encoding.Compressor
+ if ct := c.compressorType; ct != "" {
+ callHdr.SendCompress = ct
+ if ct != encoding.Identity {
+ comp = encoding.GetCompressor(ct)
+ if comp == nil {
+ return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct)
+ }
+ }
+ } else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
+ cp = cc.dopts.cp
}
if c.creds != nil {
callHdr.Creds = c.creds
@@ -189,42 +213,39 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}()
}
+
for {
+ // Check to make sure the context has expired. This will prevent us from
+ // looping forever if an error occurs for wait-for-ready RPCs where no data
+ // is sent on the wire.
+ select {
+ case <-ctx.Done():
+ return nil, toRPCErr(ctx.Err())
+ default:
+ }
+
t, done, err = cc.getTransport(ctx, c.failFast)
if err != nil {
- // TODO(zhaoq): Probably revisit the error handling.
- if _, ok := status.FromError(err); ok {
- return nil, err
- }
- if err == errConnClosing || err == errConnUnavailable {
- if c.failFast {
- return nil, Errorf(codes.Unavailable, "%v", err)
- }
- continue
- }
- // All the other errors are treated as Internal errors.
- return nil, Errorf(codes.Internal, "%v", err)
+ return nil, err
}
s, err = t.NewStream(ctx, callHdr)
if err != nil {
- if _, ok := err.(transport.ConnectionError); ok && done != nil {
- // If error is connection error, transport was sending data on wire,
- // and we are not sure if anything has been sent on wire.
- // If error is not connection error, we are sure nothing has been sent.
- updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
- }
if done != nil {
done(balancer.DoneInfo{Err: err})
done = nil
}
- if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
+ // In the event of any error from NewStream, we never attempted to write
+ // anything to the wire, so we can retry indefinitely for non-fail-fast
+ // RPCs.
+ if !c.failFast {
continue
}
return nil, toRPCErr(err)
}
break
}
+
// Set callInfo.peer object from stream's context.
if peer, ok := peer.FromContext(s.Context()); ok {
c.peer = peer
@@ -234,8 +255,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
c: c,
desc: desc,
codec: cc.dopts.codec,
- cp: cc.dopts.cp,
+ cp: cp,
dc: cc.dopts.dc,
+ comp: comp,
cancel: cancel,
done: done,
@@ -249,8 +271,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsCtx: ctx,
statsHandler: cc.dopts.copts.StatsHandler,
}
- // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
- // when there is no pending I/O operations on this stream.
+ // Listen on s.Context().Done() to detect cancellation and s.Done() to detect
+ // normal termination when there is no pending I/O operations on this stream.
go func() {
select {
case <-t.Error():
@@ -277,15 +299,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream.
type clientStream struct {
- opts []CallOption
- c *callInfo
- t transport.ClientTransport
- s *transport.Stream
- p *parser
- desc *StreamDesc
- codec Codec
- cp Compressor
- dc Decompressor
+ opts []CallOption
+ c *callInfo
+ t transport.ClientTransport
+ s *transport.Stream
+ p *parser
+ desc *StreamDesc
+
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ comp encoding.Compressor
+ decomp encoding.Compressor
+ decompSet bool
+
cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
@@ -361,7 +388,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Client: true,
}
}
- hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
+ hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
if err != nil {
return err
}
@@ -389,7 +416,23 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload)
+ if !cs.decompSet {
+ // Block until we receive headers containing received message encoding.
+ if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity {
+ if cs.dc == nil || cs.dc.Type() != ct {
+ // No configured decompressor, or it does not match the incoming
+ // message encoding; attempt to find a registered compressor that does.
+ cs.dc = nil
+ cs.decomp = encoding.GetCompressor(ct)
+ }
+ } else {
+ // No compression is used; disable our decompressor.
+ cs.dc = nil
+ }
+ // Only initialize this state once per stream.
+ cs.decompSet = true
+ }
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -415,7 +458,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -487,7 +530,7 @@ func (cs *clientStream) finish(err error) {
}
if cs.done != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{
- bytesSent: cs.s.BytesSent(),
+ bytesSent: true,
bytesReceived: cs.s.BytesReceived(),
})
cs.done(balancer.DoneInfo{Err: err})
@@ -540,12 +583,16 @@ type ServerStream interface {
// serverStream implements a server side Stream.
type serverStream struct {
- t transport.ServerTransport
- s *transport.Stream
- p *parser
- codec Codec
- cp Compressor
- dc Decompressor
+ t transport.ServerTransport
+ s *transport.Stream
+ p *parser
+ codec Codec
+
+ cp Compressor
+ dc Decompressor
+ comp encoding.Compressor
+ decomp encoding.Compressor
+
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
@@ -601,7 +648,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
+ hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
if err != nil {
return err
}
@@ -641,7 +688,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if ss.statsHandler != nil {
inPayload = &stats.InPayload{}
}
- if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
+ if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
if err == io.EOF {
return err
}
@@ -655,3 +702,13 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
}
return nil
}
+
+// MethodFromServerStream returns the method string for the input stream.
+// The returned string is in the format of "/service/method".
+func MethodFromServerStream(stream ServerStream) (string, bool) {
+ s, ok := transport.StreamFromContext(stream.Context())
+ if !ok {
+ return "", ok
+ }
+ return s.Method(), ok
+}