diff options
Diffstat (limited to 'vendor/google.golang.org')
48 files changed, 3060 insertions, 1407 deletions
diff --git a/vendor/google.golang.org/grpc/CONTRIBUTING.md b/vendor/google.golang.org/grpc/CONTRIBUTING.md index a5c6e06e2..8ec6c9574 100644 --- a/vendor/google.golang.org/grpc/CONTRIBUTING.md +++ b/vendor/google.golang.org/grpc/CONTRIBUTING.md @@ -7,7 +7,7 @@ If you are new to github, please start by reading [Pull Request howto](https://h ## Legal requirements In order to protect both you and ourselves, you will need to sign the -[Contributor License Agreement](https://cla.developers.google.com/clas). +[Contributor License Agreement](https://identity.linuxfoundation.org/projects/cncf). ## Guidelines for Pull Requests How to get your contributions merged smoothly and quickly. diff --git a/vendor/google.golang.org/grpc/balancer.go b/vendor/google.golang.org/grpc/balancer.go index ab65049dd..300da6c5e 100644 --- a/vendor/google.golang.org/grpc/balancer.go +++ b/vendor/google.golang.org/grpc/balancer.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" + "google.golang.org/grpc/status" ) // Address represents a server the client connects to. @@ -310,7 +311,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad if !opts.BlockingWait { if len(rr.addrs) == 0 { rr.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") + err = status.Errorf(codes.Unavailable, "there is no address available") return } // Returns the next addr on rr.addrs for failfast RPCs. diff --git a/vendor/google.golang.org/grpc/balancer/balancer.go b/vendor/google.golang.org/grpc/balancer/balancer.go index cd2682f5f..219a2940c 100644 --- a/vendor/google.golang.org/grpc/balancer/balancer.go +++ b/vendor/google.golang.org/grpc/balancer/balancer.go @@ -23,6 +23,7 @@ package balancer import ( "errors" "net" + "strings" "golang.org/x/net/context" "google.golang.org/grpc/connectivity" @@ -36,15 +37,17 @@ var ( ) // Register registers the balancer builder to the balancer map. -// b.Name will be used as the name registered with this builder. +// b.Name (lowercased) will be used as the name registered with +// this builder. func Register(b Builder) { - m[b.Name()] = b + m[strings.ToLower(b.Name())] = b } // Get returns the resolver builder registered with the given name. +// Note that the compare is done in a case-insenstive fashion. // If no builder is register with the name, nil will be returned. func Get(name string) Builder { - if b, ok := m[name]; ok { + if b, ok := m[strings.ToLower(name)]; ok { return b } return nil @@ -63,6 +66,11 @@ func Get(name string) Builder { // When the connection encounters an error, it will reconnect immediately. // When the connection becomes IDLE, it will not reconnect unless Connect is // called. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type SubConn interface { // UpdateAddresses updates the addresses used in this SubConn. // gRPC checks if currently-connected address is still in the new list. @@ -80,6 +88,11 @@ type SubConn interface { type NewSubConnOptions struct{} // ClientConn represents a gRPC ClientConn. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type ClientConn interface { // NewSubConn is called by balancer to create a new SubConn. // It doesn't block and wait for the connections to be established. @@ -96,6 +109,9 @@ type ClientConn interface { // on the new picker to pick new SubConn. UpdateBalancerState(s connectivity.State, p Picker) + // ResolveNow is called by balancer to notify gRPC to do a name resolving. + ResolveNow(resolver.ResolveNowOption) + // Target returns the dial target for this ClientConn. Target() string } @@ -128,6 +144,10 @@ type PickOptions struct{} type DoneInfo struct { // Err is the rpc error the RPC finished with. It could be nil. Err error + // BytesSent indicates if any bytes have been sent to the server. + BytesSent bool + // BytesReceived indicates if any byte has been received from the server. + BytesReceived bool } var ( diff --git a/vendor/google.golang.org/grpc/balancer/base/balancer.go b/vendor/google.golang.org/grpc/balancer/base/balancer.go new file mode 100644 index 000000000..1e962b724 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer/base/balancer.go @@ -0,0 +1,209 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package base + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +type baseBuilder struct { + name string + pickerBuilder PickerBuilder +} + +func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + return &baseBalancer{ + cc: cc, + pickerBuilder: bb.pickerBuilder, + + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + csEvltr: &connectivityStateEvaluator{}, + // Initialize picker to a picker that always return + // ErrNoSubConnAvailable, because when state of a SubConn changes, we + // may call UpdateBalancerState with this picker. + picker: NewErrPicker(balancer.ErrNoSubConnAvailable), + } +} + +func (bb *baseBuilder) Name() string { + return bb.name +} + +type baseBalancer struct { + cc balancer.ClientConn + pickerBuilder PickerBuilder + + csEvltr *connectivityStateEvaluator + state connectivity.State + + subConns map[resolver.Address]balancer.SubConn + scStates map[balancer.SubConn]connectivity.State + picker balancer.Picker +} + +func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + if err != nil { + grpclog.Infof("base.baseBalancer: HandleResolvedAddrs called with error %v", err) + return + } + grpclog.Infoln("base.baseBalancer: got new resolved addresses: ", addrs) + // addrsSet is the set converted from addrs, it's used for quick lookup of an address. + addrsSet := make(map[resolver.Address]struct{}) + for _, a := range addrs { + addrsSet[a] = struct{}{} + if _, ok := b.subConns[a]; !ok { + // a is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) + continue + } + b.subConns[a] = sc + b.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + for a, sc := range b.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + b.cc.RemoveSubConn(sc) + delete(b.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } +} + +// regeneratePicker takes a snapshot of the balancer, and generates a picker +// from it. The picker is +// - errPicker with ErrTransientFailure if the balancer is in TransientFailure, +// - built by the pickerBuilder with all READY SubConns otherwise. +func (b *baseBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = NewErrPicker(balancer.ErrTransientFailure) + return + } + readySCs := make(map[resolver.Address]balancer.SubConn) + + // Filter out all ready SCs from full subConn map. + for addr, sc := range b.subConns { + if st, ok := b.scStates[sc]; ok && st == connectivity.Ready { + readySCs[addr] = sc + } + } + b.picker = b.pickerBuilder.Build(readySCs) +} + +func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s) + oldS, ok := b.scStates[sc] + if !ok { + grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + b.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + } + + oldAggrState := b.state + b.state = b.csEvltr.recordTransition(oldS, s) + + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (s == connectivity.Ready) != (oldS == connectivity.Ready) || + (b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + b.regeneratePicker() + } + + b.cc.UpdateBalancerState(b.state, b.picker) + return +} + +// Close is a nop because base balancer doesn't have internal state to clean up, +// and it doesn't need to call RemoveSubConn for the SubConns. +func (b *baseBalancer) Close() { +} + +// NewErrPicker returns a picker that always returns err on Pick(). +func NewErrPicker(err error) balancer.Picker { + return &errPicker{err: err} +} + +type errPicker struct { + err error // Pick() always returns this err. +} + +func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + return nil, nil, p.err +} + +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +type connectivityStateEvaluator struct { + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every subConn and based on +// that it evaluates what aggregated state should be. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection +// before any subConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// +// recordTransition should only be called synchronously from the same goroutine. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case connectivity.Ready: + cse.numReady += updateVal + case connectivity.Connecting: + cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + return connectivity.Ready + } + if cse.numConnecting > 0 { + return connectivity.Connecting + } + return connectivity.TransientFailure +} diff --git a/vendor/google.golang.org/grpc/balancer/base/base.go b/vendor/google.golang.org/grpc/balancer/base/base.go new file mode 100644 index 000000000..012ace2f2 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer/base/base.go @@ -0,0 +1,52 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package base defines a balancer base that can be used to build balancers with +// different picking algorithms. +// +// The base balancer creates a new SubConn for each resolved address. The +// provided picker will only be notified about READY SubConns. +// +// This package is the base of round_robin balancer, its purpose is to be used +// to build round_robin like balancers with complex picking algorithms. +// Balancers with more complicated logic should try to implement a balancer +// builder from scratch. +// +// All APIs in this package are experimental. +package base + +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +// PickerBuilder creates balancer.Picker. +type PickerBuilder interface { + // Build takes a slice of ready SubConns, and returns a picker that will be + // used by gRPC to pick a SubConn. + Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker +} + +// NewBalancerBuilder returns a balancer builder. The balancers +// built by this builder will use the picker builder to build pickers. +func NewBalancerBuilder(name string, pb PickerBuilder) balancer.Builder { + return &baseBuilder{ + name: name, + pickerBuilder: pb, + } +} diff --git a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go index 9d2fbcd84..2eda0a1c2 100644 --- a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go +++ b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go @@ -26,145 +26,37 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" ) +// Name is the name of round_robin balancer. +const Name = "round_robin" + // newBuilder creates a new roundrobin balancer builder. func newBuilder() balancer.Builder { - return &rrBuilder{} + return base.NewBalancerBuilder(Name, &rrPickerBuilder{}) } func init() { balancer.Register(newBuilder()) } -type rrBuilder struct{} - -func (*rrBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { - return &rrBalancer{ - cc: cc, - subConns: make(map[resolver.Address]balancer.SubConn), - scStates: make(map[balancer.SubConn]connectivity.State), - csEvltr: &connectivityStateEvaluator{}, - // Initialize picker to a picker that always return - // ErrNoSubConnAvailable, because when state of a SubConn changes, we - // may call UpdateBalancerState with this picker. - picker: newPicker([]balancer.SubConn{}, nil), - } -} - -func (*rrBuilder) Name() string { - return "round_robin" -} +type rrPickerBuilder struct{} -type rrBalancer struct { - cc balancer.ClientConn - - csEvltr *connectivityStateEvaluator - state connectivity.State - - subConns map[resolver.Address]balancer.SubConn - scStates map[balancer.SubConn]connectivity.State - picker *picker -} - -func (b *rrBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { - if err != nil { - grpclog.Infof("roundrobin.rrBalancer: HandleResolvedAddrs called with error %v", err) - return - } - grpclog.Infoln("roundrobin.rrBalancer: got new resolved addresses: ", addrs) - // addrsSet is the set converted from addrs, it's used for quick lookup of an address. - addrsSet := make(map[resolver.Address]struct{}) - for _, a := range addrs { - addrsSet[a] = struct{}{} - if _, ok := b.subConns[a]; !ok { - // a is a new address (not existing in b.subConns). - sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) - if err != nil { - grpclog.Warningf("roundrobin.rrBalancer: failed to create new SubConn: %v", err) - continue - } - b.subConns[a] = sc - b.scStates[sc] = connectivity.Idle - sc.Connect() - } +func (*rrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker { + grpclog.Infof("roundrobinPicker: newPicker called with readySCs: %v", readySCs) + var scs []balancer.SubConn + for _, sc := range readySCs { + scs = append(scs, sc) } - for a, sc := range b.subConns { - // a was removed by resolver. - if _, ok := addrsSet[a]; !ok { - b.cc.RemoveSubConn(sc) - delete(b.subConns, a) - // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. - // The entry will be deleted in HandleSubConnStateChange. - } - } -} - -// regeneratePicker takes a snapshot of the balancer, and generates a picker -// from it. The picker -// - always returns ErrTransientFailure if the balancer is in TransientFailure, -// - or does round robin selection of all READY SubConns otherwise. -func (b *rrBalancer) regeneratePicker() { - if b.state == connectivity.TransientFailure { - b.picker = newPicker(nil, balancer.ErrTransientFailure) - return - } - var readySCs []balancer.SubConn - for sc, st := range b.scStates { - if st == connectivity.Ready { - readySCs = append(readySCs, sc) - } - } - b.picker = newPicker(readySCs, nil) -} - -func (b *rrBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { - grpclog.Infof("roundrobin.rrBalancer: handle SubConn state change: %p, %v", sc, s) - oldS, ok := b.scStates[sc] - if !ok { - grpclog.Infof("roundrobin.rrBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) - return - } - b.scStates[sc] = s - switch s { - case connectivity.Idle: - sc.Connect() - case connectivity.Shutdown: - // When an address was removed by resolver, b called RemoveSubConn but - // kept the sc's state in scStates. Remove state for this sc here. - delete(b.scStates, sc) - } - - oldAggrState := b.state - b.state = b.csEvltr.recordTransition(oldS, s) - - // Regenerate picker when one of the following happens: - // - this sc became ready from not-ready - // - this sc became not-ready from ready - // - the aggregated state of balancer became TransientFailure from non-TransientFailure - // - the aggregated state of balancer became non-TransientFailure from TransientFailure - if (s == connectivity.Ready) != (oldS == connectivity.Ready) || - (b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { - b.regeneratePicker() + return &rrPicker{ + subConns: scs, } - - b.cc.UpdateBalancerState(b.state, b.picker) - return -} - -// Close is a nop because roundrobin balancer doesn't internal state to clean -// up, and it doesn't need to call RemoveSubConn for the SubConns. -func (b *rrBalancer) Close() { } -type picker struct { - // If err is not nil, Pick always returns this err. It's immutable after - // picker is created. - err error - +type rrPicker struct { // subConns is the snapshot of the roundrobin balancer when this picker was // created. The slice is immutable. Each Get() will do a round robin // selection from it and return the selected SubConn. @@ -174,20 +66,7 @@ type picker struct { next int } -func newPicker(scs []balancer.SubConn, err error) *picker { - grpclog.Infof("roundrobinPicker: newPicker called with scs: %v, %v", scs, err) - if err != nil { - return &picker{err: err} - } - return &picker{ - subConns: scs, - } -} - -func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { - if p.err != nil { - return nil, nil, p.err - } +func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { if len(p.subConns) <= 0 { return nil, nil, balancer.ErrNoSubConnAvailable } @@ -198,44 +77,3 @@ func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer. p.mu.Unlock() return sc, nil, nil } - -// connectivityStateEvaluator gets updated by addrConns when their -// states transition, based on which it evaluates the state of -// ClientConn. -type connectivityStateEvaluator struct { - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. - numTransientFailure uint64 // Number of addrConns in transientFailure. -} - -// recordTransition records state change happening in every subConn and based on -// that it evaluates what aggregated state should be. -// It can only transition between Ready, Connecting and TransientFailure. Other states, -// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection -// before any subConn is created ClientConn is in idle state. In the end when ClientConn -// closes it is in Shutdown state. -// -// recordTransition should only be called synchronously from the same goroutine. -func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { - // Update counters. - for idx, state := range []connectivity.State{oldState, newState} { - updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. - switch state { - case connectivity.Ready: - cse.numReady += updateVal - case connectivity.Connecting: - cse.numConnecting += updateVal - case connectivity.TransientFailure: - cse.numTransientFailure += updateVal - } - } - - // Evaluate. - if cse.numReady > 0 { - return connectivity.Ready - } - if cse.numConnecting > 0 { - return connectivity.Connecting - } - return connectivity.TransientFailure -} diff --git a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin_test.go b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin_test.go index 7f953ff00..59cac4b1b 100644 --- a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin_test.go +++ b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin_test.go @@ -27,18 +27,17 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" - "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/codes" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/leakcheck" ) -var rr = balancer.Get("round_robin") - type testServer struct { testpb.TestServiceServer } @@ -102,7 +101,7 @@ func TestOneBackend(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -111,7 +110,7 @@ func TestOneBackend(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -134,7 +133,7 @@ func TestBackendsRoundRobin(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -143,7 +142,7 @@ func TestBackendsRoundRobin(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -193,7 +192,7 @@ func TestAddressesRemoved(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -202,7 +201,7 @@ func TestAddressesRemoved(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -216,7 +215,7 @@ func TestAddressesRemoved(t *testing.T) { for i := 0; i < 1000; i++ { ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) == codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); status.Code(err) == codes.DeadlineExceeded { return } time.Sleep(time.Millisecond) @@ -235,7 +234,7 @@ func TestCloseWithPendingRPC(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -248,7 +247,7 @@ func TestCloseWithPendingRPC(t *testing.T) { defer wg.Done() // This RPC blocks until cc is closed. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded { t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing") } cancel() @@ -269,7 +268,7 @@ func TestNewAddressWhileBlocking(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -278,7 +277,7 @@ func TestNewAddressWhileBlocking(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -318,7 +317,7 @@ func TestOneServerDown(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake()) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -327,7 +326,7 @@ func TestOneServerDown(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -372,6 +371,7 @@ func TestOneServerDown(t *testing.T) { var targetSeen int for i := 0; i < 1000; i++ { if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil { + targetSeen = 0 t.Logf("EmptyCall() = _, %v, want _, <nil>", err) // Due to a race, this RPC could possibly get the connection that // was closing, and this RPC may fail. Keep trying when this @@ -415,7 +415,7 @@ func TestAllServersDown(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake()) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -424,7 +424,7 @@ func TestAllServersDown(t *testing.T) { // The first RPC should fail because there's no address. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -468,7 +468,7 @@ func TestAllServersDown(t *testing.T) { } time.Sleep(100 * time.Millisecond) for i := 0; i < 1000; i++ { - if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable { return } time.Sleep(time.Millisecond) diff --git a/vendor/google.golang.org/grpc/balancer_conn_wrappers.go b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go index ebfee4a88..db6f0ae3f 100644 --- a/vendor/google.golang.org/grpc/balancer_conn_wrappers.go +++ b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go @@ -19,6 +19,7 @@ package grpc import ( + "fmt" "sync" "google.golang.org/grpc/balancer" @@ -97,6 +98,7 @@ type ccBalancerWrapper struct { resolverUpdateCh chan *resolverUpdate done chan struct{} + mu sync.Mutex subConns map[*acBalancerWrapper]struct{} } @@ -141,7 +143,11 @@ func (ccb *ccBalancerWrapper) watcher() { select { case <-ccb.done: ccb.balancer.Close() - for acbw := range ccb.subConns { + ccb.mu.Lock() + scs := ccb.subConns + ccb.subConns = nil + ccb.mu.Unlock() + for acbw := range scs { ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } return @@ -183,6 +189,14 @@ func (ccb *ccBalancerWrapper) handleResolvedAddrs(addrs []resolver.Address, err } func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + if len(addrs) <= 0 { + return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list") + } + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return nil, fmt.Errorf("grpc: ClientConn balancer wrapper was closed") + } ac, err := ccb.cc.newAddrConn(addrs) if err != nil { return nil, err @@ -200,15 +214,29 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { if !ok { return } + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return + } delete(ccb.subConns, acbw) ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } func (ccb *ccBalancerWrapper) UpdateBalancerState(s connectivity.State, p balancer.Picker) { + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return + } ccb.cc.csMgr.updateState(s) ccb.cc.blockingpicker.updatePicker(p) } +func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOption) { + ccb.cc.resolveNow(o) +} + func (ccb *ccBalancerWrapper) Target() string { return ccb.cc.target } @@ -223,6 +251,10 @@ type acBalancerWrapper struct { func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.mu.Lock() defer acbw.mu.Unlock() + if len(addrs) <= 0 { + acbw.ac.tearDown(errConnDrain) + return + } if !acbw.ac.tryUpdateAddrs(addrs) { cc := acbw.ac.cc acbw.ac.mu.Lock() diff --git a/vendor/google.golang.org/grpc/balancer_switching_test.go b/vendor/google.golang.org/grpc/balancer_switching_test.go index 92c196d38..0d8b2a545 100644 --- a/vendor/google.golang.org/grpc/balancer_switching_test.go +++ b/vendor/google.golang.org/grpc/balancer_switching_test.go @@ -25,6 +25,7 @@ import ( "time" "golang.org/x/net/context" + "google.golang.org/grpc/balancer/roundrobin" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -38,8 +39,8 @@ func checkPickFirst(cc *ClientConn, servers []*server) error { err error ) connected := false - for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == servers[0].port { + for i := 0; i < 5000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == servers[0].port { if connected { // connected is set to false if peer is not server[0]. So if // connected is true here, this is the second time we saw @@ -53,12 +54,12 @@ func checkPickFirst(cc *ClientConn, servers []*server) error { time.Sleep(time.Millisecond) } if !connected { - return fmt.Errorf("pickfirst is not in effect after 1 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port) + return fmt.Errorf("pickfirst is not in effect after 5 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port) } // The following RPCs should all succeed with the first server. for i := 0; i < 3; i++ { err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if ErrorDesc(err) != servers[0].port { + if errorDesc(err) != servers[0].port { return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err) } } @@ -78,15 +79,15 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error { // picked by the closing pickfirst balancer, and the test becomes flaky. for _, s := range servers { var up bool - for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == s.port { + for i := 0; i < 5000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == s.port { up = true break } time.Sleep(time.Millisecond) } if !up { - return fmt.Errorf("server %v is not up within 1 second", s.port) + return fmt.Errorf("server %v is not up within 5 second", s.port) } } } @@ -94,7 +95,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error { serverCount := len(servers) for i := 0; i < 3*serverCount; i++ { err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if ErrorDesc(err) != servers[i%serverCount].port { + if errorDesc(err) != servers[i%serverCount].port { return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err) } } @@ -131,3 +132,312 @@ func TestSwitchBalancer(t *testing.T) { t.Fatalf("check pickfirst returned non-nil error: %v", err) } } + +// Test that balancer specified by dial option will not be overridden. +func TestBalancerDialOption(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}), WithBalancerName(roundrobin.Name)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The init balancer is roundrobin. + if err := checkRoundRobin(cc, servers); err != nil { + t.Fatalf("check roundrobin returned non-nil error: %v", err) + } + // Switch to pickfirst. + cc.handleServiceConfig(`{"loadBalancingPolicy": "pick_first"}`) + // Balancer is still roundrobin. + if err := checkRoundRobin(cc, servers); err != nil { + t.Fatalf("check roundrobin returned non-nil error: %v", err) + } +} + +// First addr update contains grpclb. +func TestSwitchBalancerGRPCLBFirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + // ClientConn will switch balancer to grpclb when receives an address of + // type GRPCLB. + r.NewAddress([]resolver.Address{{Addr: "backend"}, {Addr: "grpclb", Type: resolver.GRPCLB}}) + var isGRPCLB bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName) + } + + // New update containing new backend and new grpclb. Should not switch + // balancer. + r.NewAddress([]resolver.Address{{Addr: "backend2"}, {Addr: "grpclb2", Type: resolver.GRPCLB}}) + for i := 0; i < 200; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if !isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("within 200 ms, cc.balancer switched to !grpclb, want grpclb") + } + + var isPickFirst bool + // Switch balancer to pickfirst. + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isPickFirst = cc.curBalancerName == PickFirstBalancerName + cc.mu.Unlock() + if isPickFirst { + break + } + time.Sleep(time.Millisecond) + } + if !isPickFirst { + t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName) + } +} + +// First addr update does not contain grpclb. +func TestSwitchBalancerGRPCLBSecond(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + var isPickFirst bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isPickFirst = cc.curBalancerName == PickFirstBalancerName + cc.mu.Unlock() + if isPickFirst { + break + } + time.Sleep(time.Millisecond) + } + if !isPickFirst { + t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName) + } + + // ClientConn will switch balancer to grpclb when receives an address of + // type GRPCLB. + r.NewAddress([]resolver.Address{{Addr: "backend"}, {Addr: "grpclb", Type: resolver.GRPCLB}}) + var isGRPCLB bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName) + } + + // New update containing new backend and new grpclb. Should not switch + // balancer. + r.NewAddress([]resolver.Address{{Addr: "backend2"}, {Addr: "grpclb2", Type: resolver.GRPCLB}}) + for i := 0; i < 200; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if !isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("within 200 ms, cc.balancer switched to !grpclb, want grpclb") + } + + // Switch balancer back. + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isPickFirst = cc.curBalancerName == PickFirstBalancerName + cc.mu.Unlock() + if isPickFirst { + break + } + time.Sleep(time.Millisecond) + } + if !isPickFirst { + t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName) + } +} + +// Test that if the current balancer is roundrobin, after switching to grpclb, +// when the resolved address doesn't contain grpclb addresses, balancer will be +// switched back to roundrobin. +func TestSwitchBalancerGRPCLBRoundRobin(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) + + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + var isRoundRobin bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isRoundRobin = cc.curBalancerName == "round_robin" + cc.mu.Unlock() + if isRoundRobin { + break + } + time.Sleep(time.Millisecond) + } + if !isRoundRobin { + t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName) + } + + // ClientConn will switch balancer to grpclb when receives an address of + // type GRPCLB. + r.NewAddress([]resolver.Address{{Addr: "grpclb", Type: resolver.GRPCLB}}) + var isGRPCLB bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName) + } + + // Switch balancer back. + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isRoundRobin = cc.curBalancerName == "round_robin" + cc.mu.Unlock() + if isRoundRobin { + break + } + time.Sleep(time.Millisecond) + } + if !isRoundRobin { + t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName) + } +} + +// Test that if resolved address list contains grpclb, the balancer option in +// service config won't take effect. But when there's no grpclb address in a new +// resolved address list, balancer will be switched to the new one. +func TestSwitchBalancerGRPCLBServiceConfig(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + var isPickFirst bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isPickFirst = cc.curBalancerName == PickFirstBalancerName + cc.mu.Unlock() + if isPickFirst { + break + } + time.Sleep(time.Millisecond) + } + if !isPickFirst { + t.Fatalf("after 5 second, cc.balancer is of type %v, not pick_first", cc.curBalancerName) + } + + // ClientConn will switch balancer to grpclb when receives an address of + // type GRPCLB. + r.NewAddress([]resolver.Address{{Addr: "grpclb", Type: resolver.GRPCLB}}) + var isGRPCLB bool + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isGRPCLB = cc.curBalancerName == "grpclb" + cc.mu.Unlock() + if isGRPCLB { + break + } + time.Sleep(time.Millisecond) + } + if !isGRPCLB { + t.Fatalf("after 5 second, cc.balancer is of type %v, not grpclb", cc.curBalancerName) + } + + r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) + var isRoundRobin bool + for i := 0; i < 200; i++ { + cc.mu.Lock() + isRoundRobin = cc.curBalancerName == "round_robin" + cc.mu.Unlock() + if isRoundRobin { + break + } + time.Sleep(time.Millisecond) + } + // Balancer should NOT switch to round_robin because resolved list contains + // grpclb. + if isRoundRobin { + t.Fatalf("within 200 ms, cc.balancer switched to round_robin, want grpclb") + } + + // Switch balancer back. + r.NewAddress([]resolver.Address{{Addr: "backend"}}) + for i := 0; i < 5000; i++ { + cc.mu.Lock() + isRoundRobin = cc.curBalancerName == "round_robin" + cc.mu.Unlock() + if isRoundRobin { + break + } + time.Sleep(time.Millisecond) + } + if !isRoundRobin { + t.Fatalf("after 5 second, cc.balancer is of type %v, not round_robin", cc.curBalancerName) + } +} diff --git a/vendor/google.golang.org/grpc/balancer_test.go b/vendor/google.golang.org/grpc/balancer_test.go index a1558f027..3e0fa19be 100644 --- a/vendor/google.golang.org/grpc/balancer_test.go +++ b/vendor/google.golang.org/grpc/balancer_test.go @@ -30,10 +30,12 @@ import ( "google.golang.org/grpc/codes" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/naming" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/leakcheck" // V1 balancer tests use passthrough resolver instead of dns. // TODO(bar) remove this when removing v1 balaner entirely. + _ "google.golang.org/grpc/resolver/passthrough" ) @@ -128,7 +130,7 @@ func TestNameDiscovery(t *testing.T) { defer cc.Close() req := "port" var reply string - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) } // Inject the name resolution change to remove servers[0] and add servers[1]. @@ -144,7 +146,7 @@ func TestNameDiscovery(t *testing.T) { r.w.inject(updates) // Loop until the rpcs in flight talks to servers[1]. for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) @@ -204,7 +206,7 @@ func TestRoundRobin(t *testing.T) { var reply string // Loop until servers[1] is up for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) @@ -217,14 +219,14 @@ func TestRoundRobin(t *testing.T) { r.w.inject([]*naming.Update{u}) // Loop until both servers[2] are up. for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port { break } time.Sleep(10 * time.Millisecond) } // Check the incoming RPCs served in a round-robin manner. for i := 0; i < 10; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[i%numServers].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port) } } @@ -252,7 +254,7 @@ func TestCloseWithPendingRPC(t *testing.T) { // Loop until the above update applies. for { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded { cancel() break } @@ -300,7 +302,7 @@ func TestGetOnWaitChannel(t *testing.T) { for { var reply string ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded { cancel() break } @@ -332,7 +334,7 @@ func TestOneServerDown(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake()) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -348,7 +350,7 @@ func TestOneServerDown(t *testing.T) { var reply string // Loop until servers[1] is up for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) @@ -401,7 +403,7 @@ func TestOneAddressRemoval(t *testing.T) { var reply string // Loop until servers[1] is up for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) @@ -450,7 +452,7 @@ func checkServerUp(t *testing.T, currentServer *server) { defer cc.Close() var reply string for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == port { break } time.Sleep(10 * time.Millisecond) @@ -511,7 +513,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) { // Loop until the above update applies. for { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded { cancel() break } @@ -574,7 +576,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { req := "port" var reply string for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) @@ -589,13 +591,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { r.w.inject([]*naming.Update{u}) // Loop until it changes to server[1] for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -609,7 +611,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { } r.w.inject([]*naming.Update{u}) for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -622,13 +624,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { } r.w.inject([]*naming.Update{u}) for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port { break } time.Sleep(1 * time.Second) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[2].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) } time.Sleep(10 * time.Millisecond) @@ -641,13 +643,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { } r.w.inject([]*naming.Update{u}) for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(1 * time.Second) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) @@ -660,7 +662,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake()) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -687,7 +689,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { req := "port" var reply string for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) @@ -698,13 +700,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { servers[0].stop() // Loop until it changes to server[1] for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(10 * time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -719,7 +721,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { checkServerUp(t, servers[0]) for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -732,13 +734,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { } r.w.inject([]*naming.Update{u}) for { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(1 * time.Second) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) diff --git a/vendor/google.golang.org/grpc/balancer_v1_wrapper.go b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go index 6cb39071c..faabf87d0 100644 --- a/vendor/google.golang.org/grpc/balancer_v1_wrapper.go +++ b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/status" ) type balancerWrapperBuilder struct { @@ -173,10 +174,10 @@ func (bw *balancerWrapper) lbWatcher() { sc.Connect() } } else { - oldSC.UpdateAddresses(newAddrs) bw.mu.Lock() bw.connSt[oldSC].addr = addrs[0] bw.mu.Unlock() + oldSC.UpdateAddresses(newAddrs) } } else { var ( @@ -317,12 +318,12 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) Metadata: a.Metadata, }] if !ok && failfast { - return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + return nil, nil, status.Errorf(codes.Unavailable, "there is no connection available") } if s, ok := bw.connSt[sc]; failfast && (!ok || s.s != connectivity.Ready) { // If the returned sc is not ready and RPC is failfast, // return error, and this RPC will fail. - return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + return nil, nil, status.Errorf(codes.Unavailable, "there is no connection available") } } diff --git a/vendor/google.golang.org/grpc/call.go b/vendor/google.golang.org/grpc/call.go index 0854f84b9..13cf8b13b 100644 --- a/vendor/google.golang.org/grpc/call.go +++ b/vendor/google.golang.org/grpc/call.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -59,7 +60,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } for { if c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } // Set dc if it exists and matches the message compression type used, @@ -113,7 +114,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, compressor = nil // Disable the legacy compressor. comp = encoding.GetCompressor(ct) if comp == nil { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) + return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) } } hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp) @@ -121,10 +122,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, return err } if c.maxSendMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } if len(data) > *c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) } err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { @@ -277,11 +278,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when // i) the server started to drain before this RPC was initiated. @@ -301,11 +302,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = recvResponse(ctx, cc.dopts, t, c, stream, reply) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -323,12 +324,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) + err = stream.Status().Err() if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -339,6 +341,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli continue } } - return stream.Status().Err() + return err } } diff --git a/vendor/google.golang.org/grpc/call_test.go b/vendor/google.golang.org/grpc/call_test.go index f48d30e87..b95ade889 100644 --- a/vendor/google.golang.org/grpc/call_test.go +++ b/vendor/google.golang.org/grpc/call_test.go @@ -233,7 +233,7 @@ func TestInvokeLargeErr(t *testing.T) { if _, ok := status.FromError(err); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } - if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { + if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr) } cc.Close() @@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { if _, ok := status.FromError(err); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } - if got, want := ErrorDesc(err), weirdError; got != want { + if got, want := errorDesc(err), weirdError; got != want { t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want) } cc.Close() diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index ae605bc32..bfbef3621 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -95,8 +95,14 @@ type dialOptions struct { scChan <-chan ServiceConfig copts transport.ConnectOptions callOptions []CallOption - // This is to support v1 balancer. + // This is used by v1 balancer dial option WithBalancer to support v1 + // balancer, and also by WithBalancerName dial option. balancerBuilder balancer.Builder + // This is to support grpclb. + resolverBuilder resolver.Builder + // Custom user options for resolver.Build. + resolverBuildUserOptions interface{} + waitForHandshake bool } const ( @@ -107,6 +113,15 @@ const ( // DialOption configures how we set up the connection. type DialOption func(*dialOptions) +// WithWaitForHandshake blocks until the initial settings frame is received from the +// server before assigning RPCs to the connection. +// Experimental API. +func WithWaitForHandshake() DialOption { + return func(o *dialOptions) { + o.waitForHandshake = true + } +} + // WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched // before doing a write on the wire. func WithWriteBufferSize(s int) DialOption { @@ -186,7 +201,8 @@ func WithDecompressor(dc Decompressor) DialOption { // WithBalancer returns a DialOption which sets a load balancer with the v1 API. // Name resolver will be ignored if this DialOption is specified. -// Deprecated: use the new balancer APIs in balancer package instead. +// +// Deprecated: use the new balancer APIs in balancer package and WithBalancerName. func WithBalancer(b Balancer) DialOption { return func(o *dialOptions) { o.balancerBuilder = &balancerWrapperBuilder{ @@ -195,12 +211,36 @@ func WithBalancer(b Balancer) DialOption { } } -// WithBalancerBuilder is for testing only. Users using custom balancers should -// register their balancer and use service config to choose the balancer to use. -func WithBalancerBuilder(b balancer.Builder) DialOption { - // TODO(bar) remove this when switching balancer is done. +// WithBalancerName sets the balancer that the ClientConn will be initialized +// with. Balancer registered with balancerName will be used. This function +// panics if no balancer was registered by balancerName. +// +// The balancer cannot be overridden by balancer option specified by service +// config. +// +// This is an EXPERIMENTAL API. +func WithBalancerName(balancerName string) DialOption { + builder := balancer.Get(balancerName) + if builder == nil { + panic(fmt.Sprintf("grpc.WithBalancerName: no balancer is registered for name %v", balancerName)) + } + return func(o *dialOptions) { + o.balancerBuilder = builder + } +} + +// withResolverBuilder is only for grpclb. +func withResolverBuilder(b resolver.Builder) DialOption { return func(o *dialOptions) { - o.balancerBuilder = b + o.resolverBuilder = b + } +} + +// WithResolverUserOptions returns a DialOption which sets the UserOptions +// field of resolver's BuildOption. +func WithResolverUserOptions(userOpt interface{}) DialOption { + return func(o *dialOptions) { + o.resolverBuildUserOptions = userOpt } } @@ -231,7 +271,7 @@ func WithBackoffConfig(b BackoffConfig) DialOption { return withBackoff(b) } -// withBackoff sets the backoff strategy used for retries after a +// withBackoff sets the backoff strategy used for connectRetryNum after a // failed connection attempt. // // This can be exported if arbitrary backoff strategies are allowed by gRPC. @@ -283,18 +323,23 @@ func WithTimeout(d time.Duration) DialOption { } } +func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption { + return func(o *dialOptions) { + o.copts.Dialer = f + } +} + // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's // Temporary() method to decide if it should try to reconnect to the network address. func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { - return func(o *dialOptions) { - o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { + return withContextDialer( + func(ctx context.Context, addr string) (net.Conn, error) { if deadline, ok := ctx.Deadline(); ok { return f(addr, deadline.Sub(time.Now())) } return f(addr, 0) - } - } + }) } // WithStatsHandler returns a DialOption that specifies the stats handler @@ -480,17 +525,19 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * Dialer: cc.dopts.copts.Dialer, } - if cc.dopts.balancerBuilder != nil { - cc.customBalancer = true - // Build should not take long time. So it's ok to not have a goroutine for it. - cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) - } - // Build the resolver. cc.resolverWrapper, err = newCCResolverWrapper(cc) if err != nil { return nil, fmt.Errorf("failed to build resolver: %v", err) } + // Start the resolver wrapper goroutine after resolverWrapper is created. + // + // If the goroutine is started before resolverWrapper is ready, the + // following may happen: The goroutine sends updates to cc. cc forwards + // those to balancer. Balancer creates new addrConn. addrConn fails to + // connect, and calls resolveNow(). resolveNow() tries to use the non-ready + // resolverWrapper. + cc.resolverWrapper.start() // A blocking dial blocks until the clientConn is ready. if cc.dopts.block { @@ -563,7 +610,6 @@ type ClientConn struct { dopts dialOptions csMgr *connectivityStateManager - customBalancer bool // If this is true, switching balancer will be disabled. balancerBuildOpts balancer.BuildOptions resolverWrapper *ccResolverWrapper blockingpicker *pickerWrapper @@ -575,6 +621,7 @@ type ClientConn struct { // Keepalive parameter can be updated if a GoAway is received. mkp keepalive.ClientParameters curBalancerName string + preBalancerName string // previous balancer name. curAddresses []resolver.Address balancerWrapper *ccBalancerWrapper } @@ -624,51 +671,92 @@ func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { cc.mu.Lock() defer cc.mu.Unlock() if cc.conns == nil { + // cc was closed. return } - // TODO(bar switching) when grpclb is submitted, check address type and start grpclb. - if !cc.customBalancer && cc.balancerWrapper == nil { - // No customBalancer was specified by DialOption, and this is the first - // time handling resolved addresses, create a pickfirst balancer. - builder := newPickfirstBuilder() - cc.curBalancerName = builder.Name() - cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) + if reflect.DeepEqual(cc.curAddresses, addrs) { + return } - // TODO(bar switching) compare addresses, if there's no update, don't notify balancer. cc.curAddresses = addrs + + if cc.dopts.balancerBuilder == nil { + // Only look at balancer types and switch balancer if balancer dial + // option is not set. + var isGRPCLB bool + for _, a := range addrs { + if a.Type == resolver.GRPCLB { + isGRPCLB = true + break + } + } + var newBalancerName string + if isGRPCLB { + newBalancerName = grpclbName + } else { + // Address list doesn't contain grpclb address. Try to pick a + // non-grpclb balancer. + newBalancerName = cc.curBalancerName + // If current balancer is grpclb, switch to the previous one. + if newBalancerName == grpclbName { + newBalancerName = cc.preBalancerName + } + // The following could be true in two cases: + // - the first time handling resolved addresses + // (curBalancerName="") + // - the first time handling non-grpclb addresses + // (curBalancerName="grpclb", preBalancerName="") + if newBalancerName == "" { + newBalancerName = PickFirstBalancerName + } + } + cc.switchBalancer(newBalancerName) + } else if cc.balancerWrapper == nil { + // Balancer dial option was set, and this is the first time handling + // resolved addresses. Build a balancer with dopts.balancerBuilder. + cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) + } + cc.balancerWrapper.handleResolvedAddrs(addrs, nil) } -// switchBalancer starts the switching from current balancer to the balancer with name. +// switchBalancer starts the switching from current balancer to the balancer +// with the given name. +// +// It will NOT send the current address list to the new balancer. If needed, +// caller of this function should send address list to the new balancer after +// this function returns. +// +// Caller must hold cc.mu. func (cc *ClientConn) switchBalancer(name string) { if cc.conns == nil { return } - grpclog.Infof("ClientConn switching balancer to %q", name) - if cc.customBalancer { - grpclog.Infoln("ignoring service config balancer configuration: WithBalancer DialOption used instead") + if strings.ToLower(cc.curBalancerName) == strings.ToLower(name) { return } - if cc.curBalancerName == name { + grpclog.Infof("ClientConn switching balancer to %q", name) + if cc.dopts.balancerBuilder != nil { + grpclog.Infoln("ignoring balancer switching: Balancer DialOption used instead") return } - // TODO(bar switching) change this to two steps: drain and close. // Keep track of sc in wrapper. - cc.balancerWrapper.close() + if cc.balancerWrapper != nil { + cc.balancerWrapper.close() + } builder := balancer.Get(name) if builder == nil { - grpclog.Infof("failed to get balancer builder for: %v (this should never happen...)", name) + grpclog.Infof("failed to get balancer builder for: %v, using pick_first instead", name) builder = newPickfirstBuilder() } + cc.preBalancerName = cc.curBalancerName cc.curBalancerName = builder.Name() cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) - cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) } func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { @@ -684,6 +772,8 @@ func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivi } // newAddrConn creates an addrConn for addrs and adds it to cc.conns. +// +// Caller needs to make sure len(addrs) > 0. func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ac := &addrConn{ cc: cc, @@ -774,6 +864,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { grpclog.Infof("addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound) if curAddrFound { ac.addrs = addrs + ac.reconnectIdx = 0 // Start reconnecting from beginning in the new list. } return curAddrFound @@ -816,13 +907,33 @@ func (cc *ClientConn) handleServiceConfig(js string) error { cc.mu.Lock() cc.scRaw = js cc.sc = sc - if sc.LB != nil { - cc.switchBalancer(*sc.LB) + if sc.LB != nil && *sc.LB != grpclbName { // "grpclb" is not a valid balancer option in service config. + if cc.curBalancerName == grpclbName { + // If current balancer is grpclb, there's at least one grpclb + // balancer address in the resolved list. Don't switch the balancer, + // but change the previous balancer name, so if a new resolved + // address list doesn't contain grpclb address, balancer will be + // switched to *sc.LB. + cc.preBalancerName = *sc.LB + } else { + cc.switchBalancer(*sc.LB) + cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) + } } cc.mu.Unlock() return nil } +func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { + cc.mu.Lock() + r := cc.resolverWrapper + cc.mu.Unlock() + if r == nil { + return + } + go r.resolveNow(o) +} + // Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() error { cc.cancel() @@ -859,15 +970,16 @@ type addrConn struct { ctx context.Context cancel context.CancelFunc - cc *ClientConn - curAddr resolver.Address - addrs []resolver.Address - dopts dialOptions - events trace.EventLog - acbw balancer.SubConn + cc *ClientConn + addrs []resolver.Address + dopts dialOptions + events trace.EventLog + acbw balancer.SubConn - mu sync.Mutex - state connectivity.State + mu sync.Mutex + curAddr resolver.Address + reconnectIdx int // The index in addrs list to start reconnecting from. + state connectivity.State // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} @@ -875,6 +987,14 @@ type addrConn struct { // The reason this addrConn is torn down. tearDownErr error + + connectRetryNum int + // backoffDeadline is the time until which resetTransport needs to + // wait before increasing connectRetryNum count. + backoffDeadline time.Time + // connectDeadline is the time by which all connection + // negotiations must complete. + connectDeadline time.Time } // adjustParams updates parameters used to create transports upon @@ -909,6 +1029,15 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { // resetTransport recreates a transport to the address for ac. The old // transport will close itself on error or when the clientconn is closed. +// The created transport must receive initial settings frame from the server. +// In case that doesnt happen, transportMonitor will kill the newly created +// transport after connectDeadline has expired. +// In case there was an error on the transport before the settings frame was +// received, resetTransport resumes connecting to backends after the one that +// was previously connected to. In case end of the list is reached, resetTransport +// backs off until the original deadline. +// If the DialOption WithWaitForHandshake was set, resetTrasport returns +// successfully only after server settings are received. // // TODO(bar) make sure all state transitions are valid. func (ac *addrConn) resetTransport() error { @@ -922,19 +1051,38 @@ func (ac *addrConn) resetTransport() error { ac.ready = nil } ac.transport = nil - ac.curAddr = resolver.Address{} + ridx := ac.reconnectIdx ac.mu.Unlock() ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() - for retries := 0; ; retries++ { - sleepTime := ac.dopts.bs.backoff(retries) - timeout := minConnectTimeout + var backoffDeadline, connectDeadline time.Time + for connectRetryNum := 0; ; connectRetryNum++ { ac.mu.Lock() - if timeout < time.Duration(int(sleepTime)/len(ac.addrs)) { - timeout = time.Duration(int(sleepTime) / len(ac.addrs)) + if ac.backoffDeadline.IsZero() { + // This means either a successful HTTP2 connection was established + // or this is the first time this addrConn is trying to establish a + // connection. + backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration. + // This will be the duration that dial gets to finish. + dialDuration := minConnectTimeout + if backoffFor > dialDuration { + // Give dial more time as we keep failing to connect. + dialDuration = backoffFor + } + start := time.Now() + backoffDeadline = start.Add(backoffFor) + connectDeadline = start.Add(dialDuration) + ridx = 0 // Start connecting from the beginning. + } else { + // Continue trying to conect with the same deadlines. + connectRetryNum = ac.connectRetryNum + backoffDeadline = ac.backoffDeadline + connectDeadline = ac.connectDeadline + ac.backoffDeadline = time.Time{} + ac.connectDeadline = time.Time{} + ac.connectRetryNum = 0 } - connectTime := time.Now() if ac.state == connectivity.Shutdown { ac.mu.Unlock() return errConnClosing @@ -949,93 +1097,159 @@ func (ac *addrConn) resetTransport() error { copy(addrsIter, ac.addrs) copts := ac.dopts.copts ac.mu.Unlock() - for _, addr := range addrsIter { + connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts) + if err != nil { + return err + } + if connected { + return nil + } + } +} + +// createTransport creates a connection to one of the backends in addrs. +// It returns true if a connection was established. +func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions) (bool, error) { + for i := ridx; i < len(addrs); i++ { + addr := addrs[i] + target := transport.TargetInfo{ + Addr: addr.Addr, + Metadata: addr.Metadata, + Authority: ac.cc.authority, + } + done := make(chan struct{}) + onPrefaceReceipt := func() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. - ac.mu.Unlock() - return errConnClosing + close(done) + if !ac.backoffDeadline.IsZero() { + // If we haven't already started reconnecting to + // other backends. + // Note, this can happen when writer notices an error + // and triggers resetTransport while at the same time + // reader receives the preface and invokes this closure. + ac.backoffDeadline = time.Time{} + ac.connectDeadline = time.Time{} + ac.connectRetryNum = 0 } ac.mu.Unlock() - sinfo := transport.TargetInfo{ - Addr: addr.Addr, - Metadata: addr.Metadata, - Authority: ac.cc.authority, - } - newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) - if err != nil { - if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { - ac.mu.Lock() - if ac.state != connectivity.Shutdown { - ac.state = connectivity.TransientFailure - ac.cc.handleSubConnStateChange(ac.acbw, ac.state) - } - ac.mu.Unlock() - return err - } - grpclog.Warningf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, addr) + } + // Do not cancel in the success path because of + // this issue in Go1.6: https://github.com/golang/go/issues/15078. + connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) + newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt) + if err != nil { + cancel() + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. - ac.mu.Unlock() - return errConnClosing + if ac.state != connectivity.Shutdown { + ac.state = connectivity.TransientFailure + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) } ac.mu.Unlock() - continue + return false, err } ac.mu.Lock() - ac.printf("ready") if ac.state == connectivity.Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() - newTransport.Close() - return errConnClosing - } - ac.state = connectivity.Ready - ac.cc.handleSubConnStateChange(ac.acbw, ac.state) - t := ac.transport - ac.transport = newTransport - if t != nil { - t.Close() - } - ac.curAddr = addr - if ac.ready != nil { - close(ac.ready) - ac.ready = nil + return false, errConnClosing } ac.mu.Unlock() - return nil + grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v. Err :%v. Reconnecting...", addr, err) + continue + } + if ac.dopts.waitForHandshake { + select { + case <-done: + case <-connectCtx.Done(): + // Didn't receive server preface, must kill this new transport now. + grpclog.Warningf("grpc: addrConn.createTransport failed to receive server preface before deadline.") + newTr.Close() + break + case <-ac.ctx.Done(): + } } ac.mu.Lock() - ac.state = connectivity.TransientFailure + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + // ac.tearDonn(...) has been invoked. + newTr.Close() + return false, errConnClosing + } + ac.printf("ready") + ac.state = connectivity.Ready ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.transport = newTr + ac.curAddr = addr if ac.ready != nil { close(ac.ready) ac.ready = nil } - ac.mu.Unlock() - timer := time.NewTimer(sleepTime - time.Since(connectTime)) select { - case <-timer.C: - case <-ac.ctx.Done(): - timer.Stop() - return ac.ctx.Err() + case <-done: + // If the server has responded back with preface already, + // don't set the reconnect parameters. + default: + ac.connectRetryNum = connectRetryNum + ac.backoffDeadline = backoffDeadline + ac.connectDeadline = connectDeadline + ac.reconnectIdx = i + 1 // Start reconnecting from the next backend in the list. } + ac.mu.Unlock() + return true, nil + } + ac.mu.Lock() + ac.state = connectivity.TransientFailure + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.cc.resolveNow(resolver.ResolveNowOption{}) + if ac.ready != nil { + close(ac.ready) + ac.ready = nil + } + ac.mu.Unlock() + timer := time.NewTimer(backoffDeadline.Sub(time.Now())) + select { + case <-timer.C: + case <-ac.ctx.Done(): timer.Stop() + return false, ac.ctx.Err() } + return false, nil } // Run in a goroutine to track the error in transport and create the // new transport if an error happens. It returns when the channel is closing. func (ac *addrConn) transportMonitor() { for { + var timer *time.Timer + var cdeadline <-chan time.Time ac.mu.Lock() t := ac.transport + if !ac.connectDeadline.IsZero() { + timer = time.NewTimer(ac.connectDeadline.Sub(time.Now())) + cdeadline = timer.C + } ac.mu.Unlock() // Block until we receive a goaway or an error occurs. select { case <-t.GoAway(): case <-t.Error(): + case <-cdeadline: + ac.mu.Lock() + // This implies that client received server preface. + if ac.backoffDeadline.IsZero() { + ac.mu.Unlock() + continue + } + ac.mu.Unlock() + timer = nil + // No server preface received until deadline. + // Kill the connection. + grpclog.Warningf("grpc: addrConn.transportMonitor didn't get server preface after waiting. Closing the new transport now.") + t.Close() + } + if timer != nil { + timer.Stop() } // If a GoAway happened, regardless of error, adjust our keepalive // parameters as appropriate. @@ -1053,6 +1267,7 @@ func (ac *addrConn) transportMonitor() { // resetTransport. Transition READY->CONNECTING is not valid. ac.state = connectivity.TransientFailure ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.cc.resolveNow(resolver.ResolveNowOption{}) ac.curAddr = resolver.Address{} ac.mu.Unlock() if err := ac.resetTransport(); err != nil { @@ -1140,6 +1355,9 @@ func (ac *addrConn) tearDown(err error) { ac.cancel() ac.mu.Lock() defer ac.mu.Unlock() + if ac.state == connectivity.Shutdown { + return + } ac.curAddr = resolver.Address{} if err == errConnDrain && ac.transport != nil { // GracefulClose(...) may be executed multiple times when @@ -1148,9 +1366,6 @@ func (ac *addrConn) tearDown(err error) { // address removal and GoAway. ac.transport.GracefulClose() } - if ac.state == connectivity.Shutdown { - return - } ac.state = connectivity.Shutdown ac.tearDownErr = err ac.cc.handleSubConnStateChange(ac.acbw, ac.state) diff --git a/vendor/google.golang.org/grpc/clientconn_test.go b/vendor/google.golang.org/grpc/clientconn_test.go index c0b0ba436..d87daf2b8 100644 --- a/vendor/google.golang.org/grpc/clientconn_test.go +++ b/vendor/google.golang.org/grpc/clientconn_test.go @@ -19,17 +19,21 @@ package grpc import ( + "io" "math" "net" "testing" "time" "golang.org/x/net/context" + "golang.org/x/net/http2" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/naming" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" _ "google.golang.org/grpc/resolver/passthrough" "google.golang.org/grpc/test/leakcheck" "google.golang.org/grpc/testdata" @@ -44,6 +48,272 @@ func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.Sta return state, state == wantState } +func TestDialWithMultipleBackendsNotSendingServerPreface(t *testing.T) { + defer leakcheck.Check(t) + numServers := 2 + servers := make([]net.Listener, numServers) + var err error + for i := 0; i < numServers; i++ { + servers[i], err = net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening. Err: %v", err) + } + } + dones := make([]chan struct{}, numServers) + for i := 0; i < numServers; i++ { + dones[i] = make(chan struct{}) + } + for i := 0; i < numServers; i++ { + go func(i int) { + defer func() { + close(dones[i]) + }() + conn, err := servers[i].Accept() + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + defer conn.Close() + switch i { + case 0: // 1st server accepts the connection and immediately closes it. + case 1: // 2nd server accepts the connection and sends settings frames. + framer := http2.NewFramer(conn, conn) + if err := framer.WriteSettings(http2.Setting{}); err != nil { + t.Errorf("Error while writing settings frame. %v", err) + return + } + conn.SetDeadline(time.Now().Add(time.Second)) + buf := make([]byte, 1024) + for { // Make sure the connection stays healthy. + _, err = conn.Read(buf) + if err == nil { + continue + } + if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() { + t.Errorf("Server expected the conn.Read(_) to timeout instead got error: %v", err) + } + return + } + } + }(i) + } + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + resolvedAddrs := make([]resolver.Address, numServers) + for i := 0; i < numServers; i++ { + resolvedAddrs[i] = resolver.Address{Addr: servers[i].Addr().String()} + } + r.InitialAddrs(resolvedAddrs) + client, err := Dial(r.Scheme()+":///test.server", WithInsecure()) + if err != nil { + t.Errorf("Dial failed. Err: %v", err) + } else { + defer client.Close() + } + time.Sleep(time.Second) // Close the servers after a second for cleanup. + for _, s := range servers { + s.Close() + } + for _, done := range dones { + <-done + } +} + +func TestDialWaitsForServerSettings(t *testing.T) { + defer leakcheck.Check(t) + server, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening. Err: %v", err) + } + defer server.Close() + done := make(chan struct{}) + sent := make(chan struct{}) + dialDone := make(chan struct{}) + go func() { // Launch the server. + defer func() { + close(done) + }() + conn, err := server.Accept() + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + defer conn.Close() + // Sleep so that if the test were to fail it + // will fail more often than not. + time.Sleep(100 * time.Millisecond) + framer := http2.NewFramer(conn, conn) + close(sent) + if err := framer.WriteSettings(http2.Setting{}); err != nil { + t.Errorf("Error while writing settings. Err: %v", err) + return + } + <-dialDone // Close conn only after dial returns. + }() + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + client, err := DialContext(ctx, server.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock()) + close(dialDone) + if err != nil { + cancel() + t.Fatalf("Error while dialing. Err: %v", err) + } + defer client.Close() + select { + case <-sent: + default: + t.Fatalf("Dial returned before server settings were sent") + } + <-done + +} + +func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { + mctBkp := minConnectTimeout + // Call this only after transportMonitor goroutine has ended. + defer func() { + minConnectTimeout = mctBkp + }() + defer leakcheck.Check(t) + minConnectTimeout = time.Millisecond * 500 + server, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening. Err: %v", err) + } + defer server.Close() + done := make(chan struct{}) + clientDone := make(chan struct{}) + go func() { // Launch the server. + defer func() { + if done != nil { + close(done) + } + }() + conn1, err := server.Accept() + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + defer conn1.Close() + // Don't send server settings and make sure the connection is closed. + time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second. + conn1.SetDeadline(time.Now().Add(time.Second)) + b := make([]byte, 24) + for { + // Make sure the connection was closed by client. + _, err = conn1.Read(b) + if err == nil { + continue + } + if err != io.EOF { + t.Errorf(" conn1.Read(_) = _, %v, want _, io.EOF", err) + return + } + break + } + + conn2, err := server.Accept() // Accept a reconnection request from client. + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + defer conn2.Close() + framer := http2.NewFramer(conn2, conn2) + if err := framer.WriteSettings(http2.Setting{}); err != nil { + t.Errorf("Error while writing settings. Err: %v", err) + return + } + time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second. + conn2.SetDeadline(time.Now().Add(time.Millisecond * 500)) + for { + // Make sure the connection stays open and is closed + // only by connection timeout. + _, err = conn2.Read(b) + if err == nil { + continue + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return + } + t.Errorf("Unexpected error while reading. Err: %v, want timeout error", err) + break + } + close(done) + done = nil + <-clientDone + + }() + client, err := Dial(server.Addr().String(), WithInsecure()) + if err != nil { + t.Fatalf("Error while dialing. Err: %v", err) + } + <-done + // TODO: The code from BEGIN to END should be delete once issue + // https://github.com/grpc/grpc-go/issues/1750 is fixed. + // BEGIN + // Set underlying addrConns state to Shutdown so that no reconnect + // attempts take place and thereby resetting minConnectTimeout is + // race free. + client.mu.Lock() + addrConns := client.conns + client.mu.Unlock() + for ac := range addrConns { + ac.mu.Lock() + ac.state = connectivity.Shutdown + ac.mu.Unlock() + } + // END + client.Close() + close(clientDone) +} + +func TestBackoffWhenNoServerPrefaceReceived(t *testing.T) { + defer leakcheck.Check(t) + server, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening. Err: %v", err) + } + defer server.Close() + done := make(chan struct{}) + go func() { // Launch the server. + defer func() { + close(done) + }() + conn, err := server.Accept() // Accept the connection only to close it immediately. + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + prevAt := time.Now() + conn.Close() + var prevDuration time.Duration + // Make sure the retry attempts are backed off properly. + for i := 0; i < 3; i++ { + conn, err := server.Accept() + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + meow := time.Now() + conn.Close() + dr := meow.Sub(prevAt) + if dr <= prevDuration { + t.Errorf("Client backoff did not increase with retries. Previous duration: %v, current duration: %v", prevDuration, dr) + return + } + prevDuration = dr + prevAt = meow + } + }() + client, err := Dial(server.Addr().String(), WithInsecure()) + if err != nil { + t.Fatalf("Error while dialing. Err: %v", err) + } + defer client.Close() + <-done + +} + func TestConnectivityStates(t *testing.T) { defer leakcheck.Check(t) servers, resolver, cleanup := startServers(t, 2, math.MaxUint32) @@ -342,7 +612,7 @@ func TestClientUpdatesParamsAfterGoAway(t *testing.T) { defer s.Stop() cc, err := Dial(addr, WithBlock(), WithInsecure(), WithKeepaliveParams(keepalive.ClientParameters{ Time: 50 * time.Millisecond, - Timeout: 1 * time.Millisecond, + Timeout: 100 * time.Millisecond, PermitWithoutStream: true, })) if err != nil { diff --git a/vendor/google.golang.org/grpc/codec.go b/vendor/google.golang.org/grpc/codec.go index b452a4ae8..43d81ed2a 100644 --- a/vendor/google.golang.org/grpc/codec.go +++ b/vendor/google.golang.org/grpc/codec.go @@ -69,6 +69,11 @@ func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error } func (p protoCodec) Marshal(v interface{}) ([]byte, error) { + if pm, ok := v.(proto.Marshaler); ok { + // object can marshal itself, no need for buffer + return pm.Marshal() + } + cb := protoBufferPool.Get().(*cachedProtoBuffer) out, err := p.marshal(v, cb) @@ -79,10 +84,17 @@ func (p protoCodec) Marshal(v interface{}) ([]byte, error) { } func (p protoCodec) Unmarshal(data []byte, v interface{}) error { + protoMsg := v.(proto.Message) + protoMsg.Reset() + + if pu, ok := protoMsg.(proto.Unmarshaler); ok { + // object can unmarshal itself, no need for buffer + return pu.Unmarshal(data) + } + cb := protoBufferPool.Get().(*cachedProtoBuffer) cb.SetBuf(data) - v.(proto.Message).Reset() - err := cb.Unmarshal(v.(proto.Message)) + err := cb.Unmarshal(protoMsg) cb.SetBuf(nil) protoBufferPool.Put(cb) return err diff --git a/vendor/google.golang.org/grpc/codes/code_string.go b/vendor/google.golang.org/grpc/codes/code_string.go index d9cf9675b..0b206a578 100644 --- a/vendor/google.golang.org/grpc/codes/code_string.go +++ b/vendor/google.golang.org/grpc/codes/code_string.go @@ -1,16 +1,62 @@ -// Code generated by "stringer -type=Code"; DO NOT EDIT. +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ package codes import "strconv" -const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlreadyExistsPermissionDeniedResourceExhaustedFailedPreconditionAbortedOutOfRangeUnimplementedInternalUnavailableDataLossUnauthenticated" - -var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192} - -func (i Code) String() string { - if i >= Code(len(_Code_index)-1) { - return "Code(" + strconv.FormatInt(int64(i), 10) + ")" +func (c Code) String() string { + switch c { + case OK: + return "OK" + case Canceled: + return "Canceled" + case Unknown: + return "Unknown" + case InvalidArgument: + return "InvalidArgument" + case DeadlineExceeded: + return "DeadlineExceeded" + case NotFound: + return "NotFound" + case AlreadyExists: + return "AlreadyExists" + case PermissionDenied: + return "PermissionDenied" + case ResourceExhausted: + return "ResourceExhausted" + case FailedPrecondition: + return "FailedPrecondition" + case Aborted: + return "Aborted" + case OutOfRange: + return "OutOfRange" + case Unimplemented: + return "Unimplemented" + case Internal: + return "Internal" + case Unavailable: + return "Unavailable" + case DataLoss: + return "DataLoss" + case Unauthenticated: + return "Unauthenticated" + default: + return "Code(" + strconv.FormatInt(int64(c), 10) + ")" } - return _Code_name[_Code_index[i]:_Code_index[i+1]] } diff --git a/vendor/google.golang.org/grpc/codes/codes.go b/vendor/google.golang.org/grpc/codes/codes.go index 21e7733a5..f3719d562 100644 --- a/vendor/google.golang.org/grpc/codes/codes.go +++ b/vendor/google.golang.org/grpc/codes/codes.go @@ -19,12 +19,13 @@ // Package codes defines the canonical error codes used by gRPC. It is // consistent across various languages. package codes // import "google.golang.org/grpc/codes" +import ( + "fmt" +) // A Code is an unsigned 32-bit error code as defined in the gRPC spec. type Code uint32 -//go:generate stringer -type=Code - const ( // OK is returned on success. OK Code = 0 @@ -142,3 +143,41 @@ const ( // DataLoss indicates unrecoverable data loss or corruption. DataLoss Code = 15 ) + +var strToCode = map[string]Code{ + `"OK"`: OK, + `"CANCELLED"`:/* [sic] */ Canceled, + `"UNKNOWN"`: Unknown, + `"INVALID_ARGUMENT"`: InvalidArgument, + `"DEADLINE_EXCEEDED"`: DeadlineExceeded, + `"NOT_FOUND"`: NotFound, + `"ALREADY_EXISTS"`: AlreadyExists, + `"PERMISSION_DENIED"`: PermissionDenied, + `"RESOURCE_EXHAUSTED"`: ResourceExhausted, + `"FAILED_PRECONDITION"`: FailedPrecondition, + `"ABORTED"`: Aborted, + `"OUT_OF_RANGE"`: OutOfRange, + `"UNIMPLEMENTED"`: Unimplemented, + `"INTERNAL"`: Internal, + `"UNAVAILABLE"`: Unavailable, + `"DATA_LOSS"`: DataLoss, + `"UNAUTHENTICATED"`: Unauthenticated, +} + +// UnmarshalJSON unmarshals b into the Code. +func (c *Code) UnmarshalJSON(b []byte) error { + // From json.Unmarshaler: By convention, to approximate the behavior of + // Unmarshal itself, Unmarshalers implement UnmarshalJSON([]byte("null")) as + // a no-op. + if string(b) == "null" { + return nil + } + if c == nil { + return fmt.Errorf("nil receiver passed to UnmarshalJSON") + } + if jc, ok := strToCode[string(b)]; ok { + *c = jc + return nil + } + return fmt.Errorf("invalid code: %q", string(b)) +} diff --git a/vendor/google.golang.org/grpc/codes/codes_test.go b/vendor/google.golang.org/grpc/codes/codes_test.go new file mode 100644 index 000000000..1e3b99184 --- /dev/null +++ b/vendor/google.golang.org/grpc/codes/codes_test.go @@ -0,0 +1,64 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package codes + +import ( + "encoding/json" + "reflect" + "testing" + + cpb "google.golang.org/genproto/googleapis/rpc/code" +) + +func TestUnmarshalJSON(t *testing.T) { + for s, v := range cpb.Code_value { + want := Code(v) + var got Code + if err := got.UnmarshalJSON([]byte(`"` + s + `"`)); err != nil || got != want { + t.Errorf("got.UnmarshalJSON(%q) = %v; want <nil>. got=%v; want %v", s, err, got, want) + } + } +} + +func TestJSONUnmarshal(t *testing.T) { + var got []Code + want := []Code{OK, NotFound, Internal, Canceled} + in := `["OK", "NOT_FOUND", "INTERNAL", "CANCELLED"]` + err := json.Unmarshal([]byte(in), &got) + if err != nil || !reflect.DeepEqual(got, want) { + t.Fatalf("json.Unmarshal(%q, &got) = %v; want <nil>. got=%v; want %v", in, err, got, want) + } +} + +func TestUnmarshalJSON_NilReceiver(t *testing.T) { + var got *Code + in := OK.String() + if err := got.UnmarshalJSON([]byte(in)); err == nil { + t.Errorf("got.UnmarshalJSON(%q) = nil; want <non-nil>. got=%v", in, got) + } +} + +func TestUnmarshalJSON_UnknownInput(t *testing.T) { + var got Code + for _, in := range [][]byte{[]byte(""), []byte("xxx"), []byte("Code(17)"), nil} { + if err := got.UnmarshalJSON([]byte(in)); err == nil { + t.Errorf("got.UnmarshalJSON(%q) = nil; want <non-nil>. got=%v", in, got) + } + } +} diff --git a/vendor/google.golang.org/grpc/grpclb.go b/vendor/google.golang.org/grpc/grpclb.go index db56ff362..d14a5d409 100644 --- a/vendor/google.golang.org/grpc/grpclb.go +++ b/vendor/google.golang.org/grpc/grpclb.go @@ -19,21 +19,32 @@ package grpc import ( - "errors" - "fmt" - "math/rand" - "net" + "strconv" + "strings" "sync" "time" "golang.org/x/net/context" - "google.golang.org/grpc/codes" - lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/naming" + "google.golang.org/grpc/resolver" ) +const ( + lbTokeyKey = "lb-token" + defaultFallbackTimeout = 10 * time.Second + grpclbName = "grpclb" +) + +func convertDuration(d *lbpb.Duration) time.Duration { + if d == nil { + return 0 + } + return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +} + // Client API for LoadBalancer service. // Mostly copied from generated pb.go file. // To avoid circular dependency. @@ -59,646 +70,273 @@ type balanceLoadClientStream struct { ClientStream } -func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error { +func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error { return x.ClientStream.SendMsg(m) } -func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) { - m := new(lbmpb.LoadBalanceResponse) +func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) { + m := new(lbpb.LoadBalanceResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } -// NewGRPCLBBalancer creates a grpclb load balancer. -func NewGRPCLBBalancer(r naming.Resolver) Balancer { - return &grpclbBalancer{ - r: r, - } +func init() { + balancer.Register(newLBBuilder()) } -type remoteBalancerInfo struct { - addr string - // the server name used for authentication with the remote LB server. - name string +// newLBBuilder creates a builder for grpclb. +func newLBBuilder() balancer.Builder { + return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout) } -// grpclbAddrInfo consists of the information of a backend server. -type grpclbAddrInfo struct { - addr Address - connected bool - // dropForRateLimiting indicates whether this particular request should be - // dropped by the client for rate limiting. - dropForRateLimiting bool - // dropForLoadBalancing indicates whether this particular request should be - // dropped by the client for load balancing. - dropForLoadBalancing bool +// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given +// fallbackTimeout. If no response is received from the remote balancer within +// fallbackTimeout, the backend addresses from the resolved address list will be +// used. +// +// Only call this function when a non-default fallback timeout is needed. +func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder { + return &lbBuilder{ + fallbackTimeout: fallbackTimeout, + } } -type grpclbBalancer struct { - r naming.Resolver - target string - mu sync.Mutex - seq int // a sequence number to make sure addrCh does not get stale addresses. - w naming.Watcher - addrCh chan []Address - rbs []remoteBalancerInfo - addrs []*grpclbAddrInfo - next int - waitCh chan struct{} - done bool - rand *rand.Rand - - clientStats lbmpb.ClientStats +type lbBuilder struct { + fallbackTimeout time.Duration } -func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { - updates, err := w.Next() - if err != nil { - grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) - return err - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return ErrClientConnClosing - } - for _, update := range updates { - switch update.Op { - case naming.Add: - var exist bool - for _, v := range b.rbs { - // TODO: Is the same addr with different server name a different balancer? - if update.Addr == v.addr { - exist = true - break - } - } - if exist { - continue - } - md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB) - if !ok { - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata) - continue - } - switch md.AddrType { - case naming.Backend: - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution does not give grpclb addresses") - continue - case naming.GRPCLB: - b.rbs = append(b.rbs, remoteBalancerInfo{ - addr: update.Addr, - name: md.ServerName, - }) - default: - grpclog.Errorf("Received unknow address type %d", md.AddrType) - continue - } - case naming.Delete: - for i, v := range b.rbs { - if update.Addr == v.addr { - copy(b.rbs[i:], b.rbs[i+1:]) - b.rbs = b.rbs[:len(b.rbs)-1] - break - } - } - default: - grpclog.Errorf("Unknown update.Op %v", update.Op) - } +func (b *lbBuilder) Name() string { + return grpclbName +} + +func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + // This generates a manual resolver builder with a random scheme. This + // scheme will be used to dial to remote LB, so we can send filtered address + // updates to remote LB ClientConn using this manual resolver. + scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36) + r := &lbManualResolver{scheme: scheme, ccb: cc} + + var target string + targetSplitted := strings.Split(cc.Target(), ":///") + if len(targetSplitted) < 2 { + target = cc.Target() + } else { + target = targetSplitted[1] } - // TODO: Fall back to the basic round-robin load balancing if the resulting address is - // not a load balancer. - select { - case <-ch: - default: + + lb := &lbBalancer{ + cc: cc, + target: target, + opt: opt, + fallbackTimeout: b.fallbackTimeout, + doneCh: make(chan struct{}), + + manualResolver: r, + csEvltr: &connectivityStateEvaluator{}, + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + picker: &errPicker{err: balancer.ErrNoSubConnAvailable}, + clientStats: &rpcStats{}, } - ch <- b.rbs - return nil + + return lb } -func convertDuration(d *lbmpb.Duration) time.Duration { - if d == nil { - return 0 - } - return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +type lbBalancer struct { + cc balancer.ClientConn + target string + opt balancer.BuildOptions + fallbackTimeout time.Duration + doneCh chan struct{} + + // manualResolver is used in the remote LB ClientConn inside grpclb. When + // resolved address updates are received by grpclb, filtered updates will be + // send to remote LB ClientConn through this resolver. + manualResolver *lbManualResolver + // The ClientConn to talk to the remote balancer. + ccRemoteLB *ClientConn + + // Support client side load reporting. Each picker gets a reference to this, + // and will update its content. + clientStats *rpcStats + + mu sync.Mutex // guards everything following. + // The full server list including drops, used to check if the newly received + // serverList contains anything new. Each generate picker will also have + // reference to this list to do the first layer pick. + fullServerList []*lbpb.Server + // All backends addresses, with metadata set to nil. This list contains all + // backend addresses in the same order and with the same duplicates as in + // serverlist. When generating picker, a SubConn slice with the same order + // but with only READY SCs will be gerenated. + backendAddrs []resolver.Address + // Roundrobin functionalities. + csEvltr *connectivityStateEvaluator + state connectivity.State + subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn. + scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns. + picker balancer.Picker + // Support fallback to resolved backend addresses if there's no response + // from remote balancer within fallbackTimeout. + fallbackTimerExpired bool + serverListReceived bool + // resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set + // when resolved address updates are received, and read in the goroutine + // handling fallback. + resolvedBackendAddrs []resolver.Address } -func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) { - if l == nil { +// regeneratePicker takes a snapshot of the balancer, and generates a picker from +// it. The picker +// - always returns ErrTransientFailure if the balancer is in TransientFailure, +// - does two layer roundrobin pick otherwise. +// Caller must hold lb.mu. +func (lb *lbBalancer) regeneratePicker() { + if lb.state == connectivity.TransientFailure { + lb.picker = &errPicker{err: balancer.ErrTransientFailure} return } - servers := l.GetServers() - var ( - sl []*grpclbAddrInfo - addrs []Address - ) - for _, s := range servers { - md := metadata.Pairs("lb-token", s.LoadBalanceToken) - ip := net.IP(s.IpAddress) - ipStr := ip.String() - if ip.To4() == nil { - // Add square brackets to ipv6 addresses, otherwise net.Dial() and - // net.SplitHostPort() will return too many colons error. - ipStr = fmt.Sprintf("[%s]", ipStr) - } - addr := Address{ - Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), - Metadata: &md, + var readySCs []balancer.SubConn + for _, a := range lb.backendAddrs { + if sc, ok := lb.subConns[a]; ok { + if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready { + readySCs = append(readySCs, sc) + } } - sl = append(sl, &grpclbAddrInfo{ - addr: addr, - dropForRateLimiting: s.DropForRateLimiting, - dropForLoadBalancing: s.DropForLoadBalancing, - }) - addrs = append(addrs, addr) } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { - return - } - if len(sl) > 0 { - // reset b.next to 0 when replacing the server list. - b.next = 0 - b.addrs = sl - b.addrCh <- addrs - } - return -} -func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - case <-done: - return - } - b.mu.Lock() - stats := b.clientStats - b.clientStats = lbmpb.ClientStats{} // Clear the stats. - b.mu.Unlock() - t := time.Now() - stats.Timestamp = &lbmpb.Timestamp{ - Seconds: t.Unix(), - Nanos: int32(t.Nanosecond()), - } - if err := s.Send(&lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{ - ClientStats: &stats, - }, - }); err != nil { - grpclog.Errorf("grpclb: failed to send load report: %v", err) + if len(lb.fullServerList) <= 0 { + if len(readySCs) <= 0 { + lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable} return } - } -} - -func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - stream, err := lbc.BalanceLoad(ctx) - if err != nil { - grpclog.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) + lb.picker = &rrPicker{subConns: readySCs} return } - b.mu.Lock() - if b.done { - b.mu.Unlock() - return - } - b.mu.Unlock() - initReq := &lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{ - InitialRequest: &lbmpb.InitialLoadBalanceRequest{ - Name: b.target, - }, - }, + lb.picker = &lbPicker{ + serverList: lb.fullServerList, + subConns: readySCs, + stats: lb.clientStats, } - if err := stream.Send(initReq); err != nil { - grpclog.Errorf("grpclb: failed to send init request: %v", err) - // TODO: backoff on retry? - return true - } - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv init response: %v", err) - // TODO: backoff on retry? - return true - } - initResp := reply.GetInitialResponse() - if initResp == nil { - grpclog.Errorf("grpclb: reply from remote balancer did not include initial response.") - return - } - // TODO: Support delegation. - if initResp.LoadBalancerDelegate != "" { - // delegation - grpclog.Errorf("TODO: Delegation is not supported yet.") - return - } - streamDone := make(chan struct{}) - defer close(streamDone) - b.mu.Lock() - b.clientStats = lbmpb.ClientStats{} // Clear client stats. - b.mu.Unlock() - if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { - go b.sendLoadReport(stream, d, streamDone) - } - // Retrieve the server list. - for { - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv server list: %v", err) - break - } - b.mu.Lock() - if b.done || seq < b.seq { - b.mu.Unlock() - return - } - b.seq++ // tick when receiving a new list of servers. - seq = b.seq - b.mu.Unlock() - if serverList := reply.GetServerList(); serverList != nil { - b.processServerList(serverList, seq) - } - } - return true + return } -func (b *grpclbBalancer) Start(target string, config BalancerConfig) error { - b.rand = rand.New(rand.NewSource(time.Now().Unix())) - // TODO: Fall back to the basic direct connection if there is no name resolver. - if b.r == nil { - return errors.New("there is no name resolver installed") +func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s) + lb.mu.Lock() + defer lb.mu.Unlock() + + oldS, ok := lb.scStates[sc] + if !ok { + grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return } - b.target = target - b.mu.Lock() - if b.done { - b.mu.Unlock() - return ErrClientConnClosing + lb.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(lb.scStates, sc) } - b.addrCh = make(chan []Address) - w, err := b.r.Resolve(target) - if err != nil { - b.mu.Unlock() - grpclog.Errorf("grpclb: failed to resolve address: %v, err: %v", target, err) - return err - } - b.w = w - b.mu.Unlock() - balancerAddrsCh := make(chan []remoteBalancerInfo, 1) - // Spawn a goroutine to monitor the name resolution of remote load balancer. - go func() { - for { - if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil { - grpclog.Warningf("grpclb: the naming watcher stops working due to %v.\n", err) - close(balancerAddrsCh) - return - } - } - }() - // Spawn a goroutine to talk to the remote load balancer. - go func() { - var ( - cc *ClientConn - // ccError is closed when there is an error in the current cc. - // A new rb should be picked from rbs and connected. - ccError chan struct{} - rb *remoteBalancerInfo - rbs []remoteBalancerInfo - rbIdx int - ) - - defer func() { - if ccError != nil { - select { - case <-ccError: - default: - close(ccError) - } - } - if cc != nil { - cc.Close() - } - }() - - for { - var ok bool - select { - case rbs, ok = <-balancerAddrsCh: - if !ok { - return - } - foundIdx := -1 - if rb != nil { - for i, trb := range rbs { - if trb == *rb { - foundIdx = i - break - } - } - } - if foundIdx >= 0 { - if foundIdx >= 1 { - // Move the address in use to the beginning of the list. - b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0] - rbIdx = 0 - } - continue // If found, don't dial new cc. - } else if len(rbs) > 0 { - // Pick a random one from the list, instead of always using the first one. - if l := len(rbs); l > 1 && rb != nil { - tmpIdx := b.rand.Intn(l - 1) - b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0] - } - rbIdx = 0 - rb = &rbs[0] - } else { - // foundIdx < 0 && len(rbs) <= 0. - rb = nil - } - case <-ccError: - ccError = nil - if rbIdx < len(rbs)-1 { - rbIdx++ - rb = &rbs[rbIdx] - } else { - rb = nil - } - } - - if rb == nil { - continue - } - if cc != nil { - cc.Close() - } - // Talk to the remote load balancer to get the server list. - var ( - err error - dopts []DialOption - ) - if creds := config.DialCreds; creds != nil { - if rb.name != "" { - if err := creds.OverrideServerName(rb.name); err != nil { - grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v", err) - continue - } - } - dopts = append(dopts, WithTransportCredentials(creds)) - } else { - dopts = append(dopts, WithInsecure()) - } - if dialer := config.Dialer; dialer != nil { - // WithDialer takes a different type of function, so we instead use a special DialOption here. - dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) - } - dopts = append(dopts, WithBlock()) - ccError = make(chan struct{}) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - cc, err = DialContext(ctx, rb.addr, dopts...) - cancel() - if err != nil { - grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) - close(ccError) - continue - } - b.mu.Lock() - b.seq++ // tick when getting a new balancer address - seq := b.seq - b.next = 0 - b.mu.Unlock() - go func(cc *ClientConn, ccError chan struct{}) { - lbc := &loadBalancerClient{cc} - b.callRemoteBalancer(lbc, seq) - cc.Close() - select { - case <-ccError: - default: - close(ccError) - } - }(cc, ccError) - } - }() - return nil -} + oldAggrState := lb.state + lb.state = lb.csEvltr.recordTransition(oldS, s) -func (b *grpclbBalancer) down(addr Address, err error) { - b.mu.Lock() - defer b.mu.Unlock() - for _, a := range b.addrs { - if addr == a.addr { - a.connected = false - break - } + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (oldS == connectivity.Ready) != (s == connectivity.Ready) || + (lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + lb.regeneratePicker() } + + lb.cc.UpdateBalancerState(lb.state, lb.picker) + return } -func (b *grpclbBalancer) Up(addr Address) func(error) { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return nil - } - var cnt int - for _, a := range b.addrs { - if a.addr == addr { - if a.connected { - return nil - } - a.connected = true - } - if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing { - cnt++ - } - } - // addr is the only one which is connected. Notify the Get() callers who are blocking. - if cnt == 1 && b.waitCh != nil { - close(b.waitCh) - b.waitCh = nil +// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use +// resolved backends (backends received from resolver, not from remote balancer) +// if no connection to remote balancers was successful. +func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) { + timer := time.NewTimer(fallbackTimeout) + defer timer.Stop() + select { + case <-timer.C: + case <-lb.doneCh: + return } - return func(err error) { - b.down(addr, err) + lb.mu.Lock() + if lb.serverListReceived { + lb.mu.Unlock() + return } + lb.fallbackTimerExpired = true + lb.refreshSubConns(lb.resolvedBackendAddrs) + lb.mu.Unlock() } -func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { - var ch chan struct{} - b.mu.Lock() - if b.done { - b.mu.Unlock() - err = ErrClientConnClosing +// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB +// clientConn. The remoteLB clientConn will handle creating/removing remoteLB +// connections. +func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs) + if len(addrs) <= 0 { return } - seq := b.seq - defer func() { - if err != nil { - return - } - put = func() { - s, ok := rpcInfoFromContext(ctx) - if !ok { - return - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { - return - } - b.clientStats.NumCallsFinished++ - if !s.bytesSent { - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - } else if s.bytesReceived { - b.clientStats.NumCallsFinishedKnownReceived++ - } + var remoteBalancerAddrs, backendAddrs []resolver.Address + for _, a := range addrs { + if a.Type == resolver.GRPCLB { + remoteBalancerAddrs = append(remoteBalancerAddrs, a) + } else { + backendAddrs = append(backendAddrs, a) } - }() - - b.clientStats.NumCallsStarted++ - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - if !opts.BlockingWait { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") - return } - // Wait on b.waitCh for non-failfast RPCs. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() - for { - select { - case <-ctx.Done(): - b.mu.Lock() - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ctx.Err() - return - case <-ch: - b.mu.Lock() - if b.done { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ErrClientConnClosing - return - } - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - // The newly added addr got removed by Down() again. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() + if lb.ccRemoteLB == nil { + if len(remoteBalancerAddrs) <= 0 { + grpclog.Errorf("grpclb: no remote balancer address is available, should never happen") + return } + // First time receiving resolved addresses, create a cc to remote + // balancers. + lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName) + // Start the fallback goroutine. + go lb.fallbackToBackendsAfter(lb.fallbackTimeout) } -} -func (b *grpclbBalancer) Notify() <-chan []Address { - return b.addrCh + // cc to remote balancers uses lb.manualResolver. Send the updated remote + // balancer addresses to it through manualResolver. + lb.manualResolver.NewAddress(remoteBalancerAddrs) + + lb.mu.Lock() + lb.resolvedBackendAddrs = backendAddrs + // If serverListReceived is true, connection to remote balancer was + // successful and there's no need to do fallback anymore. + // If fallbackTimerExpired is false, fallback hasn't happened yet. + if !lb.serverListReceived && lb.fallbackTimerExpired { + // This means we received a new list of resolved backends, and we are + // still in fallback mode. Need to update the list of backends we are + // using to the new list of backends. + lb.refreshSubConns(lb.resolvedBackendAddrs) + } + lb.mu.Unlock() } -func (b *grpclbBalancer) Close() error { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return errBalancerClosed - } - b.done = true - if b.waitCh != nil { - close(b.waitCh) - } - if b.addrCh != nil { - close(b.addrCh) +func (lb *lbBalancer) Close() { + select { + case <-lb.doneCh: + return + default: } - if b.w != nil { - b.w.Close() + close(lb.doneCh) + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.Close() } - return nil } diff --git a/vendor/google.golang.org/grpc/grpclb/grpclb_test.go b/vendor/google.golang.org/grpc/grpclb/grpclb_test.go index 1491be788..d83ea6b88 100644 --- a/vendor/google.golang.org/grpc/grpclb/grpclb_test.go +++ b/vendor/google.golang.org/grpc/grpclb/grpclb_test.go @@ -27,6 +27,7 @@ import ( "fmt" "io" "net" + "strconv" "strings" "sync" "testing" @@ -35,22 +36,27 @@ import ( "github.com/golang/protobuf/proto" "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/naming" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/leakcheck" + + _ "google.golang.org/grpc/grpclog/glogger" ) var ( - lbsn = "bar.com" - besn = "foo.com" - lbToken = "iamatoken" + lbServerName = "bar.com" + beServerName = "foo.com" + lbToken = "iamatoken" // Resolver replaces localhost with fakeName in Next(). // Dialer replaces fakeName with localhost when dialing. @@ -58,83 +64,6 @@ var ( fakeName = "fake.Name" ) -type testWatcher struct { - // the channel to receives name resolution updates - update chan *naming.Update - // the side channel to get to know how many updates in a batch - side chan int - // the channel to notifiy update injector that the update reading is done - readDone chan int -} - -func (w *testWatcher) Next() (updates []*naming.Update, err error) { - n, ok := <-w.side - if !ok { - return nil, fmt.Errorf("w.side is closed") - } - for i := 0; i < n; i++ { - u, ok := <-w.update - if !ok { - break - } - if u != nil { - // Resolver replaces localhost with fakeName in Next(). - // Custom dialer will replace fakeName with localhost when dialing. - u.Addr = strings.Replace(u.Addr, "localhost", fakeName, 1) - updates = append(updates, u) - } - } - w.readDone <- 0 - return -} - -func (w *testWatcher) Close() { - close(w.side) -} - -// Inject naming resolution updates to the testWatcher. -func (w *testWatcher) inject(updates []*naming.Update) { - w.side <- len(updates) - for _, u := range updates { - w.update <- u - } - <-w.readDone -} - -type testNameResolver struct { - w *testWatcher - addrs []string -} - -func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { - r.w = &testWatcher{ - update: make(chan *naming.Update, len(r.addrs)), - side: make(chan int, 1), - readDone: make(chan int), - } - r.w.side <- len(r.addrs) - for _, addr := range r.addrs { - r.w.update <- &naming.Update{ - Op: naming.Add, - Addr: addr, - Metadata: &naming.AddrMetadataGRPCLB{ - AddrType: naming.GRPCLB, - ServerName: lbsn, - }, - } - } - go func() { - <-r.w.readDone - }() - return r.w, nil -} - -func (r *testNameResolver) inject(updates []*naming.Update) { - if r.w != nil { - r.w.inject(updates) - } -} - type serverNameCheckCreds struct { mu sync.Mutex sn string @@ -199,23 +128,22 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) { } type remoteBalancer struct { - sls []*lbmpb.ServerList - intervals []time.Duration + sls chan *lbmpb.ServerList statsDura time.Duration done chan struct{} mu sync.Mutex stats lbmpb.ClientStats } -func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer { +func newRemoteBalancer(intervals []time.Duration) *remoteBalancer { return &remoteBalancer{ - sls: sls, - intervals: intervals, - done: make(chan struct{}), + sls: make(chan *lbmpb.ServerList, 1), + done: make(chan struct{}), } } func (b *remoteBalancer) stop() { + close(b.sls) close(b.done) } @@ -225,7 +153,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer return err } initReq := req.GetInitialRequest() - if initReq.Name != besn { + if initReq.Name != beServerName { return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) } resp := &lbmpb.LoadBalanceResponse{ @@ -260,8 +188,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer b.mu.Unlock() } }() - for k, v := range b.sls { - time.Sleep(b.intervals[k]) + for v := range b.sls { resp = &lbmpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{ ServerList: v, @@ -278,7 +205,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer type testServer struct { testpb.TestServiceServer - addr string + addr string + fallback bool } const testmdkey = "testmd" @@ -288,7 +216,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E if !ok { return nil, status.Error(codes.Internal, "failed to receive metadata") } - if md == nil || md["lb-token"][0] != lbToken { + if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) { return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md) } grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr)) @@ -299,13 +227,13 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ return nil } -func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { +func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) { for _, l := range lis { creds := &serverNameCheckCreds{ sn: sn, } s := grpc.NewServer(grpc.Creds(creds)) - testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()}) + testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback}) servers = append(servers, s) go func(s *grpc.Server, l net.Listener) { s.Serve(l) @@ -348,7 +276,7 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er beListeners = append(beListeners, beLis) } - backends := startBackends(besn, beListeners...) + backends := startBackends(beServerName, false, beListeners...) // Start a load balancer. lbLis, err := net.Listen("tcp", "localhost:0") @@ -357,21 +285,21 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er return } lbCreds := &serverNameCheckCreds{ - sn: lbsn, + sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) if err != nil { err = fmt.Errorf("Failed to generate the port number %v", err) return } - ls = newRemoteBalancer(nil, nil) + ls = newRemoteBalancer(nil) lbspb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) }() tss = &testServers{ - lbAddr: lbLis.Addr().String(), + lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port), ls: ls, lb: lb, beIPs: beIPs, @@ -389,6 +317,10 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er func TestGRPCLB(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) @@ -405,136 +337,175 @@ func TestGRPCLB(t *testing.T) { sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} + tss.ls.sls <- sl creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) } } -func TestDropRequest(t *testing.T) { +// The remote balancer sends response with duplicates to grpclb client. +func TestGRPCLBWeighted(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(2) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbmpb.ServerList{{ - Servers: []*lbmpb.Server{{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, - }, { - IpAddress: tss.beIPs[1], - Port: int32(tss.bePorts[1]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: false, - }}, + + beServers := []*lbmpb.Server{{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, { + IpAddress: tss.beIPs[1], + Port: int32(tss.bePorts[1]), + LoadBalanceToken: lbToken, }} - tss.ls.intervals = []time.Duration{0} + portsToIndex := make(map[int]int) + for i := range beServers { + portsToIndex[tss.bePorts[i]] = i + } + creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - // Wait until the first connection is up. - // The first one has Drop set to true, error should contain "drop requests". - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { - break + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + sequences := []string{"00101", "00011"} + for _, seq := range sequences { + var ( + bes []*lbmpb.Server + p peer.Peer + result string + ) + for _, s := range seq { + bes = append(bes, beServers[s-'0']) + } + tss.ls.sls <- &lbmpb.ServerList{Servers: bes} + + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) } + result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port]) } - } - // The 1st, non-fail-fast RPC should succeed. This ensures both server - // connections are made, because the first one has DropForLoadBalancing set to true. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { - t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err) - } - for i := 0; i < 3; i++ { - // Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing - // set to true. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) - } - // Even fail-fast RPCs should succeed since they choose the - // non-drop-request backend according to the round robin policy. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) + // The generated result will be in format of "0010100101". + if !strings.Contains(result, strings.Repeat(seq, 2)) { + t.Errorf("got result sequence %q, want patten %q", result, seq) } } } -func TestDropRequestFailedNonFailFast(t *testing.T) { +func TestDropRequest(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbmpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, - } - var bes []*lbmpb.Server - bes = append(bes, be) - sl := &lbmpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbmpb.ServerList{ + Servers: []*lbmpb.Server{{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + DropForLoadBalancing: false, + }, { + DropForLoadBalancing: true, + }}, } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + // The 1st, non-fail-fast RPC should succeed. This ensures both server + // connections are made, because the first one has DropForLoadBalancing set to true. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err) + } + for _, failfast := range []bool{true, false} { + for i := 0; i < 3; i++ { + // Even RPCs should fail, because the 2st backend has + // DropForLoadBalancing set to true. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); status.Code(err) != codes.Unavailable { + t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) + } + // Odd RPCs should succeed since they choose the non-drop-request + // backend according to the round robin policy. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil { + t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) + } + } } } // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. func TestBalancerDisconnects(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + var ( - lbAddrs []string - lbs []*grpc.Server + tests []*testServers + lbs []*grpc.Server ) - for i := 0; i < 3; i++ { + for i := 0; i < 2; i++ { tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) @@ -551,78 +522,166 @@ func TestBalancerDisconnects(t *testing.T) { sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} + tss.ls.sls <- sl - lbAddrs = append(lbAddrs, tss.lbAddr) + tests = append(tests, tss) lbs = append(lbs, tss.lb) } creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - resolver := &testNameResolver{ - addrs: lbAddrs[:2], - } - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - var previousTrailer string - trailer := metadata.MD{} - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + + r.NewAddress([]resolver.Address{{ + Addr: tests[0].lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: tests[1].lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + var p peer.Peer + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) - } else { - previousTrailer = trailer[testmdkey][0] } - // The initial resolver update contains lbs[0] and lbs[1]. - // When lbs[0] is stopped, lbs[1] should be used. + if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] { + t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0]) + } + lbs[0].Stop() - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + // Stop balancer[0], balancer[1] should be used by grpclb. + // Check peer address to see if that happened. + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) - } else if trailer[testmdkey][0] != previousTrailer { - // A new backend server should receive the request. - // The trailer contains the backend address, so the trailer should be different from the previous one. - previousTrailer = trailer[testmdkey][0] - break - } - time.Sleep(100 * time.Millisecond) - } - // Inject a update to add lbs[2] to resolved addresses. - resolver.inject([]*naming.Update{ - {Op: naming.Add, - Addr: lbAddrs[2], - Metadata: &naming.AddrMetadataGRPCLB{ - AddrType: naming.GRPCLB, - ServerName: lbsn, - }, - }, + } + if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("No RPC sent to second backend after 1 second") +} + +type customGRPCLBBuilder struct { + balancer.Builder + name string +} + +func (b *customGRPCLBBuilder) Name() string { + return b.name +} + +const grpclbCustomFallbackName = "grpclb_with_custom_fallback_timeout" + +func init() { + balancer.Register(&customGRPCLBBuilder{ + Builder: grpc.NewLBBuilderWithFallbackTimeout(100 * time.Millisecond), + name: grpclbCustomFallbackName, }) - // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used. - lbs[1].Stop() - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { +} + +func TestFallback(t *testing.T) { + defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + tss, cleanup, err := newLoadBalancer(1) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + // Start a standalone backend. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + defer beLis.Close() + standaloneBEs := startBackends(beServerName, true, beLis) + defer stopBackends(standaloneBEs) + + be := &lbmpb.Server{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + } + var bes []*lbmpb.Server + bes = append(bes, be) + sl := &lbmpb.ServerList{ + Servers: bes, + } + tss.ls.sls <- sl + creds := serverNameCheckCreds{ + expected: beServerName, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerName(grpclbCustomFallbackName), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + r.NewAddress([]resolver.Address{{ + Addr: "", + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: beLis.Addr().String(), + Type: resolver.Backend, + ServerName: beServerName, + }}) + + var p peer.Peer + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) + } + if p.Addr.String() != beLis.Addr().String() { + t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) + } + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: beLis.Addr().String(), + Type: resolver.Backend, + ServerName: beServerName, + }}) + + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) - } else if trailer[testmdkey][0] != previousTrailer { - // A new backend server should receive the request. - // The trailer contains the backend address, so the trailer should be different from the previous one. - break } - time.Sleep(100 * time.Millisecond) + if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] { + return + } + time.Sleep(time.Millisecond) } + t.Fatalf("No RPC sent to backend behind remote balancer after 1 second") } type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - if strings.Contains(uri[0], "failtosend") { + if strings.Contains(uri[0], failtosendURI) { return nil, fmt.Errorf("rpc should fail to send") } return nil, nil @@ -640,35 +699,45 @@ func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error { } func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats { - tss, cleanup, err := newLoadBalancer(3) + defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbmpb.ServerList{{ + tss.ls.sls <- &lbmpb.ServerList{ Servers: []*lbmpb.Server{{ - IpAddress: tss.beIPs[2], - Port: int32(tss.bePorts[2]), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, DropForLoadBalancing: dropForLoadBalancing, DropForRateLimiting: dropForRateLimiting, }}, - }} - tss.ls.intervals = []time.Duration{0} + } tss.ls.statsDura = 100 * time.Millisecond - creds := serverNameCheckCreds{expected: besn} + creds := serverNameCheckCreds{expected: beServerName} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}), - grpc.WithBlock(), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithTransportCredentials(&creds), + grpc.WithPerRPCCredentials(failPreRPCCred{}), + grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + runRPCs(cc) time.Sleep(1 * time.Second) tss.ls.mu.Lock() @@ -677,7 +746,11 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool return stats } -const countRPC = 40 +const ( + countRPC = 40 + failtosendURI = "failtosend" + dropErrDesc = "request dropped by grpclb" +) func TestGRPCLBStatsUnarySuccess(t *testing.T) { defer leakcheck.Check(t) @@ -709,7 +782,7 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -737,7 +810,7 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -766,7 +839,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) } for i := 0; i < countRPC-1; i++ { - grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) + grpc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil, cc) } }) @@ -824,7 +897,7 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -852,7 +925,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -887,7 +960,7 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { } } for i := 0; i < countRPC-1; i++ { - grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") + grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, failtosendURI) } }) diff --git a/vendor/google.golang.org/grpc/grpclb_picker.go b/vendor/google.golang.org/grpc/grpclb_picker.go new file mode 100644 index 000000000..872c7ccea --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_picker.go @@ -0,0 +1,159 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/status" +) + +type rpcStats struct { + NumCallsStarted int64 + NumCallsFinished int64 + NumCallsFinishedWithDropForRateLimiting int64 + NumCallsFinishedWithDropForLoadBalancing int64 + NumCallsFinishedWithClientFailedToSend int64 + NumCallsFinishedKnownReceived int64 +} + +// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats. +func (s *rpcStats) toClientStats() *lbpb.ClientStats { + stats := &lbpb.ClientStats{ + NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0), + NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0), + NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0), + NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0), + NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0), + NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0), + } + return stats +} + +func (s *rpcStats) dropForRateLimiting() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) dropForLoadBalancing() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) failedToSend() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) knownReceived() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +type errPicker struct { + // Pick always returns this err. + err error +} + +func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + return nil, nil, p.err +} + +// rrPicker does roundrobin on subConns. It's typically used when there's no +// response from remote balancer, and grpclb falls back to the resolved +// backends. +// +// It guaranteed that len(subConns) > 0. +type rrPicker struct { + mu sync.Mutex + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int +} + +func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + return sc, nil, nil +} + +// lbPicker does two layers of picks: +// +// First layer: roundrobin on all servers in serverList, including drops and backends. +// - If it picks a drop, the RPC will fail as being dropped. +// - If it picks a backend, do a second layer pick to pick the real backend. +// +// Second layer: roundrobin on all READY backends. +// +// It's guaranteed that len(serverList) > 0. +type lbPicker struct { + mu sync.Mutex + serverList []*lbpb.Server + serverListNext int + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int + + stats *rpcStats +} + +func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Layer one roundrobin on serverList. + s := p.serverList[p.serverListNext] + p.serverListNext = (p.serverListNext + 1) % len(p.serverList) + + // If it's a drop, return an error and fail the RPC. + if s.DropForRateLimiting { + p.stats.dropForRateLimiting() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + if s.DropForLoadBalancing { + p.stats.dropForLoadBalancing() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + + // If not a drop but there's no ready subConns. + if len(p.subConns) <= 0 { + return nil, nil, balancer.ErrNoSubConnAvailable + } + + // Return the next ready subConn in the list, also collect rpc stats. + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + done := func(info balancer.DoneInfo) { + if !info.BytesSent { + p.stats.failedToSend() + } else if info.BytesReceived { + p.stats.knownReceived() + } + } + return sc, done, nil +} diff --git a/vendor/google.golang.org/grpc/grpclb_remote_balancer.go b/vendor/google.golang.org/grpc/grpclb_remote_balancer.go new file mode 100644 index 000000000..1b580df26 --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_remote_balancer.go @@ -0,0 +1,254 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "net" + "reflect" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" +) + +// processServerList updates balaner's internal state, create/remove SubConns +// and regenerates picker using the received serverList. +func (lb *lbBalancer) processServerList(l *lbpb.ServerList) { + grpclog.Infof("lbBalancer: processing server list: %+v", l) + lb.mu.Lock() + defer lb.mu.Unlock() + + // Set serverListReceived to true so fallback will not take effect if it has + // not hit timeout. + lb.serverListReceived = true + + // If the new server list == old server list, do nothing. + if reflect.DeepEqual(lb.fullServerList, l.Servers) { + grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring") + return + } + lb.fullServerList = l.Servers + + var backendAddrs []resolver.Address + for _, s := range l.Servers { + if s.DropForLoadBalancing || s.DropForRateLimiting { + continue + } + + md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken) + ip := net.IP(s.IpAddress) + ipStr := ip.String() + if ip.To4() == nil { + // Add square brackets to ipv6 addresses, otherwise net.Dial() and + // net.SplitHostPort() will return too many colons error. + ipStr = fmt.Sprintf("[%s]", ipStr) + } + addr := resolver.Address{ + Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), + Metadata: &md, + } + + backendAddrs = append(backendAddrs, addr) + } + + // Call refreshSubConns to create/remove SubConns. + backendsUpdated := lb.refreshSubConns(backendAddrs) + // If no backend was updated, no SubConn will be newed/removed. But since + // the full serverList was different, there might be updates in drops or + // pick weights(different number of duplicates). We need to update picker + // with the fulllist. + if !backendsUpdated { + lb.regeneratePicker() + lb.cc.UpdateBalancerState(lb.state, lb.picker) + } +} + +// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool +// indicating whether the backendAddrs are different from the cached +// backendAddrs (whether any SubConn was newed/removed). +// Caller must hold lb.mu. +func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool { + lb.backendAddrs = nil + var backendsUpdated bool + // addrsSet is the set converted from backendAddrs, it's used to quick + // lookup for an address. + addrsSet := make(map[resolver.Address]struct{}) + // Create new SubConns. + for _, addr := range backendAddrs { + addrWithoutMD := addr + addrWithoutMD.Metadata = nil + addrsSet[addrWithoutMD] = struct{}{} + lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD) + + if _, ok := lb.subConns[addrWithoutMD]; !ok { + backendsUpdated = true + + // Use addrWithMD to create the SubConn. + sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err) + continue + } + lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map. + lb.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + + for a, sc := range lb.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + backendsUpdated = true + + lb.cc.RemoveSubConn(sc) + delete(lb.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } + + return backendsUpdated +} + +func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error { + for { + reply, err := s.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv server list: %v", err) + } + if serverList := reply.GetServerList(); serverList != nil { + lb.processServerList(serverList) + } + } +} + +func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-s.Context().Done(): + return + } + stats := lb.clientStats.toClientStats() + t := time.Now() + stats.Timestamp = &lbpb.Timestamp{ + Seconds: t.Unix(), + Nanos: int32(t.Nanosecond()), + } + if err := s.Send(&lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{ + ClientStats: stats, + }, + }); err != nil { + return + } + } +} +func (lb *lbBalancer) callRemoteBalancer() error { + lbClient := &loadBalancerClient{cc: lb.ccRemoteLB} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := lbClient.BalanceLoad(ctx, FailFast(false)) + if err != nil { + return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) + } + + // grpclb handshake on the stream. + initReq := &lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ + InitialRequest: &lbpb.InitialLoadBalanceRequest{ + Name: lb.target, + }, + }, + } + if err := stream.Send(initReq); err != nil { + return fmt.Errorf("grpclb: failed to send init request: %v", err) + } + reply, err := stream.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv init response: %v", err) + } + initResp := reply.GetInitialResponse() + if initResp == nil { + return fmt.Errorf("grpclb: reply from remote balancer did not include initial response") + } + if initResp.LoadBalancerDelegate != "" { + return fmt.Errorf("grpclb: Delegation is not supported") + } + + go func() { + if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { + lb.sendLoadReport(stream, d) + } + }() + return lb.readServerList(stream) +} + +func (lb *lbBalancer) watchRemoteBalancer() { + for { + err := lb.callRemoteBalancer() + select { + case <-lb.doneCh: + return + default: + if err != nil { + grpclog.Error(err) + } + } + + } +} + +func (lb *lbBalancer) dialRemoteLB(remoteLBName string) { + var dopts []DialOption + if creds := lb.opt.DialCreds; creds != nil { + if err := creds.OverrideServerName(remoteLBName); err == nil { + dopts = append(dopts, WithTransportCredentials(creds)) + } else { + grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err) + dopts = append(dopts, WithInsecure()) + } + } else { + dopts = append(dopts, WithInsecure()) + } + if lb.opt.Dialer != nil { + // WithDialer takes a different type of function, so we instead use a + // special DialOption here. + dopts = append(dopts, withContextDialer(lb.opt.Dialer)) + } + // Explicitly set pickfirst as the balancer. + dopts = append(dopts, WithBalancerName(PickFirstBalancerName)) + dopts = append(dopts, withResolverBuilder(lb.manualResolver)) + // Dial using manualResolver.Scheme, which is a random scheme generated + // when init grpclb. The target name is not important. + cc, err := Dial("grpclb:///grpclb.server", dopts...) + if err != nil { + grpclog.Fatalf("failed to dial: %v", err) + } + lb.ccRemoteLB = cc + go lb.watchRemoteBalancer() +} diff --git a/vendor/google.golang.org/grpc/grpclb_util.go b/vendor/google.golang.org/grpc/grpclb_util.go new file mode 100644 index 000000000..93ab2db32 --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_util.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +// The parent ClientConn should re-resolve when grpclb loses connection to the +// remote balancer. When the ClientConn inside grpclb gets a TransientFailure, +// it calls lbManualResolver.ResolveNow(), which calls parent ClientConn's +// ResolveNow, and eventually results in re-resolve happening in parent +// ClientConn's resolver (DNS for example). +// +// parent +// ClientConn +// +-----------------------------------------------------------------+ +// | parent +---------------------------------+ | +// | DNS ClientConn | grpclb | | +// | resolver balancerWrapper | | | +// | + + | grpclb grpclb | | +// | | | | ManualResolver ClientConn | | +// | | | | + + | | +// | | | | | | Transient | | +// | | | | | | Failure | | +// | | | | | <--------- | | | +// | | | <--------------- | ResolveNow | | | +// | | <--------- | ResolveNow | | | | | +// | | ResolveNow | | | | | | +// | | | | | | | | +// | + + | + + | | +// | +---------------------------------+ | +// +-----------------------------------------------------------------+ + +// lbManualResolver is used by the ClientConn inside grpclb. It's a manual +// resolver with a special ResolveNow() function. +// +// When ResolveNow() is called, it calls ResolveNow() on the parent ClientConn, +// so when grpclb client lose contact with remote balancers, the parent +// ClientConn's resolver will re-resolve. +type lbManualResolver struct { + scheme string + ccr resolver.ClientConn + + ccb balancer.ClientConn +} + +func (r *lbManualResolver) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) { + r.ccr = cc + return r, nil +} + +func (r *lbManualResolver) Scheme() string { + return r.scheme +} + +// ResolveNow calls resolveNow on the parent ClientConn. +func (r *lbManualResolver) ResolveNow(o resolver.ResolveNowOption) { + r.ccb.ResolveNow(o) +} + +// Close is a noop for Resolver. +func (*lbManualResolver) Close() {} + +// NewAddress calls cc.NewAddress. +func (r *lbManualResolver) NewAddress(addrs []resolver.Address) { + r.ccr.NewAddress(addrs) +} + +// NewServiceConfig calls cc.NewServiceConfig. +func (r *lbManualResolver) NewServiceConfig(sc string) { + r.ccr.NewServiceConfig(sc) +} diff --git a/vendor/google.golang.org/grpc/naming/go17.go b/vendor/google.golang.org/grpc/naming/go17.go index a537b08c6..57b65d7b8 100644 --- a/vendor/google.golang.org/grpc/naming/go17.go +++ b/vendor/google.golang.org/grpc/naming/go17.go @@ -1,4 +1,4 @@ -// +build go1.6, !go1.8 +// +build go1.6,!go1.8 /* * diff --git a/vendor/google.golang.org/grpc/picker_wrapper.go b/vendor/google.golang.org/grpc/picker_wrapper.go index 9085dbc9c..db82bfb3a 100644 --- a/vendor/google.golang.org/grpc/picker_wrapper.go +++ b/vendor/google.golang.org/grpc/picker_wrapper.go @@ -97,7 +97,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. p = bp.picker bp.mu.Unlock() - subConn, put, err := p.Pick(ctx, opts) + subConn, done, err := p.Pick(ctx, opts) if err != nil { switch err { @@ -120,7 +120,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. continue } if t, ok := acw.getAddrConn().getReadyTransport(); ok { - return t, put, nil + return t, done, nil } grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") // If ok == false, ac.state is not READY. diff --git a/vendor/google.golang.org/grpc/pickfirst.go b/vendor/google.golang.org/grpc/pickfirst.go index e83ca2b0d..bf659d49d 100644 --- a/vendor/google.golang.org/grpc/pickfirst.go +++ b/vendor/google.golang.org/grpc/pickfirst.go @@ -26,6 +26,9 @@ import ( "google.golang.org/grpc/resolver" ) +// PickFirstBalancerName is the name of the pick_first balancer. +const PickFirstBalancerName = "pick_first" + func newPickfirstBuilder() balancer.Builder { return &pickfirstBuilder{} } @@ -37,7 +40,7 @@ func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions } func (*pickfirstBuilder) Name() string { - return "pick_first" + return PickFirstBalancerName } type pickfirstBalancer struct { diff --git a/vendor/google.golang.org/grpc/pickfirst_test.go b/vendor/google.golang.org/grpc/pickfirst_test.go index e58b3422c..2f85febff 100644 --- a/vendor/google.golang.org/grpc/pickfirst_test.go +++ b/vendor/google.golang.org/grpc/pickfirst_test.go @@ -28,9 +28,17 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/leakcheck" ) +func errorDesc(err error) string { + if s, ok := status.FromError(err); ok { + return s.Message() + } + return err.Error() +} + func TestOneBackendPickfirst(t *testing.T) { defer leakcheck.Check(t) r, rcleanup := manual.GenerateAndRegisterManualResolver() @@ -40,7 +48,7 @@ func TestOneBackendPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -50,14 +58,14 @@ func TestOneBackendPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) // The second RPC should succeed. for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { return } time.Sleep(time.Millisecond) @@ -74,7 +82,7 @@ func TestBackendsPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -84,14 +92,14 @@ func TestBackendsPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) // The second RPC should succeed with the first server. for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { return } time.Sleep(time.Millisecond) @@ -108,7 +116,7 @@ func TestNewAddressWhileBlockingPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -118,7 +126,7 @@ func TestNewAddressWhileBlockingPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -145,7 +153,7 @@ func TestCloseWithPendingRPCPickfirst(t *testing.T) { _, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -155,7 +163,7 @@ func TestCloseWithPendingRPCPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } @@ -182,7 +190,7 @@ func TestOneServerDownPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake()) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -192,14 +200,14 @@ func TestOneServerDownPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) // The second RPC should succeed with the first server. for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(time.Millisecond) @@ -207,7 +215,7 @@ func TestOneServerDownPickfirst(t *testing.T) { servers[0].stop() for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { return } time.Sleep(time.Millisecond) @@ -224,7 +232,7 @@ func TestAllServersDownPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{}), WithWaitForHandshake()) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -234,14 +242,14 @@ func TestAllServersDownPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) // The second RPC should succeed with the first server. for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(time.Millisecond) @@ -251,7 +259,7 @@ func TestAllServersDownPickfirst(t *testing.T) { servers[i].stop() } for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); Code(err) == codes.Unavailable { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); status.Code(err) == codes.Unavailable { return } time.Sleep(time.Millisecond) @@ -268,7 +276,7 @@ func TestAddressesRemovedPickfirst(t *testing.T) { servers, _, scleanup := startServers(t, numServers, math.MaxInt32) defer scleanup() - cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -278,19 +286,19 @@ func TestAddressesRemovedPickfirst(t *testing.T) { defer cancel() req := "port" var reply string - if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || status.Code(err) != codes.DeadlineExceeded { t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) } r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}, {Addr: servers[2].addr}}) for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) @@ -299,13 +307,13 @@ func TestAddressesRemovedPickfirst(t *testing.T) { // Remove server[0]. r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}}) for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port { break } time.Sleep(time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -314,7 +322,7 @@ func TestAddressesRemovedPickfirst(t *testing.T) { // Append server[0], nothing should change. r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}, {Addr: servers[0].addr}}) for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) } time.Sleep(10 * time.Millisecond) @@ -323,13 +331,13 @@ func TestAddressesRemovedPickfirst(t *testing.T) { // Remove server[1]. r.NewAddress([]resolver.Address{{Addr: servers[2].addr}, {Addr: servers[0].addr}}) for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port { break } time.Sleep(time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[2].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) } time.Sleep(10 * time.Millisecond) @@ -338,13 +346,13 @@ func TestAddressesRemovedPickfirst(t *testing.T) { // Remove server[2]. r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) for i := 0; i < 1000; i++ { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port { break } time.Sleep(time.Millisecond) } for i := 0; i < 20; i++ { - if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) diff --git a/vendor/google.golang.org/grpc/resolver/resolver.go b/vendor/google.golang.org/grpc/resolver/resolver.go index 0dd887fa5..df097eedf 100644 --- a/vendor/google.golang.org/grpc/resolver/resolver.go +++ b/vendor/google.golang.org/grpc/resolver/resolver.go @@ -38,7 +38,7 @@ func Register(b Builder) { // Get returns the resolver builder registered with the given scheme. // If no builder is register with the scheme, the default scheme will // be used. -// If the default scheme is not modified, "dns" will be the default +// If the default scheme is not modified, "passthrough" will be the default // scheme, and the preinstalled dns resolver will be used. // If the default scheme is modified, and a resolver is registered with // the scheme, that resolver will be returned. @@ -55,7 +55,7 @@ func Get(scheme string) Builder { } // SetDefaultScheme sets the default scheme that will be used. -// The default default scheme is "dns". +// The default default scheme is "passthrough". func SetDefaultScheme(scheme string) { defaultScheme = scheme } @@ -78,7 +78,9 @@ type Address struct { // Type is the type of this address. Type AddressType // ServerName is the name of this address. - // It's the name of the grpc load balancer, which will be used for authentication. + // + // e.g. if Type is GRPCLB, ServerName should be the name of the remote load + // balancer, not the name of the backend. ServerName string // Metadata is the information associated with Addr, which may be used // to make load balancing decision. @@ -88,10 +90,18 @@ type Address struct { // BuildOption includes additional information for the builder to create // the resolver. type BuildOption struct { + // UserOptions can be used to pass configuration between DialOptions and the + // resolver. + UserOptions interface{} } // ClientConn contains the callbacks for resolver to notify any updates // to the gRPC ClientConn. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type ClientConn interface { // NewAddress is called by resolver to notify ClientConn a new list // of resolved addresses. @@ -128,8 +138,10 @@ type ResolveNowOption struct{} // Resolver watches for the updates on the specified target. // Updates include address updates and service config updates. type Resolver interface { - // ResolveNow will be called by gRPC to try to resolve the target name again. - // It's just a hint, resolver can ignore this if it's not necessary. + // ResolveNow will be called by gRPC to try to resolve the target name + // again. It's just a hint, resolver can ignore this if it's not necessary. + // + // It could be called multiple times concurrently. ResolveNow(ResolveNowOption) // Close closes the resolver. Close() diff --git a/vendor/google.golang.org/grpc/resolver_conn_wrapper.go b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go index c07e174a8..ef5d4c286 100644 --- a/vendor/google.golang.org/grpc/resolver_conn_wrapper.go +++ b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go @@ -61,12 +61,18 @@ func parseTarget(target string) (ret resolver.Target) { // newCCResolverWrapper parses cc.target for scheme and gets the resolver // builder for this scheme. It then builds the resolver and starts the // monitoring goroutine for it. +// +// If withResolverBuilder dial option is set, the specified resolver will be +// used instead. func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme) - rb := resolver.Get(cc.parsedTarget.Scheme) + rb := cc.dopts.resolverBuilder if rb == nil { - return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + rb = resolver.Get(cc.parsedTarget.Scheme) + if rb == nil { + return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + } } ccr := &ccResolverWrapper{ @@ -77,14 +83,19 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { } var err error - ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{}) + ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{ + UserOptions: cc.dopts.resolverBuildUserOptions, + }) if err != nil { return nil, err } - go ccr.watcher() return ccr, nil } +func (ccr *ccResolverWrapper) start() { + go ccr.watcher() +} + // watcher processes address updates and service config updates sequencially. // Otherwise, we need to resolve possible races between address and service // config (e.g. they specify different balancer types). @@ -119,6 +130,10 @@ func (ccr *ccResolverWrapper) watcher() { } } +func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOption) { + ccr.resolver.ResolveNow(o) +} + func (ccr *ccResolverWrapper) close() { ccr.resolver.Close() close(ccr.done) diff --git a/vendor/google.golang.org/grpc/resolver_test.go b/vendor/google.golang.org/grpc/resolver_test.go new file mode 100644 index 000000000..6aba13c1d --- /dev/null +++ b/vendor/google.golang.org/grpc/resolver_test.go @@ -0,0 +1,99 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "strings" + "testing" + "time" + + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/test/leakcheck" +) + +func TestResolverServiceConfigBeforeAddressNotPanic(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure()) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + // SwitchBalancer before NewAddress. There was no balancer created, this + // makes sure we don't call close on nil balancerWrapper. + r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) // This should not panic. + + time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn. +} + +func TestResolverEmptyUpdateNotPanic(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure()) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + + // This make sure we don't create addrConn with empty address list. + r.NewAddress([]resolver.Address{}) // This should not panic. + + time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn. +} + +var ( + errTestResolverFailBuild = fmt.Errorf("test resolver build error") +) + +type testResolverFailBuilder struct { + buildOpt resolver.BuildOption +} + +func (r *testResolverFailBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + r.buildOpt = opts + return nil, errTestResolverFailBuild +} +func (r *testResolverFailBuilder) Scheme() string { + return "testResolverFailBuilderScheme" +} + +// Tests that options in WithResolverUserOptions are passed to resolver.Build(). +func TestResolverUserOptions(t *testing.T) { + r := &testResolverFailBuilder{} + + userOpt := "testUserOpt" + _, err := Dial("scheme:///test.server", WithInsecure(), + withResolverBuilder(r), + WithResolverUserOptions(userOpt), + ) + if err == nil || !strings.Contains(err.Error(), errTestResolverFailBuild.Error()) { + t.Fatalf("Dial with testResolverFailBuilder returns err: %v, want: %v", err, errTestResolverFailBuild) + } + + if r.buildOpt.UserOptions != userOpt { + t.Fatalf("buildOpt.UserOptions = %T %+v, want %v", r.buildOpt.UserOptions, r.buildOpt.UserOptions, userOpt) + } +} diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go index 1881d3dd6..bf384b644 100644 --- a/vendor/google.golang.org/grpc/rpc_util.go +++ b/vendor/google.golang.org/grpc/rpc_util.go @@ -293,10 +293,10 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt return pf, nil, nil } if int64(length) > int64(maxInt) { - return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) + return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) } if int(length) > maxReceiveMessageSize { - return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) + return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: @@ -326,7 +326,7 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa var err error b, err = c.Marshal(msg) if err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if outPayload != nil { outPayload.Payload = msg @@ -340,20 +340,20 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa if compressor != nil { z, _ := compressor.Compress(cbuf) if _, err := z.Write(b); err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } z.Close() } else { // If Compressor is not set by UseCompressor, use default Compressor if err := cp.Do(cbuf, b); err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } } b = cbuf.Bytes() } } if uint(len(b)) > math.MaxUint32 { - return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } bufHeader := make([]byte, payloadLen+sizeLen) @@ -409,26 +409,26 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if dc != nil { d, err = dc.Do(bytes.NewReader(d)) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } else { dcReader, err := compressor.Decompress(bytes.NewReader(d)) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } d, err = ioutil.ReadAll(dcReader) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } } if len(d) > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) } if err := c.Unmarshal(d, m); err != nil { - return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } if inPayload != nil { inPayload.RecvTime = time.Now() @@ -441,9 +441,7 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ } type rpcInfo struct { - failfast bool - bytesSent bool - bytesReceived bool + failfast bool } type rpcInfoContextKey struct{} @@ -457,18 +455,10 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { return } -func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { - if ss, ok := rpcInfoFromContext(ctx); ok { - ss.bytesReceived = s.bytesReceived - ss.bytesSent = s.bytesSent - } - return -} - // Code returns the error code for err if it was produced by the rpc system. // Otherwise, it returns codes.Unknown. // -// Deprecated; use status.FromError and Code method instead. +// Deprecated: use status.FromError and Code method instead. func Code(err error) codes.Code { if s, ok := status.FromError(err); ok { return s.Code() @@ -479,7 +469,7 @@ func Code(err error) codes.Code { // ErrorDesc returns the error description of err if it was produced by the rpc system. // Otherwise, it returns err.Error() or empty string when err is nil. // -// Deprecated; use status.FromError and Message method instead. +// Deprecated: use status.FromError and Message method instead. func ErrorDesc(err error) string { if s, ok := status.FromError(err); ok { return s.Message() @@ -490,7 +480,7 @@ func ErrorDesc(err error) string { // Errorf returns an error containing an error code and a description; // Errorf returns nil if c is OK. // -// Deprecated; use status.Errorf instead. +// Deprecated: use status.Errorf instead. func Errorf(c codes.Code, format string, a ...interface{}) error { return status.Errorf(c, format, a...) } @@ -510,6 +500,6 @@ const ( ) // Version is the current grpc version. -const Version = "1.8.2" +const Version = "1.9.1" const grpcUA = "grpc-go/" + Version diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index e9737fc49..f65162168 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -92,11 +92,7 @@ type Server struct { conns map[io.Closer]bool serve bool drain bool - ctx context.Context - cancel context.CancelFunc - // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished - // and all the transport goes away. - cv *sync.Cond + cv *sync.Cond // signaled when connections close for GracefulStop m map[string]*service // service name -> service info events trace.EventLog @@ -104,6 +100,7 @@ type Server struct { done chan struct{} quitOnce sync.Once doneOnce sync.Once + serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop } type options struct { @@ -343,7 +340,6 @@ func NewServer(opt ...ServerOption) *Server { done: make(chan struct{}), } s.cv = sync.NewCond(&s.mu) - s.ctx, s.cancel = context.WithCancel(context.Background()) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -474,10 +470,23 @@ func (s *Server) Serve(lis net.Listener) error { s.printf("serving") s.serve = true if s.lis == nil { + // Serve called after Stop or GracefulStop. s.mu.Unlock() lis.Close() return ErrServerStopped } + + s.serveWG.Add(1) + defer func() { + s.serveWG.Done() + select { + // Stop or GracefulStop called; block until done and return nil. + case <-s.quit: + <-s.done + default: + } + }() + s.lis[lis] = true s.mu.Unlock() defer func() { @@ -511,33 +520,39 @@ func (s *Server) Serve(lis net.Listener) error { timer := time.NewTimer(tempDelay) select { case <-timer.C: - case <-s.ctx.Done(): + case <-s.quit: + timer.Stop() + return nil } - timer.Stop() continue } s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() - // If Stop or GracefulStop is called, block until they are done and return nil select { case <-s.quit: - <-s.done return nil default: } return err } tempDelay = 0 - // Start a new goroutine to deal with rawConn - // so we don't stall this Accept loop goroutine. - go s.handleRawConn(rawConn) + // Start a new goroutine to deal with rawConn so we don't stall this Accept + // loop goroutine. + // + // Make sure we account for the goroutine so GracefulStop doesn't nil out + // s.conns before this conn can be added. + s.serveWG.Add(1) + go func() { + s.handleRawConn(rawConn) + s.serveWG.Done() + }() } } -// handleRawConn is run in its own goroutine and handles a just-accepted -// connection that has not had any I/O performed on it yet. +// handleRawConn forks a goroutine to handle a just-accepted connection that +// has not had any I/O performed on it yet. func (s *Server) handleRawConn(rawConn net.Conn) { rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) conn, authInfo, err := s.useTransportAuthenticator(rawConn) @@ -562,17 +577,28 @@ func (s *Server) handleRawConn(rawConn net.Conn) { } s.mu.Unlock() + var serve func() + c := conn.(io.Closer) if s.opts.useHandlerImpl { - rawConn.SetDeadline(time.Time{}) - s.serveUsingHandler(conn) + serve = func() { s.serveUsingHandler(conn) } } else { + // Finish handshaking (HTTP2) st := s.newHTTP2Transport(conn, authInfo) if st == nil { return } - rawConn.SetDeadline(time.Time{}) - s.serveStreams(st) + c = st + serve = func() { s.serveStreams(st) } + } + + rawConn.SetDeadline(time.Time{}) + if !s.addConn(c) { + return } + go func() { + serve() + s.removeConn(c) + }() } // newHTTP2Transport sets up a http/2 transport (using the @@ -599,15 +625,10 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) return nil } - if !s.addConn(st) { - st.Close() - return nil - } return st } func (s *Server) serveStreams(st transport.ServerTransport) { - defer s.removeConn(st) defer st.Close() var wg sync.WaitGroup st.HandleStreams(func(stream *transport.Stream) { @@ -641,11 +662,6 @@ var _ http.Handler = (*Server)(nil) // // conn is the *tls.Conn that's already been authenticated. func (s *Server) serveUsingHandler(conn net.Conn) { - if !s.addConn(conn) { - conn.Close() - return - } - defer s.removeConn(conn) h2s := &http2.Server{ MaxConcurrentStreams: s.opts.maxConcurrentStreams, } @@ -685,7 +701,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if !s.addConn(st) { - st.Close() return } defer s.removeConn(st) @@ -715,9 +730,15 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil || s.drain { + if s.conns == nil { + c.Close() return false } + if s.drain { + // Transport added after we drained our existing conns: drain it + // immediately. + c.(transport.ServerTransport).Drain() + } s.conns[c] = true return true } @@ -826,7 +847,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } if err == io.ErrUnexpectedEOF { - err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { if st, ok := status.FromError(err); ok { @@ -868,13 +889,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if dc != nil { req, err = dc.Do(bytes.NewReader(req)) if err != nil { - return Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } } else { tmp, _ := decomp.Decompress(bytes.NewReader(req)) req, err = ioutil.ReadAll(tmp) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } } @@ -1158,6 +1179,7 @@ func (s *Server) Stop() { }) defer func() { + s.serveWG.Wait() s.doneOnce.Do(func() { close(s.done) }) @@ -1180,7 +1202,6 @@ func (s *Server) Stop() { } s.mu.Lock() - s.cancel() if s.events != nil { s.events.Finish() s.events = nil @@ -1203,21 +1224,27 @@ func (s *Server) GracefulStop() { }() s.mu.Lock() - defer s.mu.Unlock() if s.conns == nil { + s.mu.Unlock() return } for lis := range s.lis { lis.Close() } s.lis = nil - s.cancel() if !s.drain { for c := range s.conns { c.(transport.ServerTransport).Drain() } s.drain = true } + + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. + s.mu.Unlock() + s.serveWG.Wait() + s.mu.Lock() + for len(s.conns) != 0 { s.cv.Wait() } @@ -1226,6 +1253,7 @@ func (s *Server) GracefulStop() { s.events.Finish() s.events = nil } + s.mu.Unlock() } func init() { @@ -1246,7 +1274,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { } stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetHeader(md) } @@ -1256,7 +1284,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { func SendHeader(ctx context.Context, md metadata.MD) error { stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } t := stream.ServerTransport() if t == nil { @@ -1276,7 +1304,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { } stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetTrailer(md) } diff --git a/vendor/google.golang.org/grpc/server_test.go b/vendor/google.golang.org/grpc/server_test.go index cd2f2c083..4a7a52440 100644 --- a/vendor/google.golang.org/grpc/server_test.go +++ b/vendor/google.golang.org/grpc/server_test.go @@ -49,7 +49,7 @@ func TestStopBeforeServe(t *testing.T) { // server.Serve is responsible for closing the listener, even if the // server was already stopped. err = lis.Close() - if got, want := ErrorDesc(err), "use of closed"; !strings.Contains(got, want) { + if got, want := errorDesc(err), "use of closed"; !strings.Contains(got, want) { t.Errorf("Close() error = %q, want %q", got, want) } } diff --git a/vendor/google.golang.org/grpc/service_config.go b/vendor/google.golang.org/grpc/service_config.go index cde648334..53fa88f37 100644 --- a/vendor/google.golang.org/grpc/service_config.go +++ b/vendor/google.golang.org/grpc/service_config.go @@ -20,6 +20,9 @@ package grpc import ( "encoding/json" + "fmt" + "strconv" + "strings" "time" "google.golang.org/grpc/grpclog" @@ -70,12 +73,48 @@ type ServiceConfig struct { Methods map[string]MethodConfig } -func parseTimeout(t *string) (*time.Duration, error) { - if t == nil { +func parseDuration(s *string) (*time.Duration, error) { + if s == nil { return nil, nil } - d, err := time.ParseDuration(*t) - return &d, err + if !strings.HasSuffix(*s, "s") { + return nil, fmt.Errorf("malformed duration %q", *s) + } + ss := strings.SplitN((*s)[:len(*s)-1], ".", 3) + if len(ss) > 2 { + return nil, fmt.Errorf("malformed duration %q", *s) + } + // hasDigits is set if either the whole or fractional part of the number is + // present, since both are optional but one is required. + hasDigits := false + var d time.Duration + if len(ss[0]) > 0 { + i, err := strconv.ParseInt(ss[0], 10, 32) + if err != nil { + return nil, fmt.Errorf("malformed duration %q: %v", *s, err) + } + d = time.Duration(i) * time.Second + hasDigits = true + } + if len(ss) == 2 && len(ss[1]) > 0 { + if len(ss[1]) > 9 { + return nil, fmt.Errorf("malformed duration %q", *s) + } + f, err := strconv.ParseInt(ss[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("malformed duration %q: %v", *s, err) + } + for i := 9; i > len(ss[1]); i-- { + f *= 10 + } + d += time.Duration(f) + hasDigits = true + } + if !hasDigits { + return nil, fmt.Errorf("malformed duration %q", *s) + } + + return &d, nil } type jsonName struct { @@ -128,7 +167,7 @@ func parseServiceConfig(js string) (ServiceConfig, error) { if m.Name == nil { continue } - d, err := parseTimeout(m.Timeout) + d, err := parseDuration(m.Timeout) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) return ServiceConfig{}, err @@ -182,18 +221,6 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { return doptMax } -func newBool(b bool) *bool { - return &b -} - func newInt(b int) *int { return &b } - -func newDuration(b time.Duration) *time.Duration { - return &b -} - -func newString(b string) *string { - return &b -} diff --git a/vendor/google.golang.org/grpc/service_config_test.go b/vendor/google.golang.org/grpc/service_config_test.go index 7e985457e..8301a5061 100644 --- a/vendor/google.golang.org/grpc/service_config_test.go +++ b/vendor/google.golang.org/grpc/service_config_test.go @@ -19,6 +19,8 @@ package grpc import ( + "fmt" + "math" "reflect" "testing" "time" @@ -321,3 +323,64 @@ func TestPraseMsgSize(t *testing.T) { } } } + +func TestParseDuration(t *testing.T) { + testCases := []struct { + s *string + want *time.Duration + err bool + }{ + {s: nil, want: nil}, + {s: newString("1s"), want: newDuration(time.Second)}, + {s: newString("-1s"), want: newDuration(-time.Second)}, + {s: newString("1.1s"), want: newDuration(1100 * time.Millisecond)}, + {s: newString("1.s"), want: newDuration(time.Second)}, + {s: newString("1.0s"), want: newDuration(time.Second)}, + {s: newString(".002s"), want: newDuration(2 * time.Millisecond)}, + {s: newString(".002000s"), want: newDuration(2 * time.Millisecond)}, + {s: newString("0.003s"), want: newDuration(3 * time.Millisecond)}, + {s: newString("0.000004s"), want: newDuration(4 * time.Microsecond)}, + {s: newString("5000.000000009s"), want: newDuration(5000*time.Second + 9*time.Nanosecond)}, + {s: newString("4999.999999999s"), want: newDuration(5000*time.Second - time.Nanosecond)}, + {s: newString("1"), err: true}, + {s: newString("s"), err: true}, + {s: newString(".s"), err: true}, + {s: newString("1 s"), err: true}, + {s: newString(" 1s"), err: true}, + {s: newString("1ms"), err: true}, + {s: newString("1.1.1s"), err: true}, + {s: newString("Xs"), err: true}, + {s: newString("as"), err: true}, + {s: newString(".0000000001s"), err: true}, + {s: newString(fmt.Sprint(math.MaxInt32) + "s"), want: newDuration(math.MaxInt32 * time.Second)}, + {s: newString(fmt.Sprint(int64(math.MaxInt32)+1) + "s"), err: true}, + } + for _, tc := range testCases { + got, err := parseDuration(tc.s) + if tc.err != (err != nil) || + (got == nil) != (tc.want == nil) || + (got != nil && *got != *tc.want) { + wantErr := "<nil>" + if tc.err { + wantErr = "<non-nil error>" + } + s := "<nil>" + if tc.s != nil { + s = `&"` + *tc.s + `"` + } + t.Errorf("parseDuration(%v) = %v, %v; want %v, %v", s, got, err, tc.want, wantErr) + } + } +} + +func newBool(b bool) *bool { + return &b +} + +func newDuration(b time.Duration) *time.Duration { + return &b +} + +func newString(b string) *string { + return &b +} diff --git a/vendor/google.golang.org/grpc/stats/stats_test.go b/vendor/google.golang.org/grpc/stats/stats_test.go index 0df23f3dc..fef0d7c65 100644 --- a/vendor/google.golang.org/grpc/stats/stats_test.go +++ b/vendor/google.golang.org/grpc/stats/stats_test.go @@ -1,3 +1,5 @@ +// +build go1.7 + /* * * Copyright 2016 gRPC authors. @@ -64,10 +66,10 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* md, ok := metadata.FromIncomingContext(ctx) if ok { if err := grpc.SendHeader(ctx, md); err != nil { - return nil, status.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err) + return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err) } if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { - return nil, status.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err) + return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err) } } @@ -82,7 +84,7 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { - return status.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) } stream.SetTrailer(testTrailerMetadata) } @@ -110,7 +112,7 @@ func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCall md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { - return status.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) } stream.SetTrailer(testTrailerMetadata) } @@ -134,7 +136,7 @@ func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.Te md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { - return status.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) } stream.SetTrailer(testTrailerMetadata) } @@ -646,7 +648,14 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) { if st.EndTime.IsZero() { t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime) } - if grpc.Code(st.Error) != grpc.Code(e.err) || grpc.ErrorDesc(st.Error) != grpc.ErrorDesc(e.err) { + + actual, ok := status.FromError(st.Error) + if !ok { + t.Fatalf("expected st.Error to be a statusError, got %T", st.Error) + } + + expectedStatus, _ := status.FromError(e.err) + if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() { t.Fatalf("st.Error = %v, want %v", st.Error, e.err) } } diff --git a/vendor/google.golang.org/grpc/status/status.go b/vendor/google.golang.org/grpc/status/status.go index 871dc4b31..d9defaebc 100644 --- a/vendor/google.golang.org/grpc/status/status.go +++ b/vendor/google.golang.org/grpc/status/status.go @@ -125,8 +125,8 @@ func FromError(err error) (s *Status, ok bool) { if err == nil { return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true } - if s, ok := err.(*statusError); ok { - return s.status(), true + if se, ok := err.(*statusError); ok { + return se.status(), true } return nil, false } @@ -166,3 +166,16 @@ func (s *Status) Details() []interface{} { } return details } + +// Code returns the Code of the error if it is a Status error, codes.OK if err +// is nil, or codes.Unknown otherwise. +func Code(err error) codes.Code { + // Don't use FromError to avoid allocation of OK status. + if err == nil { + return codes.OK + } + if se, ok := err.(*statusError); ok { + return se.status().Code() + } + return codes.Unknown +} diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 9eeaafef8..f91381995 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -163,7 +163,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if ct != encoding.Identity { comp = encoding.GetCompressor(ct) if comp == nil { - return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) + return nil, status.Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) } } } else if cc.dopts.cp != nil { @@ -232,7 +232,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { if done != nil { - done(balancer.DoneInfo{Err: err}) + doneInfo := balancer.DoneInfo{Err: err} + if _, ok := err.(transport.ConnectionError); ok { + // If error is connection error, transport was sending data on wire, + // and we are not sure if anything has been sent on wire. + // If error is not connection error, we are sure nothing has been sent. + doneInfo.BytesSent = true + } + done(doneInfo) done = nil } // In the event of any error from NewStream, we never attempted to write @@ -393,10 +400,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { return err } if cs.c.maxSendMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } if len(data) > *cs.c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false}) if err == nil && outPayload != nil { @@ -414,7 +421,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } } if cs.c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } if !cs.decompSet { // Block until we receive headers containing received message encoding. @@ -456,7 +463,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. if cs.c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) cs.closeTransportStream(err) @@ -529,11 +536,11 @@ func (cs *clientStream) finish(err error) { o.after(cs.c) } if cs.done != nil { - updateRPCInfoInContext(cs.s.Context(), rpcInfo{ - bytesSent: true, - bytesReceived: cs.s.BytesReceived(), + cs.done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: cs.s.BytesReceived(), }) - cs.done(balancer.DoneInfo{Err: err}) cs.done = nil } if cs.statsHandler != nil { @@ -653,7 +660,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { return err } if len(data) > ss.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { return toRPCErr(err) @@ -693,7 +700,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { return err } if err == io.ErrUnexpectedEOF { - err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } return toRPCErr(err) } diff --git a/vendor/google.golang.org/grpc/transport/control.go b/vendor/google.golang.org/grpc/transport/control.go index 63194830d..0474b0907 100644 --- a/vendor/google.golang.org/grpc/transport/control.go +++ b/vendor/google.golang.org/grpc/transport/control.go @@ -116,6 +116,7 @@ type goAway struct { func (*goAway) item() {} type flushIO struct { + closeTr bool } func (*flushIO) item() {} diff --git a/vendor/google.golang.org/grpc/transport/go16.go b/vendor/google.golang.org/grpc/transport/go16.go index 7cffee11e..5babcf9b8 100644 --- a/vendor/google.golang.org/grpc/transport/go16.go +++ b/vendor/google.golang.org/grpc/transport/go16.go @@ -22,6 +22,7 @@ package transport import ( "net" + "net/http" "google.golang.org/grpc/codes" @@ -43,3 +44,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a background context. +func contextFromRequest(r *http.Request) context.Context { + return context.Background() +} diff --git a/vendor/google.golang.org/grpc/transport/go17.go b/vendor/google.golang.org/grpc/transport/go17.go index 2464e69fa..b7fa6bdb9 100644 --- a/vendor/google.golang.org/grpc/transport/go17.go +++ b/vendor/google.golang.org/grpc/transport/go17.go @@ -23,6 +23,7 @@ package transport import ( "context" "net" + "net/http" "google.golang.org/grpc/codes" @@ -44,3 +45,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a context from the HTTP Request. +func contextFromRequest(r *http.Request) context.Context { + return r.Context() +} diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go index 7e0fdb359..27c4ebb5f 100644 --- a/vendor/google.golang.org/grpc/transport/handler_server.go +++ b/vendor/google.golang.org/grpc/transport/handler_server.go @@ -284,12 +284,12 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { // With this transport type there will be exactly 1 stream: this HTTP request. - var ctx context.Context + ctx := contextFromRequest(ht.req) var cancel context.CancelFunc if ht.timeoutSet { - ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + ctx, cancel = context.WithTimeout(ctx, ht.timeout) } else { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(ctx) } // requestOver is closed when either the request's context is done diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go index 0f58a390a..4a122692a 100644 --- a/vendor/google.golang.org/grpc/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -20,6 +20,7 @@ package transport import ( "bytes" + "fmt" "io" "math" "net" @@ -93,6 +94,11 @@ type http2Client struct { bdpEst *bdpEstimator outQuotaVersion uint32 + // onSuccess is a callback that client transport calls upon + // receiving server preface to signal that a succefull HTTP2 + // connection was established. + onSuccess func() + mu sync.Mutex // guard the following variables state transportState // the state of underlying connection activeStreams map[uint32]*Stream @@ -145,16 +151,12 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) { +func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ ClientTransport, err error) { scheme := "http" ctx, cancel := context.WithCancel(ctx) - connectCtx, connectCancel := context.WithTimeout(ctx, timeout) defer func() { if err != nil { cancel() - // Don't call connectCancel in success path due to a race in Go 1.6: - // https://github.com/golang/go/issues/15078. - connectCancel() } }() @@ -240,6 +242,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t kp: kp, statsHandler: opts.StatsHandler, initialWindowSize: initialWindowSize, + onSuccess: onSuccess, } if opts.InitialWindowSize >= defaultWindowSize { t.initialWindowSize = opts.InitialWindowSize @@ -300,7 +303,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t t.framer.writer.Flush() go func() { loopyWriter(t.ctx, t.controlBuf, t.itemHandler) - t.Close() + t.conn.Close() }() if t.kp.Time != infinity { go t.keepalive() @@ -1122,7 +1125,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { s.mu.Unlock() return } - if len(state.mdata) > 0 { s.trailer = state.mdata } @@ -1160,6 +1162,7 @@ func (t *http2Client) reader() { t.Close() return } + t.onSuccess() t.handleSettings(sf, true) // loop to keep reading incoming messages on this transport. @@ -1234,8 +1237,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) // is duplicated between the client and the server. // The transport layer needs to be refactored to take care of this. -func (t *http2Client) itemHandler(i item) error { - var err error +func (t *http2Client) itemHandler(i item) (err error) { defer func() { if err != nil { errorf(" error in itemHandler: %v", err) @@ -1243,10 +1245,11 @@ func (t *http2Client) itemHandler(i item) error { }() switch i := i.(type) { case *dataFrame: - err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) - if err == nil { - i.f() + if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { + return err } + i.f() + return nil case *headerFrame: t.hBuf.Reset() for _, f := range i.hf { @@ -1280,31 +1283,33 @@ func (t *http2Client) itemHandler(i item) error { return err } } + return nil case *windowUpdate: - err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: - err = t.framer.fr.WriteSettings(i.ss...) + return t.framer.fr.WriteSettings(i.ss...) case *settingsAck: - err = t.framer.fr.WriteSettingsAck() + return t.framer.fr.WriteSettingsAck() case *resetStream: // If the server needs to be to intimated about stream closing, // then we need to make sure the RST_STREAM frame is written to // the wire before the headers of the next stream waiting on // streamQuota. We ensure this by adding to the streamsQuota pool // only after having acquired the writableChan to send RST_STREAM. - err = t.framer.fr.WriteRSTStream(i.streamID, i.code) + err := t.framer.fr.WriteRSTStream(i.streamID, i.code) t.streamsQuota.add(1) + return err case *flushIO: - err = t.framer.writer.Flush() + return t.framer.writer.Flush() case *ping: if !i.ack { t.bdpEst.timesnap(i.data) } - err = t.framer.fr.WritePing(i.ack, i.data) + return t.framer.fr.WritePing(i.ack, i.data) default: errorf("transport: http2Client.controller got unexpected item type %v", i) + return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i) } - return err } // keepalive running in a separate goroutune makes sure the connection is alive by sending pings. diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go index 4a95363cc..6d252c53a 100644 --- a/vendor/google.golang.org/grpc/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/transport/http2_server.go @@ -228,6 +228,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } t.framer.writer.Flush() + defer func() { + if err != nil { + t.Close() + } + }() + // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { @@ -239,8 +245,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err frame, err := t.framer.fr.ReadFrame() if err == io.EOF || err == io.ErrUnexpectedEOF { - t.Close() - return + return nil, err } if err != nil { return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) @@ -254,7 +259,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err go func() { loopyWriter(t.ctx, t.controlBuf, t.itemHandler) - t.Close() + t.conn.Close() }() go t.keepalive() return t, nil @@ -1069,6 +1074,9 @@ func (t *http2Server) itemHandler(i item) error { if !i.headsUp { // Stop accepting more streams now. t.state = draining + if len(t.activeStreams) == 0 { + i.closeConn = true + } t.mu.Unlock() if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { return err @@ -1076,8 +1084,7 @@ func (t *http2Server) itemHandler(i item) error { if i.closeConn { // Abruptly close the connection following the GoAway (via // loopywriter). But flush out what's inside the buffer first. - t.framer.writer.Flush() - return fmt.Errorf("transport: Connection closing") + t.controlBuf.put(&flushIO{closeTr: true}) } return nil } @@ -1107,7 +1114,13 @@ func (t *http2Server) itemHandler(i item) error { }() return nil case *flushIO: - return t.framer.writer.Flush() + if err := t.framer.writer.Flush(); err != nil { + return err + } + if i.closeTr { + return ErrConnClosing + } + return nil case *ping: if !i.ack { t.bdpEst.timesnap(i.data) @@ -1155,7 +1168,7 @@ func (t *http2Server) closeStream(s *Stream) { t.idle = time.Now() } if t.state == draining && len(t.activeStreams) == 0 { - defer t.Close() + defer t.controlBuf.put(&flushIO{closeTr: true}) } t.mu.Unlock() // In case stream sending and receiving are invoked in separate diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go index b7a5dbe42..2e7bcaeaa 100644 --- a/vendor/google.golang.org/grpc/transport/transport.go +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -26,7 +26,6 @@ import ( "io" "net" "sync" - "time" "golang.org/x/net/context" "golang.org/x/net/http2" @@ -249,11 +248,28 @@ type Stream struct { unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream } +func (s *Stream) waitOnHeader() error { + if s.headerChan == nil { + // On the server headerChan is always nil since a stream originates + // only after having received headers. + return nil + } + wc := s.waiters + select { + case <-wc.ctx.Done(): + return ContextErr(wc.ctx.Err()) + case <-wc.goAway: + return errStreamDrain + case <-s.headerChan: + return nil + } +} + // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. func (s *Stream) RecvCompress() string { - if s.headerChan != nil { - <-s.headerChan + if err := s.waitOnHeader(); err != nil { + return "" } return s.recvCompress } @@ -279,15 +295,7 @@ func (s *Stream) GoAway() <-chan struct{} { // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is canceled/expired. func (s *Stream) Header() (metadata.MD, error) { - var err error - select { - case <-s.ctx.Done(): - err = ContextErr(s.ctx.Err()) - case <-s.goAway: - err = errStreamDrain - case <-s.headerChan: - return s.header.Copy(), nil - } + err := s.waitOnHeader() // Even if the stream is closed, header is returned if available. select { case <-s.headerChan: @@ -506,8 +514,8 @@ type TargetInfo struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) { - return newHTTP2Client(ctx, target, opts, timeout) +func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onSuccess func()) (ClientTransport, error) { + return newHTTP2Client(connectCtx, ctx, target, opts, onSuccess) } // Options provides additional hints and information for message diff --git a/vendor/google.golang.org/grpc/transport/transport_test.go b/vendor/google.golang.org/grpc/transport/transport_test.go index 3e29c9179..8c004ec6b 100644 --- a/vendor/google.golang.org/grpc/transport/transport_test.go +++ b/vendor/google.golang.org/grpc/transport/transport_test.go @@ -361,8 +361,10 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second) + connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + ct, connErr = NewClientTransport(connectCtx, context.Background(), target, copts, func() {}) if connErr != nil { + cancel() // Do not cancel in success path. t.Fatalf("failed to create transport: %v", connErr) } return server, ct @@ -384,8 +386,10 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } done <- conn }() - tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second) + connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + tr, err := NewClientTransport(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, func() {}) if err != nil { + cancel() // Do not cancel in success path. // Server clean-up. lis.Close() if conn, ok := <-done; ok { @@ -685,7 +689,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ Time: 50 * time.Millisecond, - Timeout: 50 * time.Millisecond, + Timeout: 1 * time.Second, PermitWithoutStream: true, }, } @@ -693,7 +697,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { defer server.stop() defer client.Close() - timeout := time.NewTimer(2 * time.Second) + timeout := time.NewTimer(10 * time.Second) select { case <-client.GoAway(): if !timeout.Stop() { @@ -720,7 +724,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ Time: 50 * time.Millisecond, - Timeout: 50 * time.Millisecond, + Timeout: 1 * time.Second, }, } server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) @@ -730,7 +734,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil { t.Fatalf("Client failed to create stream.") } - timeout := time.NewTimer(2 * time.Second) + timeout := time.NewTimer(10 * time.Second) select { case <-client.GoAway(): if !timeout.Stop() { @@ -758,7 +762,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ Time: 101 * time.Millisecond, - Timeout: 50 * time.Millisecond, + Timeout: 1 * time.Second, PermitWithoutStream: true, }, } @@ -767,7 +771,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { defer client.Close() // Give keepalive enough time. - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) // Assert that connection is healthy. ct := client.(*http2Client) ct.mu.Lock() @@ -786,7 +790,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ Time: 101 * time.Millisecond, - Timeout: 50 * time.Millisecond, + Timeout: 1 * time.Second, }, } server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) @@ -798,7 +802,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { } // Give keepalive enough time. - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) // Assert that connection is healthy. ct := client.(*http2Client) ct.mu.Lock() @@ -1360,7 +1364,15 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { } func TestServerWithMisbehavedClient(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, suspended) + serverConfig := &ServerConfig{ + InitialWindowSize: defaultWindowSize, + InitialConnWindowSize: defaultWindowSize, + } + connectOptions := ConnectOptions{ + InitialWindowSize: defaultWindowSize, + InitialConnWindowSize: defaultWindowSize, + } + server, ct := setUpWithOptions(t, 0, serverConfig, suspended, connectOptions) callHdr := &CallHdr{ Host: "localhost", Method: "foo", @@ -2091,8 +2103,10 @@ func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream wh: wh, } server.start(t, lis) - client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second) + connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + client, err = newHTTP2Client(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}) if err != nil { + cancel() // Do not cancel in success path. t.Fatalf("Error creating client. Err: %v", err) } stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true}) diff --git a/vendor/google.golang.org/grpc/vet.sh b/vendor/google.golang.org/grpc/vet.sh index 02d4bae39..2ad94fed9 100755 --- a/vendor/google.golang.org/grpc/vet.sh +++ b/vendor/google.golang.org/grpc/vet.sh @@ -23,8 +23,7 @@ if [ "$1" = "-install" ]; then golang.org/x/tools/cmd/goimports \ honnef.co/go/tools/cmd/staticcheck \ github.com/client9/misspell/cmd/misspell \ - github.com/golang/protobuf/protoc-gen-go \ - golang.org/x/tools/cmd/stringer + github.com/golang/protobuf/protoc-gen-go if [[ "$check_proto" = "true" ]]; then if [[ "$TRAVIS" = "true" ]]; then PROTOBUF_VERSION=3.3.0 @@ -52,7 +51,7 @@ fi git ls-files "*.go" | xargs grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)\|DO NOT EDIT" 2>&1 | tee /dev/stderr | (! read) gofmt -s -d -l . 2>&1 | tee /dev/stderr | (! read) goimports -l . 2>&1 | tee /dev/stderr | (! read) -golint ./... 2>&1 | (grep -vE "(_mock|_string|\.pb)\.go:" || true) | tee /dev/stderr | (! read) +golint ./... 2>&1 | (grep -vE "(_mock|\.pb)\.go:" || true) | tee /dev/stderr | (! read) # Undo any edits made by this script. cleanup() { @@ -65,7 +64,7 @@ trap cleanup EXIT git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":' set +o pipefail # TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. -go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read) +go tool vet -all . 2>&1 | grep -vE '(clientconn|transport\/transport_test).go:.*cancel (function|var)' | grep -vF '.pb.go:' | tee /dev/stderr | (! read) set -o pipefail git reset --hard HEAD |