diff options
author | 2017-11-03 00:20:15 -0700 | |
---|---|---|
committer | 2017-11-03 07:20:15 +0000 | |
commit | 1fc0c16968304b169fb7eb52f162ffa4c7cdc55c (patch) | |
tree | e242db88ffc7aa291a6344d1ab1d6cd07293cba7 /vendor/google.golang.org | |
parent | af6086d6535e3309e3bd1e521cb4d51ef113f850 (diff) | |
download | coredns-1fc0c16968304b169fb7eb52f162ffa4c7cdc55c.tar.gz coredns-1fc0c16968304b169fb7eb52f162ffa4c7cdc55c.tar.zst coredns-1fc0c16968304b169fb7eb52f162ffa4c7cdc55c.zip |
Update vendor libraries except client-go, apimachinery and ugorji/go (#1197)
This fix updates vendor libraries except client-go, apimachinery
and ugorji/go, as github.com/ugorji/go/codec is causing compatibilities issues.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'vendor/google.golang.org')
67 files changed, 5430 insertions, 2342 deletions
diff --git a/vendor/google.golang.org/genproto/googleapis/rpc/status/status.pb.go b/vendor/google.golang.org/genproto/googleapis/rpc/status/status.pb.go index 40e79375b..8867ae781 100644 --- a/vendor/google.golang.org/genproto/googleapis/rpc/status/status.pb.go +++ b/vendor/google.golang.org/genproto/googleapis/rpc/status/status.pb.go @@ -45,7 +45,7 @@ const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package // error message is needed, put the localized message in the error details or // localize it in the client. The optional error details may contain arbitrary // information about the error. There is a predefined set of error detail types -// in the package `google.rpc` which can be used for common error conditions. +// in the package `google.rpc` that can be used for common error conditions. // // # Language mapping // @@ -68,7 +68,7 @@ const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package // errors. // // - Workflow errors. A typical workflow has multiple steps. Each step may -// have a `Status` message for error reporting purpose. +// have a `Status` message for error reporting. // // - Batch operations. If a client uses batch request and batch response, the // `Status` message should be used directly inside batch response, one for @@ -87,8 +87,8 @@ type Status struct { // user-facing error message should be localized and sent in the // [google.rpc.Status.details][google.rpc.Status.details] field, or localized by the client. Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` - // A list of messages that carry the error details. There will be a - // common set of message types for APIs to use. + // A list of messages that carry the error details. There is a common set of + // message types for APIs to use. Details []*google_protobuf.Any `protobuf:"bytes,3,rep,name=details" json:"details,omitempty"` } diff --git a/vendor/google.golang.org/grpc/.please-update b/vendor/google.golang.org/grpc/.please-update new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/vendor/google.golang.org/grpc/.please-update diff --git a/vendor/google.golang.org/grpc/.travis.yml b/vendor/google.golang.org/grpc/.travis.yml index 9032f8dcd..22bf25004 100644 --- a/vendor/google.golang.org/grpc/.travis.yml +++ b/vendor/google.golang.org/grpc/.travis.yml @@ -1,20 +1,20 @@ language: go go: - - 1.6.x - 1.7.x - 1.8.x + - 1.9.x + +matrix: + include: + - go: 1.9.x + env: ARCH=386 go_import_path: google.golang.org/grpc before_install: - - if [[ $TRAVIS_GO_VERSION = 1.8* ]]; then go get -u github.com/golang/lint/golint honnef.co/go/tools/cmd/staticcheck; fi - - go get -u golang.org/x/tools/cmd/goimports github.com/axw/gocov/gocov github.com/mattn/goveralls golang.org/x/tools/cmd/cover + - if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$ARCH" != "386" ]]; then ./vet.sh -install || exit 1; fi script: - - '! gofmt -s -d -l . 2>&1 | read' - - '! goimports -l . | read' - - 'if [[ $TRAVIS_GO_VERSION = 1.8* ]]; then ! golint ./... | grep -vE "(_mock|_string|\.pb)\.go:"; fi' - - 'if [[ $TRAVIS_GO_VERSION = 1.8* ]]; then ! go tool vet -all . 2>&1 | grep -vF .pb.go:; fi' # https://github.com/golang/protobuf/issues/214 + - if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$ARCH" != "386" ]]; then ./vet.sh || exit 1; fi - make test testrace - - 'if [[ $TRAVIS_GO_VERSION = 1.8* ]]; then staticcheck -ignore google.golang.org/grpc/transport/transport_test.go:SA2002 ./...; fi' # TODO(menghanl): fix these diff --git a/vendor/google.golang.org/grpc/Makefile b/vendor/google.golang.org/grpc/Makefile index 03bb01f0b..39606b564 100644 --- a/vendor/google.golang.org/grpc/Makefile +++ b/vendor/google.golang.org/grpc/Makefile @@ -20,24 +20,17 @@ proto: echo "error: protoc not installed" >&2; \ exit 1; \ fi - go get -u -v github.com/golang/protobuf/protoc-gen-go - # use $$dir as the root for all proto files in the same directory - for dir in $$(git ls-files '*.proto' | xargs -n1 dirname | uniq); do \ - protoc -I $$dir --go_out=plugins=grpc:$$dir $$dir/*.proto; \ - done + go generate google.golang.org/grpc/... test: testdeps - go test -v -cpu 1,4 google.golang.org/grpc/... + go test -cpu 1,4 google.golang.org/grpc/... testrace: testdeps - go test -v -race -cpu 1,4 google.golang.org/grpc/... + go test -race -cpu 1,4 google.golang.org/grpc/... clean: go clean -i google.golang.org/grpc/... -coverage: testdeps - ./coverage.sh --coveralls - .PHONY: \ all \ deps \ diff --git a/vendor/google.golang.org/grpc/README.md b/vendor/google.golang.org/grpc/README.md index 72c7325cc..622a5dc3e 100644 --- a/vendor/google.golang.org/grpc/README.md +++ b/vendor/google.golang.org/grpc/README.md @@ -10,13 +10,13 @@ Installation To install this package, you need to install Go and setup your Go workspace on your computer. The simplest way to install the library is to run: ``` -$ go get google.golang.org/grpc +$ go get -u google.golang.org/grpc ``` Prerequisites ------------- -This requires Go 1.6 or later. +This requires Go 1.7 or later. Constraints ----------- diff --git a/vendor/google.golang.org/grpc/balancer.go b/vendor/google.golang.org/grpc/balancer.go index cde472c81..ab65049dd 100644 --- a/vendor/google.golang.org/grpc/balancer.go +++ b/vendor/google.golang.org/grpc/balancer.go @@ -395,3 +395,14 @@ func (rr *roundRobin) Close() error { } return nil } + +// pickFirst is used to test multi-addresses in one addrConn in which all addresses share the same addrConn. +// It is a wrapper around roundRobin balancer. The logic of all methods works fine because balancer.Get() +// returns the only address Up by resetTransport(). +type pickFirst struct { + *roundRobin +} + +func pickFirstBalancerV1(r naming.Resolver) Balancer { + return &pickFirst{&roundRobin{r: r}} +} diff --git a/vendor/google.golang.org/grpc/balancer/balancer.go b/vendor/google.golang.org/grpc/balancer/balancer.go new file mode 100644 index 000000000..84e10b630 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer/balancer.go @@ -0,0 +1,206 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package balancer defines APIs for load balancing in gRPC. +// All APIs in this package are experimental. +package balancer + +import ( + "errors" + "net" + + "golang.org/x/net/context" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/resolver" +) + +var ( + // m is a map from name to balancer builder. + m = make(map[string]Builder) + // defaultBuilder is the default balancer to use. + defaultBuilder Builder // TODO(bar) install pickfirst as default. +) + +// Register registers the balancer builder to the balancer map. +// b.Name will be used as the name registered with this builder. +func Register(b Builder) { + m[b.Name()] = b +} + +// Get returns the resolver builder registered with the given name. +// If no builder is register with the name, the default pickfirst will +// be used. +func Get(name string) Builder { + if b, ok := m[name]; ok { + return b + } + return defaultBuilder +} + +// SubConn represents a gRPC sub connection. +// Each sub connection contains a list of addresses. gRPC will +// try to connect to them (in sequence), and stop trying the +// remainder once one connection is successful. +// +// The reconnect backoff will be applied on the list, not a single address. +// For example, try_on_all_addresses -> backoff -> try_on_all_addresses. +// +// All SubConns start in IDLE, and will not try to connect. To trigger +// the connecting, Balancers must call Connect. +// When the connection encounters an error, it will reconnect immediately. +// When the connection becomes IDLE, it will not reconnect unless Connect is +// called. +type SubConn interface { + // UpdateAddresses updates the addresses used in this SubConn. + // gRPC checks if currently-connected address is still in the new list. + // If it's in the list, the connection will be kept. + // If it's not in the list, the connection will gracefully closed, and + // a new connection will be created. + // + // This will trigger a state transition for the SubConn. + UpdateAddresses([]resolver.Address) + // Connect starts the connecting for this SubConn. + Connect() +} + +// NewSubConnOptions contains options to create new SubConn. +type NewSubConnOptions struct{} + +// ClientConn represents a gRPC ClientConn. +type ClientConn interface { + // NewSubConn is called by balancer to create a new SubConn. + // It doesn't block and wait for the connections to be established. + // Behaviors of the SubConn can be controlled by options. + NewSubConn([]resolver.Address, NewSubConnOptions) (SubConn, error) + // RemoveSubConn removes the SubConn from ClientConn. + // The SubConn will be shutdown. + RemoveSubConn(SubConn) + + // UpdateBalancerState is called by balancer to nofity gRPC that some internal + // state in balancer has changed. + // + // gRPC will update the connectivity state of the ClientConn, and will call pick + // on the new picker to pick new SubConn. + UpdateBalancerState(s connectivity.State, p Picker) + + // Target returns the dial target for this ClientConn. + Target() string +} + +// BuildOptions contains additional information for Build. +type BuildOptions struct { + // DialCreds is the transport credential the Balancer implementation can + // use to dial to a remote load balancer server. The Balancer implementations + // can ignore this if it does not need to talk to another party securely. + DialCreds credentials.TransportCredentials + // Dialer is the custom dialer the Balancer implementation can use to dial + // to a remote load balancer server. The Balancer implementations + // can ignore this if it doesn't need to talk to remote balancer. + Dialer func(context.Context, string) (net.Conn, error) +} + +// Builder creates a balancer. +type Builder interface { + // Build creates a new balancer with the ClientConn. + Build(cc ClientConn, opts BuildOptions) Balancer + // Name returns the name of balancers built by this builder. + // It will be used to pick balancers (for example in service config). + Name() string +} + +// PickOptions contains addition information for the Pick operation. +type PickOptions struct{} + +// DoneInfo contains additional information for done. +type DoneInfo struct { + // Err is the rpc error the RPC finished with. It could be nil. + Err error +} + +var ( + // ErrNoSubConnAvailable indicates no SubConn is available for pick(). + // gRPC will block the RPC until a new picker is available via UpdateBalancerState(). + ErrNoSubConnAvailable = errors.New("no SubConn is available") + // ErrTransientFailure indicates all SubConns are in TransientFailure. + // WaitForReady RPCs will block, non-WaitForReady RPCs will fail. + ErrTransientFailure = errors.New("all SubConns are in TransientFailure") +) + +// Picker is used by gRPC to pick a SubConn to send an RPC. +// Balancer is expected to generate a new picker from its snapshot everytime its +// internal state has changed. +// +// The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState(). +type Picker interface { + // Pick returns the SubConn to be used to send the RPC. + // The returned SubConn must be one returned by NewSubConn(). + // + // This functions is expected to return: + // - a SubConn that is known to be READY; + // - ErrNoSubConnAvailable if no SubConn is available, but progress is being + // made (for example, some SubConn is in CONNECTING mode); + // - other errors if no active connecting is happening (for example, all SubConn + // are in TRANSIENT_FAILURE mode). + // + // If a SubConn is returned: + // - If it is READY, gRPC will send the RPC on it; + // - If it is not ready, or becomes not ready after it's returned, gRPC will block + // this call until a new picker is updated and will call pick on the new picker. + // + // If the returned error is not nil: + // - If the error is ErrNoSubConnAvailable, gRPC will block until UpdateBalancerState() + // - If the error is ErrTransientFailure: + // - If the RPC is wait-for-ready, gRPC will block until UpdateBalancerState() + // is called to pick again; + // - Otherwise, RPC will fail with unavailable error. + // - Else (error is other non-nil error): + // - The RPC will fail with unavailable error. + // + // The returned done() function will be called once the rpc has finished, with the + // final status of that RPC. + // done may be nil if balancer doesn't care about the RPC status. + Pick(ctx context.Context, opts PickOptions) (conn SubConn, done func(DoneInfo), err error) +} + +// Balancer takes input from gRPC, manages SubConns, and collects and aggregates +// the connectivity states. +// +// It also generates and updates the Picker used by gRPC to pick SubConns for RPCs. +// +// HandleSubConnectionStateChange, HandleResolvedAddrs and Close are guaranteed +// to be called synchronously from the same goroutine. +// There's no guarantee on picker.Pick, it may be called anytime. +type Balancer interface { + // HandleSubConnStateChange is called by gRPC when the connectivity state + // of sc has changed. + // Balancer is expected to aggregate all the state of SubConn and report + // that back to gRPC. + // Balancer should also generate and update Pickers when its internal state has + // been changed by the new state. + HandleSubConnStateChange(sc SubConn, state connectivity.State) + // HandleResolvedAddrs is called by gRPC to send updated resolved addresses to + // balancers. + // Balancer can create new SubConn or remove SubConn with the addresses. + // An empty address slice and a non-nil error will be passed if the resolver returns + // non-nil error to gRPC. + HandleResolvedAddrs([]resolver.Address, error) + // Close closes the balancer. The balancer is not required to call + // ClientConn.RemoveSubConn for its existing SubConns. + Close() +} diff --git a/vendor/google.golang.org/grpc/balancer_conn_wrappers.go b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go new file mode 100644 index 000000000..e4a95fd5c --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go @@ -0,0 +1,248 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +// scStateUpdate contains the subConn and the new state it changed to. +type scStateUpdate struct { + sc balancer.SubConn + state connectivity.State +} + +// scStateUpdateBuffer is an unbounded channel for scStateChangeTuple. +// TODO make a general purpose buffer that uses interface{}. +type scStateUpdateBuffer struct { + c chan *scStateUpdate + mu sync.Mutex + backlog []*scStateUpdate +} + +func newSCStateUpdateBuffer() *scStateUpdateBuffer { + return &scStateUpdateBuffer{ + c: make(chan *scStateUpdate, 1), + } +} + +func (b *scStateUpdateBuffer) put(t *scStateUpdate) { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.backlog) == 0 { + select { + case b.c <- t: + return + default: + } + } + b.backlog = append(b.backlog, t) +} + +func (b *scStateUpdateBuffer) load() { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog[0] = nil + b.backlog = b.backlog[1:] + default: + } + } +} + +// get returns the channel that receives a recvMsg in the buffer. +// +// Upon receiving, the caller should call load to send another +// scStateChangeTuple onto the channel if there is any. +func (b *scStateUpdateBuffer) get() <-chan *scStateUpdate { + return b.c +} + +// resolverUpdate contains the new resolved addresses or error if there's +// any. +type resolverUpdate struct { + addrs []resolver.Address + err error +} + +// ccBalancerWrapper is a wrapper on top of cc for balancers. +// It implements balancer.ClientConn interface. +type ccBalancerWrapper struct { + cc *ClientConn + balancer balancer.Balancer + stateChangeQueue *scStateUpdateBuffer + resolverUpdateCh chan *resolverUpdate + done chan struct{} +} + +func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { + ccb := &ccBalancerWrapper{ + cc: cc, + stateChangeQueue: newSCStateUpdateBuffer(), + resolverUpdateCh: make(chan *resolverUpdate, 1), + done: make(chan struct{}), + } + go ccb.watcher() + ccb.balancer = b.Build(ccb, bopts) + return ccb +} + +// watcher balancer functions sequencially, so the balancer can be implemeneted +// lock-free. +func (ccb *ccBalancerWrapper) watcher() { + for { + select { + case t := <-ccb.stateChangeQueue.get(): + ccb.stateChangeQueue.load() + ccb.balancer.HandleSubConnStateChange(t.sc, t.state) + case t := <-ccb.resolverUpdateCh: + ccb.balancer.HandleResolvedAddrs(t.addrs, t.err) + case <-ccb.done: + } + + select { + case <-ccb.done: + ccb.balancer.Close() + return + default: + } + } +} + +func (ccb *ccBalancerWrapper) close() { + close(ccb.done) +} + +func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + // When updating addresses for a SubConn, if the address in use is not in + // the new addresses, the old ac will be tearDown() and a new ac will be + // created. tearDown() generates a state change with Shutdown state, we + // don't want the balancer to receive this state change. So before + // tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and + // this function will be called with (nil, Shutdown). We don't need to call + // balancer method in this case. + if sc == nil { + return + } + ccb.stateChangeQueue.put(&scStateUpdate{ + sc: sc, + state: s, + }) +} + +func (ccb *ccBalancerWrapper) handleResolvedAddrs(addrs []resolver.Address, err error) { + select { + case <-ccb.resolverUpdateCh: + default: + } + ccb.resolverUpdateCh <- &resolverUpdate{ + addrs: addrs, + err: err, + } +} + +func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + grpclog.Infof("ccBalancerWrapper: new subconn: %v", addrs) + ac, err := ccb.cc.newAddrConn(addrs) + if err != nil { + return nil, err + } + acbw := &acBalancerWrapper{ac: ac} + ac.acbw = acbw + return acbw, nil +} + +func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { + grpclog.Infof("ccBalancerWrapper: removing subconn") + acbw, ok := sc.(*acBalancerWrapper) + if !ok { + return + } + ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) +} + +func (ccb *ccBalancerWrapper) UpdateBalancerState(s connectivity.State, p balancer.Picker) { + grpclog.Infof("ccBalancerWrapper: updating state and picker called by balancer: %v, %p", s, p) + ccb.cc.csMgr.updateState(s) + ccb.cc.blockingpicker.updatePicker(p) +} + +func (ccb *ccBalancerWrapper) Target() string { + return ccb.cc.target +} + +// acBalancerWrapper is a wrapper on top of ac for balancers. +// It implements balancer.SubConn interface. +type acBalancerWrapper struct { + mu sync.Mutex + ac *addrConn +} + +func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { + grpclog.Infof("acBalancerWrapper: UpdateAddresses called with %v", addrs) + acbw.mu.Lock() + defer acbw.mu.Unlock() + if !acbw.ac.tryUpdateAddrs(addrs) { + cc := acbw.ac.cc + acbw.ac.mu.Lock() + // Set old ac.acbw to nil so the Shutdown state update will be ignored + // by balancer. + // + // TODO(bar) the state transition could be wrong when tearDown() old ac + // and creating new ac, fix the transition. + acbw.ac.acbw = nil + acbw.ac.mu.Unlock() + acState := acbw.ac.getState() + acbw.ac.tearDown(errConnDrain) + + if acState == connectivity.Shutdown { + return + } + + ac, err := cc.newAddrConn(addrs) + if err != nil { + grpclog.Warningf("acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) + return + } + acbw.ac = ac + ac.acbw = acbw + if acState != connectivity.Idle { + ac.connect(false) + } + } +} + +func (acbw *acBalancerWrapper) Connect() { + acbw.mu.Lock() + defer acbw.mu.Unlock() + acbw.ac.connect(false) +} + +func (acbw *acBalancerWrapper) getAddrConn() *addrConn { + acbw.mu.Lock() + defer acbw.mu.Unlock() + return acbw.ac +} diff --git a/vendor/google.golang.org/grpc/balancer_test.go b/vendor/google.golang.org/grpc/balancer_test.go index 4f733a6aa..29dbe0a67 100644 --- a/vendor/google.golang.org/grpc/balancer_test.go +++ b/vendor/google.golang.org/grpc/balancer_test.go @@ -21,13 +21,16 @@ package grpc import ( "fmt" "math" + "strconv" "sync" "testing" "time" "golang.org/x/net/context" "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/naming" + "google.golang.org/grpc/test/leakcheck" ) type testWatcher struct { @@ -55,6 +58,7 @@ func (w *testWatcher) Next() (updates []*naming.Update, err error) { } func (w *testWatcher) Close() { + close(w.side) } // Inject naming resolution updates to the testWatcher. @@ -88,7 +92,7 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { return r.w, nil } -func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver) { +func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver, func()) { var servers []*server for i := 0; i < numServers; i++ { s := newTestServer() @@ -99,18 +103,25 @@ func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, * // Point to server[0] addr := "localhost:" + servers[0].port return servers, &testNameResolver{ - addr: addr, - } + addr: addr, + }, func() { + for i := 0; i < numServers; i++ { + servers[i].stop() + } + } } func TestNameDiscovery(t *testing.T) { + defer leakcheck.Check(t) // Start 2 servers on 2 ports. numServers := 2 - servers, r := startServers(t, numServers, math.MaxUint32) + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() req := "port" var reply string if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { @@ -134,18 +145,17 @@ func TestNameDiscovery(t *testing.T) { } time.Sleep(10 * time.Millisecond) } - cc.Close() - for i := 0; i < numServers; i++ { - servers[i].stop() - } } func TestEmptyAddrs(t *testing.T) { - servers, r := startServers(t, 1, math.MaxUint32) + defer leakcheck.Check(t) + servers, r, cleanup := startServers(t, 1, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() var reply string if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse) @@ -160,23 +170,26 @@ func TestEmptyAddrs(t *testing.T) { // Loop until the above updates apply. for { time.Sleep(10 * time.Millisecond) - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil { + cancel() break } + cancel() } - cc.Close() - servers[0].stop() } func TestRoundRobin(t *testing.T) { + defer leakcheck.Check(t) // Start 3 servers on 3 ports. numServers := 3 - servers, r := startServers(t, numServers, math.MaxUint32) + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() // Add servers[1] to the service discovery. u := &naming.Update{ Op: naming.Add, @@ -211,18 +224,17 @@ func TestRoundRobin(t *testing.T) { t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port) } } - cc.Close() - for i := 0; i < numServers; i++ { - servers[i].stop() - } } func TestCloseWithPendingRPC(t *testing.T) { - servers, r := startServers(t, 1, math.MaxUint32) + defer leakcheck.Check(t) + servers, r, cleanup := startServers(t, 1, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() var reply string if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) @@ -235,11 +247,13 @@ func TestCloseWithPendingRPC(t *testing.T) { r.w.inject(updates) // Loop until the above update applies. for { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + cancel() break } time.Sleep(10 * time.Millisecond) + cancel() } // Issue 2 RPCs which should be completed with error status once cc is closed. var wg sync.WaitGroup @@ -262,15 +276,17 @@ func TestCloseWithPendingRPC(t *testing.T) { time.Sleep(5 * time.Millisecond) cc.Close() wg.Wait() - servers[0].stop() } func TestGetOnWaitChannel(t *testing.T) { - servers, r := startServers(t, 1, math.MaxUint32) + defer leakcheck.Check(t) + servers, r, cleanup := startServers(t, 1, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() // Remove all servers so that all upcoming RPCs will block on waitCh. updates := []*naming.Update{{ Op: naming.Delete, @@ -279,10 +295,12 @@ func TestGetOnWaitChannel(t *testing.T) { r.w.inject(updates) for { var reply string - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + cancel() break } + cancel() time.Sleep(10 * time.Millisecond) } var wg sync.WaitGroup @@ -302,18 +320,19 @@ func TestGetOnWaitChannel(t *testing.T) { r.w.inject(updates) // Wait until the above RPC succeeds. wg.Wait() - cc.Close() - servers[0].stop() } func TestOneServerDown(t *testing.T) { + defer leakcheck.Check(t) // Start 2 servers. numServers := 2 - servers, r := startServers(t, numServers, math.MaxUint32) + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() // Add servers[1] to the service discovery. var updates []*naming.Update updates = append(updates, &naming.Update{ @@ -354,20 +373,19 @@ func TestOneServerDown(t *testing.T) { }() } wg.Wait() - cc.Close() - for i := 0; i < numServers; i++ { - servers[i].stop() - } } func TestOneAddressRemoval(t *testing.T) { + defer leakcheck.Check(t) // Start 2 servers. numServers := 2 - servers, r := startServers(t, numServers, math.MaxUint32) + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } + defer cc.Close() // Add servers[1] to the service discovery. var updates []*naming.Update updates = append(updates, &naming.Update{ @@ -416,8 +434,365 @@ func TestOneAddressRemoval(t *testing.T) { }() } wg.Wait() +} + +func checkServerUp(t *testing.T, currentServer *server) { + req := "port" + port := currentServer.port + cc, err := Dial("localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + var reply string + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == port { + break + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestPickFirstEmptyAddrs(t *testing.T) { + defer leakcheck.Check(t) + servers, r, cleanup := startServers(t, 1, math.MaxUint32) + defer cleanup() + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse) + } + // Inject name resolution change to remove the server so that there is no address + // available after that. + u := &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[0].port, + } + r.w.inject([]*naming.Update{u}) + // Loop until the above updates apply. + for { + time.Sleep(10 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil { + cancel() + break + } + cancel() + } +} + +func TestPickFirstCloseWithPendingRPC(t *testing.T) { + defer leakcheck.Check(t) + servers, r, cleanup := startServers(t, 1, math.MaxUint32) + defer cleanup() + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) + } + // Remove the server. + updates := []*naming.Update{{ + Op: naming.Delete, + Addr: "localhost:" + servers[0].port, + }} + r.w.inject(updates) + // Loop until the above update applies. + for { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + cancel() + break + } + time.Sleep(10 * time.Millisecond) + cancel() + } + // Issue 2 RPCs which should be completed with error status once cc is closed. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + }() + go func() { + defer wg.Done() + var reply string + time.Sleep(5 * time.Millisecond) + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + }() + time.Sleep(5 * time.Millisecond) cc.Close() - for i := 0; i < numServers; i++ { - servers[i].stop() + wg.Wait() +} + +func TestPickFirstOrderAllServerUp(t *testing.T) { + defer leakcheck.Check(t) + // Start 3 servers on 3 ports. + numServers := 3 + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + // Add servers[1] and [2] to the service discovery. + u := &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[1].port, + } + r.w.inject([]*naming.Update{u}) + + u = &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[2].port, + } + r.w.inject([]*naming.Update{u}) + + // Loop until all 3 servers are up + checkServerUp(t, servers[0]) + checkServerUp(t, servers[1]) + checkServerUp(t, servers[2]) + + // Check the incoming RPCs served in server[0] + req := "port" + var reply string + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) } + + // Delete server[0] in the balancer, the incoming RPCs served in server[1] + // For test addrconn, close server[0] instead + u = &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[0].port, + } + r.w.inject([]*naming.Update{u}) + // Loop until it changes to server[1] + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Add server[0] back to the balancer, the incoming RPCs served in server[1] + // Add is append operation, the order of Notify now is {server[1].port server[2].port server[0].port} + u = &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[0].port, + } + r.w.inject([]*naming.Update{u}) + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Delete server[1] in the balancer, the incoming RPCs served in server[2] + u = &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[1].port, + } + r.w.inject([]*naming.Update{u}) + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + break + } + time.Sleep(1 * time.Second) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Delete server[2] in the balancer, the incoming RPCs served in server[0] + u = &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[2].port, + } + r.w.inject([]*naming.Update{u}) + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(1 * time.Second) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestPickFirstOrderOneServerDown(t *testing.T) { + defer leakcheck.Check(t) + // Start 3 servers on 3 ports. + numServers := 3 + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + // Add servers[1] and [2] to the service discovery. + u := &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[1].port, + } + r.w.inject([]*naming.Update{u}) + + u = &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[2].port, + } + r.w.inject([]*naming.Update{u}) + + // Loop until all 3 servers are up + checkServerUp(t, servers[0]) + checkServerUp(t, servers[1]) + checkServerUp(t, servers[2]) + + // Check the incoming RPCs served in server[0] + req := "port" + var reply string + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } + + // server[0] down, incoming RPCs served in server[1], but the order of Notify still remains + // {server[0] server[1] server[2]} + servers[0].stop() + // Loop until it changes to server[1] + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // up the server[0] back, the incoming RPCs served in server[1] + p, _ := strconv.Atoi(servers[0].port) + servers[0] = newTestServer() + go servers[0].start(t, p, math.MaxUint32) + defer servers[0].stop() + servers[0].wait(t, 2*time.Second) + checkServerUp(t, servers[0]) + + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Delete server[1] in the balancer, the incoming RPCs served in server[0] + u = &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[1].port, + } + r.w.inject([]*naming.Update{u}) + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(1 * time.Second) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestPickFirstOneAddressRemoval(t *testing.T) { + defer leakcheck.Check(t) + // Start 2 servers. + numServers := 2 + servers, r, cleanup := startServers(t, numServers, math.MaxUint32) + defer cleanup() + cc, err := Dial("localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + defer cc.Close() + // Add servers[1] to the service discovery. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Add, + Addr: "localhost:" + servers[1].port, + }) + r.w.inject(updates) + + // Create a new cc to Loop until servers[1] is up + checkServerUp(t, servers[0]) + checkServerUp(t, servers[1]) + + var wg sync.WaitGroup + numRPC := 100 + sleepDuration := 10 * time.Millisecond + wg.Add(1) + go func() { + time.Sleep(sleepDuration) + // After sleepDuration, delete server[0]. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Delete, + Addr: "localhost:" + servers[0].port, + }) + r.w.inject(updates) + wg.Done() + }() + + // All non-failfast RPCs should not fail because there's at least one connection available. + for i := 0; i < numRPC; i++ { + wg.Add(1) + go func() { + var reply string + time.Sleep(sleepDuration) + // After sleepDuration, invoke RPC. + // server[0] is removed around the same time to make it racy between balancer and gRPC internals. + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + wg.Done() + }() + } + wg.Wait() } diff --git a/vendor/google.golang.org/grpc/balancer_v1_wrapper.go b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go new file mode 100644 index 000000000..9d0616080 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go @@ -0,0 +1,367 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +type balancerWrapperBuilder struct { + b Balancer // The v1 balancer. +} + +func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + bwb.b.Start(cc.Target(), BalancerConfig{ + DialCreds: opts.DialCreds, + Dialer: opts.Dialer, + }) + _, pickfirst := bwb.b.(*pickFirst) + bw := &balancerWrapper{ + balancer: bwb.b, + pickfirst: pickfirst, + cc: cc, + startCh: make(chan struct{}), + conns: make(map[resolver.Address]balancer.SubConn), + connSt: make(map[balancer.SubConn]*scState), + csEvltr: &connectivityStateEvaluator{}, + state: connectivity.Idle, + } + cc.UpdateBalancerState(connectivity.Idle, bw) + go bw.lbWatcher() + return bw +} + +func (bwb *balancerWrapperBuilder) Name() string { + return "wrapper" +} + +type scState struct { + addr Address // The v1 address type. + s connectivity.State + down func(error) +} + +type balancerWrapper struct { + balancer Balancer // The v1 balancer. + pickfirst bool + + cc balancer.ClientConn + + // To aggregate the connectivity state. + csEvltr *connectivityStateEvaluator + state connectivity.State + + mu sync.Mutex + conns map[resolver.Address]balancer.SubConn + connSt map[balancer.SubConn]*scState + // This channel is closed when handling the first resolver result. + // lbWatcher blocks until this is closed, to avoid race between + // - NewSubConn is created, cc wants to notify balancer of state changes; + // - Build hasn't return, cc doesn't have access to balancer. + startCh chan struct{} +} + +// lbWatcher watches the Notify channel of the balancer and manages +// connections accordingly. +func (bw *balancerWrapper) lbWatcher() { + <-bw.startCh + grpclog.Infof("balancerWrapper: is pickfirst: %v\n", bw.pickfirst) + notifyCh := bw.balancer.Notify() + if notifyCh == nil { + // There's no resolver in the balancer. Connect directly. + a := resolver.Address{ + Addr: bw.cc.Target(), + Type: resolver.Backend, + } + sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("Error creating connection to %v. Err: %v", a, err) + } else { + bw.mu.Lock() + bw.conns[a] = sc + bw.connSt[sc] = &scState{ + addr: Address{Addr: bw.cc.Target()}, + s: connectivity.Idle, + } + bw.mu.Unlock() + sc.Connect() + } + return + } + + for addrs := range notifyCh { + grpclog.Infof("balancerWrapper: got update addr from Notify: %v\n", addrs) + if bw.pickfirst { + var ( + oldA resolver.Address + oldSC balancer.SubConn + ) + bw.mu.Lock() + for oldA, oldSC = range bw.conns { + break + } + bw.mu.Unlock() + if len(addrs) <= 0 { + if oldSC != nil { + // Teardown old sc. + bw.mu.Lock() + delete(bw.conns, oldA) + delete(bw.connSt, oldSC) + bw.mu.Unlock() + bw.cc.RemoveSubConn(oldSC) + } + continue + } + + var newAddrs []resolver.Address + for _, a := range addrs { + newAddr := resolver.Address{ + Addr: a.Addr, + Type: resolver.Backend, // All addresses from balancer are all backends. + ServerName: "", + Metadata: a.Metadata, + } + newAddrs = append(newAddrs, newAddr) + } + if oldSC == nil { + // Create new sc. + sc, err := bw.cc.NewSubConn(newAddrs, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("Error creating connection to %v. Err: %v", newAddrs, err) + } else { + bw.mu.Lock() + // For pickfirst, there should be only one SubConn, so the + // address doesn't matter. All states updating (up and down) + // and picking should all happen on that only SubConn. + bw.conns[resolver.Address{}] = sc + bw.connSt[sc] = &scState{ + addr: addrs[0], // Use the first address. + s: connectivity.Idle, + } + bw.mu.Unlock() + sc.Connect() + } + } else { + oldSC.UpdateAddresses(newAddrs) + bw.mu.Lock() + bw.connSt[oldSC].addr = addrs[0] + bw.mu.Unlock() + } + } else { + var ( + add []resolver.Address // Addresses need to setup connections. + del []balancer.SubConn // Connections need to tear down. + ) + resAddrs := make(map[resolver.Address]Address) + for _, a := range addrs { + resAddrs[resolver.Address{ + Addr: a.Addr, + Type: resolver.Backend, // All addresses from balancer are all backends. + ServerName: "", + Metadata: a.Metadata, + }] = a + } + bw.mu.Lock() + for a := range resAddrs { + if _, ok := bw.conns[a]; !ok { + add = append(add, a) + } + } + for a, c := range bw.conns { + if _, ok := resAddrs[a]; !ok { + del = append(del, c) + delete(bw.conns, a) + // Keep the state of this sc in bw.connSt until its state becomes Shutdown. + } + } + bw.mu.Unlock() + for _, a := range add { + sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("Error creating connection to %v. Err: %v", a, err) + } else { + bw.mu.Lock() + bw.conns[a] = sc + bw.connSt[sc] = &scState{ + addr: resAddrs[a], + s: connectivity.Idle, + } + bw.mu.Unlock() + sc.Connect() + } + } + for _, c := range del { + bw.cc.RemoveSubConn(c) + } + } + } +} + +func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("balancerWrapper: handle subconn state change: %p, %v", sc, s) + bw.mu.Lock() + defer bw.mu.Unlock() + scSt, ok := bw.connSt[sc] + if !ok { + return + } + if s == connectivity.Idle { + sc.Connect() + } + oldS := scSt.s + scSt.s = s + if oldS != connectivity.Ready && s == connectivity.Ready { + scSt.down = bw.balancer.Up(scSt.addr) + } else if oldS == connectivity.Ready && s != connectivity.Ready { + if scSt.down != nil { + scSt.down(errConnClosing) + } + } + sa := bw.csEvltr.recordTransition(oldS, s) + if bw.state != sa { + bw.state = sa + } + bw.cc.UpdateBalancerState(bw.state, bw) + if s == connectivity.Shutdown { + // Remove state for this sc. + delete(bw.connSt, sc) + } + return +} + +func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) { + bw.mu.Lock() + defer bw.mu.Unlock() + select { + case <-bw.startCh: + default: + close(bw.startCh) + } + // There should be a resolver inside the balancer. + // All updates here, if any, are ignored. + return +} + +func (bw *balancerWrapper) Close() { + bw.mu.Lock() + defer bw.mu.Unlock() + select { + case <-bw.startCh: + default: + close(bw.startCh) + } + bw.balancer.Close() + return +} + +// The picker is the balancerWrapper itself. +// Pick should never return ErrNoSubConnAvailable. +// It either blocks or returns error, consistent with v1 balancer Get(). +func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + failfast := true // Default failfast is true. + if ss, ok := rpcInfoFromContext(ctx); ok { + failfast = ss.failfast + } + a, p, err := bw.balancer.Get(ctx, BalancerGetOptions{BlockingWait: !failfast}) + if err != nil { + return nil, nil, err + } + var done func(balancer.DoneInfo) + if p != nil { + done = func(i balancer.DoneInfo) { p() } + } + var sc balancer.SubConn + bw.mu.Lock() + defer bw.mu.Unlock() + if bw.pickfirst { + // Get the first sc in conns. + for _, sc = range bw.conns { + break + } + } else { + var ok bool + sc, ok = bw.conns[resolver.Address{ + Addr: a.Addr, + Type: resolver.Backend, + ServerName: "", + Metadata: a.Metadata, + }] + if !ok && failfast { + return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + } + if s, ok := bw.connSt[sc]; failfast && (!ok || s.s != connectivity.Ready) { + // If the returned sc is not ready and RPC is failfast, + // return error, and this RPC will fail. + return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + } + } + + return sc, done, nil +} + +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +type connectivityStateEvaluator struct { + mu sync.Mutex + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every subConn and based on +// that it evaluates what aggregated state should be. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection +// before any subConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + cse.mu.Lock() + defer cse.mu.Unlock() + + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case connectivity.Ready: + cse.numReady += updateVal + case connectivity.Connecting: + cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + return connectivity.Ready + } + if cse.numConnecting > 0 { + return connectivity.Connecting + } + return connectivity.TransientFailure +} diff --git a/vendor/google.golang.org/grpc/call.go b/vendor/google.golang.org/grpc/call.go index 797190f14..1ef2507c3 100644 --- a/vendor/google.golang.org/grpc/call.go +++ b/vendor/google.golang.org/grpc/call.go @@ -25,6 +25,7 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -99,17 +100,17 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, Client: true, } } - outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload) + hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload) if err != nil { return err } if c.maxSendMessageSize == nil { return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } - if len(outBuf) > *c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize) + if len(data) > *c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) } - err = t.Write(stream, outBuf, opts) + err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() dopts.copts.StatsHandler.HandleRPC(ctx, outPayload) @@ -135,7 +136,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { - c := defaultCallInfo + c := defaultCallInfo() mc := cc.GetMethodConfig(method) if mc.WaitForReady != nil { c.failFast = !*mc.WaitForReady @@ -149,13 +150,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli opts = append(cc.dopts.callOptions, opts...) for _, o := range opts { - if err := o.before(&c); err != nil { + if err := o.before(c); err != nil { return toRPCErr(err) } } defer func() { for _, o := range opts { - o.after(&c) + o.after(c) } }() @@ -178,7 +179,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } }() } - ctx = newContextWithRPCInfo(ctx) + ctx = newContextWithRPCInfo(ctx, c.failFast) sh := cc.dopts.copts.StatsHandler if sh != nil { ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) @@ -206,9 +207,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err error t transport.ClientTransport stream *transport.Stream - // Record the put handler from Balancer.Get(...). It is called once the + // Record the done handler from Balancer.Get(...). It is called once the // RPC has completed or failed. - put func() + done func(balancer.DoneInfo) ) // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ @@ -222,10 +223,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli callHdr.Creds = c.creds } - gopts := BalancerGetOptions{ - BlockingWait: !c.failFast, - } - t, put, err = cc.getTransport(ctx, gopts) + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := status.FromError(err); ok { @@ -245,14 +243,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } stream, err = t.NewStream(ctx, callHdr) if err != nil { - if put != nil { + if done != nil { if _, ok := err.(transport.ConnectionError); ok { // If error is connection error, transport was sending data on wire, // and we are not sure if anything has been sent on wire. // If error is not connection error, we are sure nothing has been sent. updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) } - put() + done(balancer.DoneInfo{Err: err}) } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -262,14 +260,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if peer, ok := peer.FromContext(stream.Context()); ok { c.peer = peer } - err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts) + err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put() + done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when // i) there is a connection error; or @@ -279,14 +277,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) + err = recvResponse(ctx, cc.dopts, t, c, stream, reply) if err != nil { - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put() + done(balancer.DoneInfo{Err: err}) } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -297,12 +295,12 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put() + done(balancer.DoneInfo{Err: err}) } return stream.Status().Err() } diff --git a/vendor/google.golang.org/grpc/call_test.go b/vendor/google.golang.org/grpc/call_test.go index deb3cb6ee..f48d30e87 100644 --- a/vendor/google.golang.org/grpc/call_test.go +++ b/vendor/google.golang.org/grpc/call_test.go @@ -32,6 +32,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/grpc/test/leakcheck" "google.golang.org/grpc/transport" ) @@ -104,18 +105,19 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) + hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) if err != nil { t.Errorf("Failed to encode the response: %v", err) return } - h.t.Write(s, reply, &transport.Options{}) + h.t.Write(s, hdr, data, &transport.Options{}) h.t.WriteStatus(s, status.New(codes.OK, "")) } type server struct { lis net.Listener port string + addr string startedErr chan error // sent nil or an error after server starts mu sync.Mutex conns map[transport.ServerTransport]bool @@ -137,7 +139,8 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { s.startedErr <- fmt.Errorf("failed to listen: %v", err) return } - _, p, err := net.SplitHostPort(s.lis.Addr().String()) + s.addr = s.lis.Addr().String() + _, p, err := net.SplitHostPort(s.addr) if err != nil { s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) return @@ -211,6 +214,7 @@ func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) { } func TestInvoke(t *testing.T) { + defer leakcheck.Check(t) server, cc := setUp(t, 0, math.MaxUint32) var reply string if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { @@ -221,6 +225,7 @@ func TestInvoke(t *testing.T) { } func TestInvokeLargeErr(t *testing.T) { + defer leakcheck.Check(t) server, cc := setUp(t, 0, math.MaxUint32) var reply string req := "hello" @@ -237,6 +242,7 @@ func TestInvokeLargeErr(t *testing.T) { // TestInvokeErrorSpecialChars checks that error messages don't get mangled. func TestInvokeErrorSpecialChars(t *testing.T) { + defer leakcheck.Check(t) server, cc := setUp(t, 0, math.MaxUint32) var reply string req := "weird error" @@ -253,6 +259,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { // TestInvokeCancel checks that an Invoke with a canceled context is not sent. func TestInvokeCancel(t *testing.T) { + defer leakcheck.Check(t) server, cc := setUp(t, 0, math.MaxUint32) var reply string req := "canceled" @@ -271,6 +278,7 @@ func TestInvokeCancel(t *testing.T) { // TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC // on a closed client will terminate. func TestInvokeCancelClosedNonFailFast(t *testing.T) { + defer leakcheck.Check(t) server, cc := setUp(t, 0, math.MaxUint32) var reply string cc.Close() diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index e3e3140f1..886bead9d 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -20,17 +20,22 @@ package grpc import ( "errors" + "fmt" + "math" "net" + "reflect" "strings" "sync" "time" "golang.org/x/net/context" "golang.org/x/net/trace" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -78,23 +83,40 @@ type dialOptions struct { cp Compressor dc Decompressor bs backoffStrategy - balancer Balancer block bool insecure bool timeout time.Duration scChan <-chan ServiceConfig copts transport.ConnectOptions callOptions []CallOption + // This is to support v1 balancer. + balancerBuilder balancer.Builder } const ( defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4 - defaultClientMaxSendMessageSize = 1024 * 1024 * 4 + defaultClientMaxSendMessageSize = math.MaxInt32 ) // DialOption configures how we set up the connection. type DialOption func(*dialOptions) +// WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched +// before doing a write on the wire. +func WithWriteBufferSize(s int) DialOption { + return func(o *dialOptions) { + o.copts.WriteBufferSize = s + } +} + +// WithReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most +// for each read syscall. +func WithReadBufferSize(s int) DialOption { + return func(o *dialOptions) { + o.copts.ReadBufferSize = s + } +} + // WithInitialWindowSize returns a DialOption which sets the value for initial window size on a stream. // The lower bound for window size is 64K and any value smaller than that will be ignored. func WithInitialWindowSize(s int32) DialOption { @@ -146,10 +168,23 @@ func WithDecompressor(dc Decompressor) DialOption { } } -// WithBalancer returns a DialOption which sets a load balancer. +// WithBalancer returns a DialOption which sets a load balancer with the v1 API. +// Name resolver will be ignored if this DialOption is specified. +// Deprecated: use the new balancer APIs in balancer package instead. func WithBalancer(b Balancer) DialOption { return func(o *dialOptions) { - o.balancer = b + o.balancerBuilder = &balancerWrapperBuilder{ + b: b, + } + } +} + +// WithBalancerBuilder is for testing only. Users using custom balancers should +// register their balancer and use service config to choose the balancer to use. +func WithBalancerBuilder(b balancer.Builder) DialOption { + // TODO(bar) remove this when switching balancer is done. + return func(o *dialOptions) { + o.balancerBuilder = b } } @@ -270,7 +305,7 @@ func WithUserAgent(s string) DialOption { } } -// WithKeepaliveParams returns a DialOption that specifies keepalive paramaters for the client transport. +// WithKeepaliveParams returns a DialOption that specifies keepalive parameters for the client transport. func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption { return func(o *dialOptions) { o.copts.KeepaliveParams = kp @@ -313,20 +348,37 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc := &ClientConn{ target: target, csMgr: &connectivityStateManager{}, - conns: make(map[Address]*addrConn), + conns: make(map[*addrConn]struct{}), + + blockingpicker: newPickerWrapper(), } - cc.csEvltr = &connectivityStateEvaluator{csMgr: cc.csMgr} cc.ctx, cc.cancel = context.WithCancel(context.Background()) for _, opt := range opts { opt(&cc.dopts) } + + if !cc.dopts.insecure { + if cc.dopts.copts.TransportCredentials == nil { + return nil, errNoTransportSecurity + } + } else { + if cc.dopts.copts.TransportCredentials != nil { + return nil, errCredentialsConflict + } + for _, cd := range cc.dopts.copts.PerRPCCredentials { + if cd.RequireTransportSecurity() { + return nil, errTransportCredentialsMissing + } + } + } + cc.mkp = cc.dopts.copts.KeepaliveParams if cc.dopts.copts.Dialer == nil { cc.dopts.copts.Dialer = newProxyDialer( func(ctx context.Context, addr string) (net.Conn, error) { - return dialContext(ctx, "tcp", addr) + return (&net.Dialer{}).DialContext(ctx, "tcp", addr) }, ) } @@ -382,49 +434,41 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } else { cc.authority = target } - waitC := make(chan error, 1) - go func() { - defer close(waitC) - if cc.dopts.balancer == nil && cc.sc.LB != nil { - cc.dopts.balancer = cc.sc.LB + + if cc.dopts.balancerBuilder != nil { + var credsClone credentials.TransportCredentials + if creds != nil { + credsClone = creds.Clone() } - if cc.dopts.balancer != nil { - var credsClone credentials.TransportCredentials - if creds != nil { - credsClone = creds.Clone() - } - config := BalancerConfig{ - DialCreds: credsClone, - Dialer: cc.dopts.copts.Dialer, - } - if err := cc.dopts.balancer.Start(target, config); err != nil { + buildOpts := balancer.BuildOptions{ + DialCreds: credsClone, + Dialer: cc.dopts.copts.Dialer, + } + // Build should not take long time. So it's ok to not have a goroutine for it. + // TODO(bar) init balancer after first resolver result to support service config balancer. + cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, buildOpts) + } else { + waitC := make(chan error, 1) + go func() { + defer close(waitC) + // No balancer, or no resolver within the balancer. Connect directly. + ac, err := cc.newAddrConn([]resolver.Address{{Addr: target}}) + if err != nil { waitC <- err return } - ch := cc.dopts.balancer.Notify() - if ch != nil { - if cc.dopts.block { - doneChan := make(chan struct{}) - go cc.lbWatcher(doneChan) - <-doneChan - } else { - go cc.lbWatcher(nil) - } + if err := ac.connect(cc.dopts.block); err != nil { + waitC <- err return } - } - // No balancer, or no resolver within the balancer. Connect directly. - if err := cc.resetAddrConn(Address{Addr: target}, cc.dopts.block, nil); err != nil { - waitC <- err - return - } - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case err := <-waitC: - if err != nil { - return nil, err + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-waitC: + if err != nil { + return nil, err + } } } if cc.dopts.scChan != nil && !scSet { @@ -442,55 +486,35 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * go cc.scWatcher() } - return cc, nil -} + // Build the resolver. + cc.resolverWrapper, err = newCCResolverWrapper(cc) + if err != nil { + return nil, fmt.Errorf("failed to build resolver: %v", err) + } -// connectivityStateEvaluator gets updated by addrConns when their -// states transition, based on which it evaluates the state of -// ClientConn. -// Note: This code will eventually sit in the balancer in the new design. -type connectivityStateEvaluator struct { - csMgr *connectivityStateManager - mu sync.Mutex - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. - numTransientFailure uint64 // Number of addrConns in transientFailure. -} + if cc.balancerWrapper != nil && cc.resolverWrapper == nil { + // TODO(bar) there should always be a resolver (DNS as the default). + // Unblock balancer initialization with a fake resolver update if there's no resolver. + // The balancer wrapper will not read the addresses, so an empty list works. + // TODO(bar) remove this after the real resolver is started. + cc.balancerWrapper.handleResolvedAddrs([]resolver.Address{}, nil) + } -// recordTransition records state change happening in every addrConn and based on -// that it evaluates what state the ClientConn is in. -// It can only transition between connectivity.Ready, connectivity.Connecting and connectivity.TransientFailure. Other states, -// Idle and connectivity.Shutdown are transitioned into by ClientConn; in the begining of the connection -// before any addrConn is created ClientConn is in idle state. In the end when ClientConn -// closes it is in connectivity.Shutdown state. -// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state. -func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) { - cse.mu.Lock() - defer cse.mu.Unlock() - - // Update counters. - for idx, state := range []connectivity.State{oldState, newState} { - updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. - switch state { - case connectivity.Ready: - cse.numReady += updateVal - case connectivity.Connecting: - cse.numConnecting += updateVal - case connectivity.TransientFailure: - cse.numTransientFailure += updateVal + // A blocking dial blocks until the clientConn is ready. + if cc.dopts.block { + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + // ctx got timeout or canceled. + return nil, ctx.Err() + } } } - // Evaluate. - if cse.numReady > 0 { - cse.csMgr.updateState(connectivity.Ready) - return - } - if cse.numConnecting > 0 { - cse.csMgr.updateState(connectivity.Connecting) - return - } - cse.csMgr.updateState(connectivity.TransientFailure) + return cc, nil } // connectivityStateManager keeps the connectivity.State of ClientConn. @@ -545,11 +569,15 @@ type ClientConn struct { authority string dopts dialOptions csMgr *connectivityStateManager - csEvltr *connectivityStateEvaluator // This will eventually be part of balancer. + + balancerWrapper *ccBalancerWrapper + resolverWrapper *ccResolverWrapper + + blockingpicker *pickerWrapper mu sync.RWMutex sc ServiceConfig - conns map[Address]*addrConn + conns map[*addrConn]struct{} // Keepalive parameter can be updated if a GoAway is received. mkp keepalive.ClientParameters } @@ -576,56 +604,6 @@ func (cc *ClientConn) GetState() connectivity.State { return cc.csMgr.getState() } -// lbWatcher watches the Notify channel of the balancer in cc and manages -// connections accordingly. If doneChan is not nil, it is closed after the -// first successfull connection is made. -func (cc *ClientConn) lbWatcher(doneChan chan struct{}) { - for addrs := range cc.dopts.balancer.Notify() { - var ( - add []Address // Addresses need to setup connections. - del []*addrConn // Connections need to tear down. - ) - cc.mu.Lock() - for _, a := range addrs { - if _, ok := cc.conns[a]; !ok { - add = append(add, a) - } - } - for k, c := range cc.conns { - var keep bool - for _, a := range addrs { - if k == a { - keep = true - break - } - } - if !keep { - del = append(del, c) - delete(cc.conns, c.addr) - } - } - cc.mu.Unlock() - for _, a := range add { - var err error - if doneChan != nil { - err = cc.resetAddrConn(a, true, nil) - if err == nil { - close(doneChan) - doneChan = nil - } - } else { - err = cc.resetAddrConn(a, false, nil) - } - if err != nil { - grpclog.Warningf("Error creating connection to %v. Err: %v", a, err) - } - } - for _, c := range del { - c.tearDown(errConnDrain) - } - } -} - func (cc *ClientConn) scWatcher() { for { select { @@ -644,67 +622,64 @@ func (cc *ClientConn) scWatcher() { } } -// resetAddrConn creates an addrConn for addr and adds it to cc.conns. -// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason. -// If tearDownErr is nil, errConnDrain will be used instead. -// -// We should never need to replace an addrConn with a new one. This function is only used -// as newAddrConn to create new addrConn. -// TODO rename this function and clean up the code. -func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) error { +// newAddrConn creates an addrConn for addrs and adds it to cc.conns. +func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ac := &addrConn{ cc: cc, - addr: addr, + addrs: addrs, dopts: cc.dopts, } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) - ac.csEvltr = cc.csEvltr - if EnableTracing { - ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) - } - if !ac.dopts.insecure { - if ac.dopts.copts.TransportCredentials == nil { - return errNoTransportSecurity - } - } else { - if ac.dopts.copts.TransportCredentials != nil { - return errCredentialsConflict - } - for _, cd := range ac.dopts.copts.PerRPCCredentials { - if cd.RequireTransportSecurity() { - return errTransportCredentialsMissing - } - } - } // Track ac in cc. This needs to be done before any getTransport(...) is called. cc.mu.Lock() if cc.conns == nil { cc.mu.Unlock() - return ErrClientConnClosing + return nil, ErrClientConnClosing } - stale := cc.conns[ac.addr] - cc.conns[ac.addr] = ac + cc.conns[ac] = struct{}{} cc.mu.Unlock() - if stale != nil { - // There is an addrConn alive on ac.addr already. This could be due to - // a buggy Balancer that reports duplicated Addresses. - if tearDownErr == nil { - // tearDownErr is nil if resetAddrConn is called by - // 1) Dial - // 2) lbWatcher - // In both cases, the stale ac should drain, not close. - stale.tearDown(errConnDrain) - } else { - stale.tearDown(tearDownErr) - } + return ac, nil +} + +// removeAddrConn removes the addrConn in the subConn from clientConn. +// It also tears down the ac with the given error. +func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) { + cc.mu.Lock() + if cc.conns == nil { + cc.mu.Unlock() + return + } + delete(cc.conns, ac) + cc.mu.Unlock() + ac.tearDown(err) +} + +// connect starts to creating transport and also starts the transport monitor +// goroutine for this ac. +// It does nothing if the ac is not IDLE. +// TODO(bar) Move this to the addrConn section. +// This was part of resetAddrConn, keep it here to make the diff look clean. +func (ac *addrConn) connect(block bool) error { + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + return errConnClosing + } + if ac.state != connectivity.Idle { + ac.mu.Unlock() + return nil + } + ac.state = connectivity.Connecting + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } + ac.mu.Unlock() + if block { - if err := ac.resetTransport(false); err != nil { + if err := ac.resetTransport(); err != nil { if err != errConnClosing { - // Tear down ac and delete it from cc.conns. - cc.mu.Lock() - delete(cc.conns, ac.addr) - cc.mu.Unlock() ac.tearDown(err) } if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { @@ -717,8 +692,8 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) } else { // Start a goroutine connecting to the server asynchronously. go func() { - if err := ac.resetTransport(false); err != nil { - grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) + if err := ac.resetTransport(); err != nil { + grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err) if err != errConnClosing { // Keep this ac in cc.conns, to get the reason it's torn down. ac.tearDown(err) @@ -731,6 +706,36 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) return nil } +// tryUpdateAddrs tries to update ac.addrs with the new addresses list. +// +// It checks whether current connected address of ac is in the new addrs list. +// - If true, it updates ac.addrs and returns true. The ac will keep using +// the existing connection. +// - If false, it does nothing and returns false. +func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { + ac.mu.Lock() + defer ac.mu.Unlock() + grpclog.Infof("addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) + if ac.state == connectivity.Shutdown { + ac.addrs = addrs + return true + } + + var curAddrFound bool + for _, a := range addrs { + if reflect.DeepEqual(ac.curAddr, a) { + curAddrFound = true + break + } + } + grpclog.Infof("addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound) + if curAddrFound { + ac.addrs = addrs + } + + return curAddrFound +} + // GetMethodConfig gets the method config of the input method. // If there's an exact match for input method (i.e. /service/method), we return // the corresponding MethodConfig. @@ -750,58 +755,37 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { return m } -func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - var ( - ac *addrConn - ok bool - put func() - ) - if cc.dopts.balancer == nil { +func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) { + if cc.balancerWrapper == nil { // If balancer is nil, there should be only one addrConn available. cc.mu.RLock() if cc.conns == nil { cc.mu.RUnlock() + // TODO this function returns toRPCErr and non-toRPCErr. Clean up + // the errors in ClientConn. return nil, nil, toRPCErr(ErrClientConnClosing) } - for _, ac = range cc.conns { + var ac *addrConn + for ac = range cc.conns { // Break after the first iteration to get the first addrConn. - ok = true break } cc.mu.RUnlock() - } else { - var ( - addr Address - err error - ) - addr, put, err = cc.dopts.balancer.Get(ctx, opts) - if err != nil { - return nil, nil, toRPCErr(err) - } - cc.mu.RLock() - if cc.conns == nil { - cc.mu.RUnlock() - return nil, nil, toRPCErr(ErrClientConnClosing) + if ac == nil { + return nil, nil, errConnClosing } - ac, ok = cc.conns[addr] - cc.mu.RUnlock() - } - if !ok { - if put != nil { - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) - put() + t, err := ac.wait(ctx, false /*hasBalancer*/, failfast) + if err != nil { + return nil, nil, err } - return nil, nil, errConnClosing + return t, nil, nil } - t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait) + + t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{}) if err != nil { - if put != nil { - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) - put() - } - return nil, nil, err + return nil, nil, toRPCErr(err) } - return t, put, nil + return t, done, nil } // Close tears down the ClientConn and all underlying connections. @@ -817,10 +801,14 @@ func (cc *ClientConn) Close() error { cc.conns = nil cc.csMgr.updateState(connectivity.Shutdown) cc.mu.Unlock() - if cc.dopts.balancer != nil { - cc.dopts.balancer.Close() + cc.blockingpicker.close() + if cc.resolverWrapper != nil { + cc.resolverWrapper.close() + } + if cc.balancerWrapper != nil { + cc.balancerWrapper.close() } - for _, ac := range conns { + for ac := range conns { ac.tearDown(ErrClientConnClosing) } return nil @@ -831,16 +819,15 @@ type addrConn struct { ctx context.Context cancel context.CancelFunc - cc *ClientConn - addr Address - dopts dialOptions - events trace.EventLog - - csEvltr *connectivityStateEvaluator + cc *ClientConn + curAddr resolver.Address + addrs []resolver.Address + dopts dialOptions + events trace.EventLog + acbw balancer.SubConn mu sync.Mutex state connectivity.State - down func(error) // the handler called when a connection is down. // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} @@ -880,108 +867,127 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { } } -// resetTransport recreates a transport to the address for ac. -// For the old transport: -// - if drain is true, it will be gracefully closed. -// - otherwise, it will be closed. -func (ac *addrConn) resetTransport(drain bool) error { +// resetTransport recreates a transport to the address for ac. The old +// transport will close itself on error or when the clientconn is closed. +// +// TODO(bar) make sure all state transitions are valid. +func (ac *addrConn) resetTransport() error { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() return errConnClosing } - ac.printf("connecting") - if ac.down != nil { - ac.down(downErrorf(false, true, "%v", errNetworkIO)) - ac.down = nil + if ac.ready != nil { + close(ac.ready) + ac.ready = nil } - oldState := ac.state - ac.state = connectivity.Connecting - ac.csEvltr.recordTransition(oldState, ac.state) - t := ac.transport ac.transport = nil + ac.curAddr = resolver.Address{} ac.mu.Unlock() - if t != nil && !drain { - t.Close() - } ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() for retries := 0; ; retries++ { + sleepTime := ac.dopts.bs.backoff(retries) + timeout := minConnectTimeout ac.mu.Lock() + if timeout < time.Duration(int(sleepTime)/len(ac.addrs)) { + timeout = time.Duration(int(sleepTime) / len(ac.addrs)) + } + connectTime := time.Now() if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. ac.mu.Unlock() return errConnClosing } - ac.mu.Unlock() - sleepTime := ac.dopts.bs.backoff(retries) - timeout := minConnectTimeout - if timeout < sleepTime { - timeout = sleepTime - } - ctx, cancel := context.WithTimeout(ac.ctx, timeout) - connectTime := time.Now() - sinfo := transport.TargetInfo{ - Addr: ac.addr.Addr, - Metadata: ac.addr.Metadata, + ac.printf("connecting") + if ac.state != connectivity.Connecting { + ac.state = connectivity.Connecting + // TODO(bar) remove condition once we always have a balancer. + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) + } } - newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts) - // Don't call cancel in success path due to a race in Go 1.6: - // https://github.com/golang/go/issues/15078. - if err != nil { - cancel() - - if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { - return err + // copy ac.addrs in case of race + addrsIter := make([]resolver.Address, len(ac.addrs)) + copy(addrsIter, ac.addrs) + copts := ac.dopts.copts + ac.mu.Unlock() + for _, addr := range addrsIter { + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() + return errConnClosing + } + ac.mu.Unlock() + sinfo := transport.TargetInfo{ + Addr: addr.Addr, + Metadata: addr.Metadata, + } + newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) + if err != nil { + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return err + } + grpclog.Warningf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, addr) + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() + return errConnClosing + } + ac.mu.Unlock() + continue } - grpclog.Warningf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, ac.addr) ac.mu.Lock() + ac.printf("ready") if ac.state == connectivity.Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() + newTransport.Close() return errConnClosing } - ac.errorf("transient failure: %v", err) - oldState = ac.state - ac.state = connectivity.TransientFailure - ac.csEvltr.recordTransition(oldState, ac.state) + ac.state = connectivity.Ready + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) + } + t := ac.transport + ac.transport = newTransport + if t != nil { + t.Close() + } + ac.curAddr = addr if ac.ready != nil { close(ac.ready) ac.ready = nil } ac.mu.Unlock() - timer := time.NewTimer(sleepTime - time.Since(connectTime)) - select { - case <-timer.C: - case <-ac.ctx.Done(): - timer.Stop() - return ac.ctx.Err() - } - timer.Stop() - continue + return nil } ac.mu.Lock() - ac.printf("ready") - if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. - ac.mu.Unlock() - newTransport.Close() - return errConnClosing + ac.state = connectivity.TransientFailure + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } - oldState = ac.state - ac.state = connectivity.Ready - ac.csEvltr.recordTransition(oldState, ac.state) - ac.transport = newTransport if ac.ready != nil { close(ac.ready) ac.ready = nil } - if ac.cc.dopts.balancer != nil { - ac.down = ac.cc.dopts.balancer.Up(ac.addr) - } ac.mu.Unlock() - return nil + timer := time.NewTimer(sleepTime - time.Since(connectTime)) + select { + case <-timer.C: + case <-ac.ctx.Done(): + timer.Stop() + return ac.ctx.Err() + } + timer.Stop() } } @@ -992,76 +998,39 @@ func (ac *addrConn) transportMonitor() { ac.mu.Lock() t := ac.transport ac.mu.Unlock() + // Block until we receive a goaway or an error occurs. select { - // This is needed to detect the teardown when - // the addrConn is idle (i.e., no RPC in flight). - case <-ac.ctx.Done(): - select { - case <-t.Error(): - t.Close() - default: - } - return case <-t.GoAway(): - ac.adjustParams(t.GetGoAwayReason()) - // If GoAway happens without any network I/O error, the underlying transport - // will be gracefully closed, and a new transport will be created. - // (The transport will be closed when all the pending RPCs finished or failed.) - // If GoAway and some network I/O error happen concurrently, the underlying transport - // will be closed, and a new transport will be created. - var drain bool - select { - case <-t.Error(): - default: - drain = true - } - if err := ac.resetTransport(drain); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return - } case <-t.Error(): - select { - case <-ac.ctx.Done(): - t.Close() - return - case <-t.GoAway(): - ac.adjustParams(t.GetGoAwayReason()) - if err := ac.resetTransport(false); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return - } - default: - } + } + // If a GoAway happened, regardless of error, adjust our keepalive + // parameters as appropriate. + select { + case <-t.GoAway(): + ac.adjustParams(t.GetGoAwayReason()) + default: + } + ac.mu.Lock() + // Set connectivity state to TransientFailure before calling + // resetTransport. Transition READY->CONNECTING is not valid. + ac.state = connectivity.TransientFailure + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) + } + ac.curAddr = resolver.Address{} + ac.mu.Unlock() + if err := ac.resetTransport(); err != nil { ac.mu.Lock() - if ac.state == connectivity.Shutdown { - // ac has been shutdown. - ac.mu.Unlock() - return - } - oldState := ac.state - ac.state = connectivity.TransientFailure - ac.csEvltr.recordTransition(oldState, ac.state) + ac.printf("transport exiting: %v", err) ac.mu.Unlock() - if err := ac.resetTransport(false); err != nil { - grpclog.Infof("get error from resetTransport %v, transportMonitor returning", err) - ac.mu.Lock() - ac.printf("transport exiting: %v", err) - ac.mu.Unlock() - grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) - if err != errConnClosing { - // Keep this ac in cc.conns, to get the reason it's torn down. - ac.tearDown(err) - } - return + grpclog.Warningf("grpc: addrConn.transportMonitor exits due to: %v", err) + if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. + ac.tearDown(err) } + return } } } @@ -1106,6 +1075,28 @@ func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (trans } } +// getReadyTransport returns the transport if ac's state is READY. +// Otherwise it returns nil, false. +// If ac's state is IDLE, it will trigger ac to connect. +func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { + ac.mu.Lock() + if ac.state == connectivity.Ready { + t := ac.transport + ac.mu.Unlock() + return t, true + } + var idle bool + if ac.state == connectivity.Idle { + idle = true + } + ac.mu.Unlock() + // Trigger idle ac to connect. + if idle { + ac.connect(false) + } + return nil, false +} + // tearDown starts to tear down the addrConn. // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // some edge cases (e.g., the caller opens and closes many addrConn's in a @@ -1113,13 +1104,9 @@ func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (trans // tearDown doesn't remove ac from ac.cc.conns. func (ac *addrConn) tearDown(err error) { ac.cancel() - ac.mu.Lock() + ac.curAddr = resolver.Address{} defer ac.mu.Unlock() - if ac.down != nil { - ac.down(downErrorf(false, false, "%v", err)) - ac.down = nil - } if err == errConnDrain && ac.transport != nil { // GracefulClose(...) may be executed multiple times when // i) receiving multiple GoAway frames from the server; or @@ -1130,10 +1117,13 @@ func (ac *addrConn) tearDown(err error) { if ac.state == connectivity.Shutdown { return } - oldState := ac.state ac.state = connectivity.Shutdown ac.tearDownErr = err - ac.csEvltr.recordTransition(oldState, ac.state) + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) + } if ac.events != nil { ac.events.Finish() ac.events = nil @@ -1142,8 +1132,11 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil && err != errConnDrain { - ac.transport.Close() - } return } + +func (ac *addrConn) getState() connectivity.State { + ac.mu.Lock() + defer ac.mu.Unlock() + return ac.state +} diff --git a/vendor/google.golang.org/grpc/clientconn_test.go b/vendor/google.golang.org/grpc/clientconn_test.go index 203b533a6..47801e962 100644 --- a/vendor/google.golang.org/grpc/clientconn_test.go +++ b/vendor/google.golang.org/grpc/clientconn_test.go @@ -30,10 +30,10 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/naming" + "google.golang.org/grpc/test/leakcheck" + "google.golang.org/grpc/testdata" ) -const tlsDir = "testdata/" - func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -44,13 +44,9 @@ func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.Sta } func TestConnectivityStates(t *testing.T) { - servers, resolver := startServers(t, 2, math.MaxUint32) - defer func() { - for i := 0; i < 2; i++ { - servers[i].stop() - } - }() - + defer leakcheck.Check(t) + servers, resolver, cleanup := startServers(t, 2, math.MaxUint32) + defer cleanup() cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) if err != nil { t.Fatalf("Dial(\"foo.bar.com\", WithBalancer(_)) = _, %v, want _ <nil>", err) @@ -85,6 +81,7 @@ func TestConnectivityStates(t *testing.T) { } func TestDialTimeout(t *testing.T) { + defer leakcheck.Check(t) conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) if err == nil { conn.Close() @@ -95,7 +92,8 @@ func TestDialTimeout(t *testing.T) { } func TestTLSDialTimeout(t *testing.T) { - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + defer leakcheck.Check(t) + creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com") if err != nil { t.Fatalf("Failed to create credentials %v", err) } @@ -109,6 +107,7 @@ func TestTLSDialTimeout(t *testing.T) { } func TestDefaultAuthority(t *testing.T) { + defer leakcheck.Check(t) target := "Non-Existent.Server:8080" conn, err := Dial(target, WithInsecure()) if err != nil { @@ -121,8 +120,9 @@ func TestDefaultAuthority(t *testing.T) { } func TestTLSServerNameOverwrite(t *testing.T) { + defer leakcheck.Check(t) overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName) + creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), overwriteServerName) if err != nil { t.Fatalf("Failed to create credentials %v", err) } @@ -137,6 +137,7 @@ func TestTLSServerNameOverwrite(t *testing.T) { } func TestWithAuthority(t *testing.T) { + defer leakcheck.Check(t) overwriteServerName := "over.write.server.name" conn, err := Dial("Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) if err != nil { @@ -149,8 +150,9 @@ func TestWithAuthority(t *testing.T) { } func TestWithAuthorityAndTLS(t *testing.T) { + defer leakcheck.Check(t) overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName) + creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), overwriteServerName) if err != nil { t.Fatalf("Failed to create credentials %v", err) } @@ -165,6 +167,7 @@ func TestWithAuthorityAndTLS(t *testing.T) { } func TestDialContextCancel(t *testing.T) { + defer leakcheck.Check(t) ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled { @@ -199,6 +202,7 @@ func (b *blockingBalancer) Close() error { } func TestDialWithBlockingBalancer(t *testing.T) { + defer leakcheck.Check(t) ctx, cancel := context.WithCancel(context.Background()) dialDone := make(chan struct{}) go func() { @@ -221,7 +225,8 @@ func (c securePerRPCCredentials) RequireTransportSecurity() bool { } func TestCredentialsMisuse(t *testing.T) { - tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + defer leakcheck.Check(t) + tlsCreds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com") if err != nil { t.Fatalf("Failed to create authenticator %v", err) } @@ -236,10 +241,12 @@ func TestCredentialsMisuse(t *testing.T) { } func TestWithBackoffConfigDefault(t *testing.T) { + defer leakcheck.Check(t) testBackoffConfigSet(t, &DefaultBackoffConfig) } func TestWithBackoffConfig(t *testing.T) { + defer leakcheck.Check(t) b := BackoffConfig{MaxDelay: DefaultBackoffConfig.MaxDelay / 2} expected := b setDefaults(&expected) // defaults should be set @@ -247,6 +254,7 @@ func TestWithBackoffConfig(t *testing.T) { } func TestWithBackoffMaxDelay(t *testing.T) { + defer leakcheck.Check(t) md := DefaultBackoffConfig.MaxDelay / 2 expected := BackoffConfig{MaxDelay: md} setDefaults(&expected) @@ -294,7 +302,9 @@ func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, erro } func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer leakcheck.Check(t) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock(), FailOnNonTempDialError(true)); err != nonTemporaryError { t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError) } @@ -332,6 +342,7 @@ func (b *emptyBalancer) Close() error { } func TestNonblockingDialWithEmptyBalancer(t *testing.T) { + defer leakcheck.Check(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() dialDone := make(chan error) @@ -350,6 +361,7 @@ func TestNonblockingDialWithEmptyBalancer(t *testing.T) { } func TestClientUpdatesParamsAfterGoAway(t *testing.T) { + defer leakcheck.Check(t) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen. Err: %v", err) diff --git a/vendor/google.golang.org/grpc/codes/code_string.go b/vendor/google.golang.org/grpc/codes/code_string.go index e6762d084..259837060 100644 --- a/vendor/google.golang.org/grpc/codes/code_string.go +++ b/vendor/google.golang.org/grpc/codes/code_string.go @@ -1,4 +1,4 @@ -// generated by stringer -type=Code; DO NOT EDIT +// Code generated by "stringer -type=Code"; DO NOT EDIT. package codes @@ -9,7 +9,7 @@ const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlre var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192} func (i Code) String() string { - if i+1 >= Code(len(_Code_index)) { + if i >= Code(len(_Code_index)-1) { return fmt.Sprintf("Code(%d)", i) } return _Code_name[_Code_index[i]:_Code_index[i+1]] diff --git a/vendor/google.golang.org/grpc/coverage.sh b/vendor/google.golang.org/grpc/coverage.sh deleted file mode 100755 index b85f9181d..000000000 --- a/vendor/google.golang.org/grpc/coverage.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash - - -set -e - -workdir=.cover -profile="$workdir/cover.out" -mode=set -end2endtest="google.golang.org/grpc/test" - -generate_cover_data() { - rm -rf "$workdir" - mkdir "$workdir" - - for pkg in "$@"; do - if [ $pkg == "google.golang.org/grpc" -o $pkg == "google.golang.org/grpc/transport" -o $pkg == "google.golang.org/grpc/metadata" -o $pkg == "google.golang.org/grpc/credentials" ] - then - f="$workdir/$(echo $pkg | tr / -)" - go test -covermode="$mode" -coverprofile="$f.cover" "$pkg" - go test -covermode="$mode" -coverpkg "$pkg" -coverprofile="$f.e2e.cover" "$end2endtest" - fi - done - - echo "mode: $mode" >"$profile" - grep -h -v "^mode:" "$workdir"/*.cover >>"$profile" -} - -show_cover_report() { - go tool cover -${1}="$profile" -} - -push_to_coveralls() { - goveralls -coverprofile="$profile" -} - -generate_cover_data $(go list ./...) -show_cover_report func -case "$1" in -"") - ;; ---html) - show_cover_report html ;; ---coveralls) - push_to_coveralls ;; -*) - echo >&2 "error: invalid option: $1" ;; -esac -rm -rf "$workdir" diff --git a/vendor/google.golang.org/grpc/credentials/credentials_test.go b/vendor/google.golang.org/grpc/credentials/credentials_test.go index 6be68a422..9b13db51d 100644 --- a/vendor/google.golang.org/grpc/credentials/credentials_test.go +++ b/vendor/google.golang.org/grpc/credentials/credentials_test.go @@ -24,6 +24,7 @@ import ( "testing" "golang.org/x/net/context" + "google.golang.org/grpc/testdata" ) func TestTLSOverrideServerName(t *testing.T) { @@ -50,8 +51,6 @@ func TestTLSClone(t *testing.T) { } -const tlsDir = "../test/testdata/" - type serverHandshake func(net.Conn) (AuthInfo, error) func TestClientHandshakeReturnsAuthInfo(t *testing.T) { @@ -129,7 +128,7 @@ func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.List return lis } -// Is run in a seperate goroutine. +// Is run in a separate goroutine. func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { serverRawConn, err := lis.Accept() if err != nil { @@ -162,7 +161,7 @@ func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lis // Server handshake implementation in gRPC. func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { - serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") + serverTLS, err := NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key")) if err != nil { return nil, err } @@ -184,7 +183,7 @@ func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { } func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { - cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") + cert, err := tls.LoadX509KeyPair(testdata.Path("server1.pem"), testdata.Path("server1.key")) if err != nil { return nil, err } diff --git a/vendor/google.golang.org/grpc/doc.go b/vendor/google.golang.org/grpc/doc.go index 41b675a81..187adbb11 100644 --- a/vendor/google.golang.org/grpc/doc.go +++ b/vendor/google.golang.org/grpc/doc.go @@ -1,4 +1,22 @@ /* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Package grpc implements an RPC system called gRPC. See grpc.io for more information about gRPC. diff --git a/vendor/google.golang.org/grpc/go16.go b/vendor/google.golang.org/grpc/go16.go deleted file mode 100644 index f3dbf2170..000000000 --- a/vendor/google.golang.org/grpc/go16.go +++ /dev/null @@ -1,98 +0,0 @@ -// +build go1.6,!go1.7 - -/* - * - * Copyright 2016 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpc - -import ( - "fmt" - "io" - "net" - "net/http" - "os" - - "golang.org/x/net/context" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/grpc/transport" -) - -// dialContext connects to the address on the named network. -func dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) -} - -func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { - req.Cancel = ctx.Done() - if err := req.Write(conn); err != nil { - return fmt.Errorf("failed to write the HTTP request: %v", err) - } - return nil -} - -// toRPCErr converts an error into an error from the status package. -func toRPCErr(err error) error { - if _, ok := status.FromError(err); ok { - return err - } - switch e := err.(type) { - case transport.StreamError: - return status.Error(e.Code, e.Desc) - case transport.ConnectionError: - return status.Error(codes.Unavailable, e.Desc) - default: - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - case ErrClientConnClosing: - return status.Error(codes.FailedPrecondition, err.Error()) - } - } - return status.Error(codes.Unknown, err.Error()) -} - -// convertCode converts a standard Go error into its canonical code. Note that -// this is only used to translate the error returned by the server applications. -func convertCode(err error) codes.Code { - switch err { - case nil: - return codes.OK - case io.EOF: - return codes.OutOfRange - case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF: - return codes.FailedPrecondition - case os.ErrInvalid: - return codes.InvalidArgument - case context.Canceled: - return codes.Canceled - case context.DeadlineExceeded: - return codes.DeadlineExceeded - } - switch { - case os.IsExist(err): - return codes.AlreadyExists - case os.IsNotExist(err): - return codes.NotFound - case os.IsPermission(err): - return codes.PermissionDenied - } - return codes.Unknown -} diff --git a/vendor/google.golang.org/grpc/go17.go b/vendor/google.golang.org/grpc/go17.go deleted file mode 100644 index a3421d99e..000000000 --- a/vendor/google.golang.org/grpc/go17.go +++ /dev/null @@ -1,98 +0,0 @@ -// +build go1.7 - -/* - * - * Copyright 2016 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpc - -import ( - "context" - "io" - "net" - "net/http" - "os" - - netctx "golang.org/x/net/context" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/grpc/transport" -) - -// dialContext connects to the address on the named network. -func dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, network, address) -} - -func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { - req = req.WithContext(ctx) - if err := req.Write(conn); err != nil { - return err - } - return nil -} - -// toRPCErr converts an error into an error from the status package. -func toRPCErr(err error) error { - if _, ok := status.FromError(err); ok { - return err - } - switch e := err.(type) { - case transport.StreamError: - return status.Error(e.Code, e.Desc) - case transport.ConnectionError: - return status.Error(codes.Unavailable, e.Desc) - default: - switch err { - case context.DeadlineExceeded, netctx.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled, netctx.Canceled: - return status.Error(codes.Canceled, err.Error()) - case ErrClientConnClosing: - return status.Error(codes.FailedPrecondition, err.Error()) - } - } - return status.Error(codes.Unknown, err.Error()) -} - -// convertCode converts a standard Go error into its canonical code. Note that -// this is only used to translate the error returned by the server applications. -func convertCode(err error) codes.Code { - switch err { - case nil: - return codes.OK - case io.EOF: - return codes.OutOfRange - case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF: - return codes.FailedPrecondition - case os.ErrInvalid: - return codes.InvalidArgument - case context.Canceled, netctx.Canceled: - return codes.Canceled - case context.DeadlineExceeded, netctx.DeadlineExceeded: - return codes.DeadlineExceeded - } - switch { - case os.IsExist(err): - return codes.AlreadyExists - case os.IsNotExist(err): - return codes.NotFound - case os.IsPermission(err): - return codes.PermissionDenied - } - return codes.Unknown -} diff --git a/vendor/google.golang.org/grpc/grpclb.go b/vendor/google.golang.org/grpc/grpclb.go index 619985e60..db56ff362 100644 --- a/vendor/google.golang.org/grpc/grpclb.go +++ b/vendor/google.golang.org/grpc/grpclb.go @@ -28,7 +28,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" - lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" + lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/naming" @@ -59,41 +59,21 @@ type balanceLoadClientStream struct { ClientStream } -func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error { +func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error { return x.ClientStream.SendMsg(m) } -func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) { - m := new(lbpb.LoadBalanceResponse) +func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) { + m := new(lbmpb.LoadBalanceResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } -// AddressType indicates the address type returned by name resolution. -type AddressType uint8 - -const ( - // Backend indicates the server is a backend server. - Backend AddressType = iota - // GRPCLB indicates the server is a grpclb load balancer. - GRPCLB -) - -// AddrMetadataGRPCLB contains the information the name resolver for grpclb should provide. The -// name resolver used by the grpclb balancer is required to provide this type of metadata in -// its address updates. -type AddrMetadataGRPCLB struct { - // AddrType is the type of server (grpc load balancer or backend). - AddrType AddressType - // ServerName is the name of the grpc load balancer. Used for authentication. - ServerName string -} - // NewGRPCLBBalancer creates a grpclb load balancer. func NewGRPCLBBalancer(r naming.Resolver) Balancer { - return &balancer{ + return &grpclbBalancer{ r: r, } } @@ -116,25 +96,24 @@ type grpclbAddrInfo struct { dropForLoadBalancing bool } -type balancer struct { - r naming.Resolver - target string - mu sync.Mutex - seq int // a sequence number to make sure addrCh does not get stale addresses. - w naming.Watcher - addrCh chan []Address - rbs []remoteBalancerInfo - addrs []*grpclbAddrInfo - next int - waitCh chan struct{} - done bool - expTimer *time.Timer - rand *rand.Rand - - clientStats lbpb.ClientStats +type grpclbBalancer struct { + r naming.Resolver + target string + mu sync.Mutex + seq int // a sequence number to make sure addrCh does not get stale addresses. + w naming.Watcher + addrCh chan []Address + rbs []remoteBalancerInfo + addrs []*grpclbAddrInfo + next int + waitCh chan struct{} + done bool + rand *rand.Rand + + clientStats lbmpb.ClientStats } -func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { +func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { updates, err := w.Next() if err != nil { grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) @@ -159,18 +138,18 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerIn if exist { continue } - md, ok := update.Metadata.(*AddrMetadataGRPCLB) + md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB) if !ok { // TODO: Revisit the handling here and may introduce some fallback mechanism. grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata) continue } switch md.AddrType { - case Backend: + case naming.Backend: // TODO: Revisit the handling here and may introduce some fallback mechanism. grpclog.Errorf("The name resolution does not give grpclb addresses") continue - case GRPCLB: + case naming.GRPCLB: b.rbs = append(b.rbs, remoteBalancerInfo{ addr: update.Addr, name: md.ServerName, @@ -201,34 +180,18 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerIn return nil } -func (b *balancer) serverListExpire(seq int) { - b.mu.Lock() - defer b.mu.Unlock() - // TODO: gRPC interanls do not clear the connections when the server list is stale. - // This means RPCs will keep using the existing server list until b receives new - // server list even though the list is expired. Revisit this behavior later. - if b.done || seq < b.seq { - return - } - b.next = 0 - b.addrs = nil - // Ask grpc internals to close all the corresponding connections. - b.addrCh <- nil -} - -func convertDuration(d *lbpb.Duration) time.Duration { +func convertDuration(d *lbmpb.Duration) time.Duration { if d == nil { return 0 } return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond } -func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { +func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) { if l == nil { return } servers := l.GetServers() - expiration := convertDuration(l.GetExpirationInterval()) var ( sl []*grpclbAddrInfo addrs []Address @@ -263,20 +226,11 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { b.next = 0 b.addrs = sl b.addrCh <- addrs - if b.expTimer != nil { - b.expTimer.Stop() - b.expTimer = nil - } - if expiration > 0 { - b.expTimer = time.AfterFunc(expiration, func() { - b.serverListExpire(seq) - }) - } } return } -func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { +func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { ticker := time.NewTicker(interval) defer ticker.Stop() for { @@ -287,15 +241,15 @@ func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Dura } b.mu.Lock() stats := b.clientStats - b.clientStats = lbpb.ClientStats{} // Clear the stats. + b.clientStats = lbmpb.ClientStats{} // Clear the stats. b.mu.Unlock() t := time.Now() - stats.Timestamp = &lbpb.Timestamp{ + stats.Timestamp = &lbmpb.Timestamp{ Seconds: t.Unix(), Nanos: int32(t.Nanosecond()), } - if err := s.Send(&lbpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{ + if err := s.Send(&lbmpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{ ClientStats: &stats, }, }); err != nil { @@ -305,7 +259,7 @@ func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Dura } } -func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { +func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() stream, err := lbc.BalanceLoad(ctx) @@ -319,9 +273,9 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b return } b.mu.Unlock() - initReq := &lbpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ - InitialRequest: &lbpb.InitialLoadBalanceRequest{ + initReq := &lbmpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{ + InitialRequest: &lbmpb.InitialLoadBalanceRequest{ Name: b.target, }, }, @@ -351,7 +305,7 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b streamDone := make(chan struct{}) defer close(streamDone) b.mu.Lock() - b.clientStats = lbpb.ClientStats{} // Clear client stats. + b.clientStats = lbmpb.ClientStats{} // Clear client stats. b.mu.Unlock() if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { go b.sendLoadReport(stream, d, streamDone) @@ -378,7 +332,7 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b return true } -func (b *balancer) Start(target string, config BalancerConfig) error { +func (b *grpclbBalancer) Start(target string, config BalancerConfig) error { b.rand = rand.New(rand.NewSource(time.Now().Unix())) // TODO: Fall back to the basic direct connection if there is no name resolver. if b.r == nil { @@ -507,8 +461,11 @@ func (b *balancer) Start(target string, config BalancerConfig) error { // WithDialer takes a different type of function, so we instead use a special DialOption here. dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) } + dopts = append(dopts, WithBlock()) ccError = make(chan struct{}) - cc, err = Dial(rb.addr, dopts...) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cc, err = DialContext(ctx, rb.addr, dopts...) + cancel() if err != nil { grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) close(ccError) @@ -534,7 +491,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error { return nil } -func (b *balancer) down(addr Address, err error) { +func (b *grpclbBalancer) down(addr Address, err error) { b.mu.Lock() defer b.mu.Unlock() for _, a := range b.addrs { @@ -545,7 +502,7 @@ func (b *balancer) down(addr Address, err error) { } } -func (b *balancer) Up(addr Address) func(error) { +func (b *grpclbBalancer) Up(addr Address) func(error) { b.mu.Lock() defer b.mu.Unlock() if b.done { @@ -573,7 +530,7 @@ func (b *balancer) Up(addr Address) func(error) { } } -func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { +func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { var ch chan struct{} b.mu.Lock() if b.done { @@ -643,17 +600,10 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre } } if !opts.BlockingWait { - if len(b.addrs) == 0 { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") - return - } - // Returns the next addr on b.addrs for a failfast RPC. - addr = b.addrs[b.next].addr - b.next++ + b.clientStats.NumCallsFinished++ + b.clientStats.NumCallsFinishedWithClientFailedToSend++ b.mu.Unlock() + err = Errorf(codes.Unavailable, "there is no address available") return } // Wait on b.waitCh for non-failfast RPCs. @@ -730,20 +680,17 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre } } -func (b *balancer) Notify() <-chan []Address { +func (b *grpclbBalancer) Notify() <-chan []Address { return b.addrCh } -func (b *balancer) Close() error { +func (b *grpclbBalancer) Close() error { b.mu.Lock() defer b.mu.Unlock() if b.done { return errBalancerClosed } b.done = true - if b.expTimer != nil { - b.expTimer.Stop() - } if b.waitCh != nil { close(b.waitCh) } diff --git a/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/grpclb.pb.go b/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/messages/messages.pb.go index f63941bd8..f4a27125a 100644 --- a/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/grpclb.pb.go +++ b/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/messages/messages.pb.go @@ -1,12 +1,11 @@ -// Code generated by protoc-gen-go. -// source: grpclb.proto -// DO NOT EDIT! +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: grpc_lb_v1/messages/messages.proto /* -Package grpc_lb_v1 is a generated protocol buffer package. +Package messages is a generated protocol buffer package. It is generated from these files: - grpclb.proto + grpc_lb_v1/messages/messages.proto It has these top-level messages: Duration @@ -19,7 +18,7 @@ It has these top-level messages: ServerList Server */ -package grpc_lb_v1 +package messages import proto "github.com/golang/protobuf/proto" import fmt "fmt" @@ -473,11 +472,6 @@ type ServerList struct { // across more servers. The client should consume the server list in order // unless instructed otherwise via the client_config. Servers []*Server `protobuf:"bytes,1,rep,name=servers" json:"servers,omitempty"` - // Indicates the amount of time that the client should consider this server - // list as valid. It may be considered stale after waiting this interval of - // time after receiving the list. If the interval is not positive, the - // client can assume the list is valid until the next list is received. - ExpirationInterval *Duration `protobuf:"bytes,3,opt,name=expiration_interval,json=expirationInterval" json:"expiration_interval,omitempty"` } func (m *ServerList) Reset() { *m = ServerList{} } @@ -492,13 +486,6 @@ func (m *ServerList) GetServers() []*Server { return nil } -func (m *ServerList) GetExpirationInterval() *Duration { - if m != nil { - return m.ExpirationInterval - } - return nil -} - // Contains server information. When none of the [drop_for_*] fields are true, // use the other fields. When drop_for_rate_limiting is true, ignore all other // fields. Use drop_for_load_balancing only when it is true and @@ -576,54 +563,53 @@ func init() { proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server") } -func init() { proto.RegisterFile("grpclb.proto", fileDescriptor0) } +func init() { proto.RegisterFile("grpc_lb_v1/messages/messages.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 733 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x55, 0xdd, 0x4e, 0x1b, 0x39, - 0x14, 0x66, 0x36, 0xfc, 0xe5, 0x24, 0x5a, 0x58, 0x93, 0x85, 0xc0, 0xc2, 0x2e, 0x1b, 0xa9, 0x34, - 0xaa, 0x68, 0x68, 0x43, 0x7b, 0xd1, 0x9f, 0x9b, 0x02, 0x45, 0x41, 0xe5, 0xa2, 0x72, 0xa8, 0x7a, - 0x55, 0x59, 0x4e, 0xc6, 0x80, 0xc5, 0xc4, 0x9e, 0xda, 0x4e, 0x68, 0x2f, 0x7b, 0xd9, 0x47, 0xe9, - 0x63, 0x54, 0x7d, 0x86, 0xbe, 0x4f, 0x65, 0x7b, 0x26, 0x33, 0x90, 0x1f, 0xd4, 0xbb, 0xf1, 0xf1, - 0x77, 0xbe, 0xf3, 0xf9, 0xd8, 0xdf, 0x19, 0x28, 0x5f, 0xa8, 0xb8, 0x1b, 0x75, 0x1a, 0xb1, 0x92, - 0x46, 0x22, 0xb0, 0xab, 0x46, 0xd4, 0x69, 0x0c, 0x1e, 0xd7, 0x9e, 0xc3, 0xe2, 0x51, 0x5f, 0x51, - 0xc3, 0xa5, 0x40, 0x55, 0x58, 0xd0, 0xac, 0x2b, 0x45, 0xa8, 0xab, 0xc1, 0x76, 0x50, 0x2f, 0xe0, - 0x74, 0x89, 0x2a, 0x30, 0x27, 0xa8, 0x90, 0xba, 0xfa, 0xc7, 0x76, 0x50, 0x9f, 0xc3, 0x7e, 0x51, - 0x7b, 0x01, 0xc5, 0x33, 0xde, 0x63, 0xda, 0xd0, 0x5e, 0xfc, 0xdb, 0xc9, 0xdf, 0x03, 0x40, 0xa7, - 0x92, 0x86, 0x07, 0x34, 0xa2, 0xa2, 0xcb, 0x30, 0xfb, 0xd8, 0x67, 0xda, 0xa0, 0xb7, 0xb0, 0xc4, - 0x05, 0x37, 0x9c, 0x46, 0x44, 0xf9, 0x90, 0xa3, 0x2b, 0x35, 0xef, 0x35, 0x32, 0xd5, 0x8d, 0x13, - 0x0f, 0x19, 0xcd, 0x6f, 0xcd, 0xe0, 0x3f, 0x93, 0xfc, 0x94, 0xf1, 0x25, 0x94, 0xbb, 0x11, 0x67, - 0xc2, 0x10, 0x6d, 0xa8, 0xf1, 0x2a, 0x4a, 0xcd, 0xb5, 0x3c, 0xdd, 0xa1, 0xdb, 0x6f, 0xdb, 0xed, - 0xd6, 0x0c, 0x2e, 0x75, 0xb3, 0xe5, 0xc1, 0x3f, 0xb0, 0x1e, 0x49, 0x1a, 0x92, 0x8e, 0x2f, 0x93, - 0x8a, 0x22, 0xe6, 0x73, 0xcc, 0x6a, 0x7b, 0xb0, 0x3e, 0x51, 0x09, 0x42, 0x30, 0x2b, 0x68, 0x8f, - 0x39, 0xf9, 0x45, 0xec, 0xbe, 0x6b, 0x5f, 0x67, 0xa1, 0x94, 0x2b, 0x86, 0xf6, 0xa1, 0x68, 0xd2, - 0x0e, 0x26, 0xe7, 0xfc, 0x3b, 0x2f, 0x6c, 0xd8, 0x5e, 0x9c, 0xe1, 0xd0, 0x03, 0xf8, 0x4b, 0xf4, - 0x7b, 0xa4, 0x4b, 0xa3, 0x48, 0xdb, 0x33, 0x29, 0xc3, 0x42, 0x77, 0xaa, 0x02, 0x5e, 0x12, 0xfd, - 0xde, 0xa1, 0x8d, 0xb7, 0x7d, 0x18, 0xed, 0x02, 0xca, 0xb0, 0xe7, 0x5c, 0x70, 0x7d, 0xc9, 0xc2, - 0x6a, 0xc1, 0x81, 0x97, 0x53, 0xf0, 0x71, 0x12, 0x47, 0x04, 0x1a, 0xa3, 0x68, 0x72, 0xcd, 0xcd, - 0x25, 0x09, 0x95, 0x8c, 0xc9, 0xb9, 0x54, 0x44, 0x51, 0xc3, 0x48, 0xc4, 0x7b, 0xdc, 0x70, 0x71, - 0x51, 0x9d, 0x75, 0x4c, 0xf7, 0x6f, 0x33, 0xbd, 0xe7, 0xe6, 0xf2, 0x48, 0xc9, 0xf8, 0x58, 0x2a, - 0x4c, 0x0d, 0x3b, 0x4d, 0xe0, 0x88, 0xc2, 0xde, 0x9d, 0x05, 0x72, 0xed, 0xb6, 0x15, 0xe6, 0x5c, - 0x85, 0xfa, 0x94, 0x0a, 0x59, 0xef, 0x6d, 0x89, 0x0f, 0xf0, 0x70, 0x52, 0x89, 0xe4, 0x19, 0x9c, - 0x53, 0x1e, 0xb1, 0x90, 0x18, 0x49, 0x34, 0x13, 0x61, 0x75, 0xde, 0x15, 0xd8, 0x19, 0x57, 0xc0, - 0x5f, 0xd5, 0xb1, 0xc3, 0x9f, 0xc9, 0x36, 0x13, 0x21, 0x6a, 0xc1, 0xff, 0x63, 0xe8, 0xaf, 0x84, - 0xbc, 0x16, 0x44, 0xb1, 0x2e, 0xe3, 0x03, 0x16, 0x56, 0x17, 0x1c, 0xe5, 0xd6, 0x6d, 0xca, 0x37, - 0x16, 0x85, 0x13, 0x50, 0xed, 0x47, 0x00, 0x2b, 0x37, 0x9e, 0x8d, 0x8e, 0xa5, 0xd0, 0x0c, 0xb5, - 0x61, 0x39, 0x73, 0x80, 0x8f, 0x25, 0x4f, 0x63, 0xe7, 0x2e, 0x0b, 0x78, 0x74, 0x6b, 0x06, 0x2f, - 0x0d, 0x3d, 0x90, 0x90, 0x3e, 0x83, 0x92, 0x66, 0x6a, 0xc0, 0x14, 0x89, 0xb8, 0x36, 0x89, 0x07, - 0x56, 0xf3, 0x7c, 0x6d, 0xb7, 0x7d, 0xca, 0x9d, 0x87, 0x40, 0x0f, 0x57, 0x07, 0x9b, 0xb0, 0x71, - 0xcb, 0x01, 0x9e, 0xd3, 0x5b, 0xe0, 0x5b, 0x00, 0x1b, 0x93, 0xa5, 0xa0, 0x27, 0xb0, 0x9a, 0x4f, - 0x56, 0x24, 0x64, 0x11, 0xbb, 0xa0, 0x26, 0xb5, 0x45, 0x25, 0xca, 0x92, 0xd4, 0x51, 0xb2, 0x87, - 0xde, 0xc1, 0x66, 0xde, 0xb2, 0x44, 0xb1, 0x58, 0x2a, 0x43, 0xb8, 0x30, 0x4c, 0x0d, 0x68, 0x94, - 0xc8, 0xaf, 0xe4, 0xe5, 0xa7, 0x43, 0x0c, 0xaf, 0xe7, 0xdc, 0x8b, 0x5d, 0xde, 0x49, 0x92, 0x56, - 0xfb, 0x12, 0x00, 0x64, 0xc7, 0x44, 0xbb, 0x76, 0x62, 0xd9, 0x95, 0x9d, 0x58, 0x85, 0x7a, 0xa9, - 0x89, 0x46, 0xfb, 0x81, 0x53, 0x08, 0x7a, 0x0d, 0x2b, 0xec, 0x53, 0xcc, 0x7d, 0x95, 0x4c, 0x4a, - 0x61, 0x8a, 0x14, 0x94, 0x25, 0x0c, 0x35, 0xfc, 0x0c, 0x60, 0xde, 0x53, 0xa3, 0x2d, 0x00, 0x1e, - 0x13, 0x1a, 0x86, 0x8a, 0x69, 0x3f, 0x34, 0xcb, 0xb8, 0xc8, 0xe3, 0x57, 0x3e, 0x60, 0xe7, 0x87, - 0x55, 0x9f, 0x4c, 0x4d, 0xf7, 0x6d, 0xed, 0x7c, 0xe3, 0x2e, 0x8c, 0xbc, 0x62, 0xc2, 0x69, 0x28, - 0xe2, 0xe5, 0x5c, 0x2b, 0xcf, 0x6c, 0x1c, 0xed, 0xc3, 0xea, 0x14, 0xdb, 0x2e, 0xe2, 0x95, 0x70, - 0x8c, 0x45, 0x9f, 0xc2, 0xda, 0x34, 0x2b, 0x2e, 0xe2, 0x4a, 0x38, 0xc6, 0x76, 0xcd, 0x0e, 0x94, - 0x73, 0xf7, 0xaf, 0x10, 0x86, 0x52, 0xf2, 0x6d, 0xc3, 0xe8, 0xdf, 0x7c, 0x83, 0x46, 0x87, 0xe5, - 0xc6, 0x7f, 0x13, 0xf7, 0xfd, 0x43, 0xaa, 0x07, 0x8f, 0x82, 0xce, 0xbc, 0xfb, 0x7d, 0xed, 0xff, - 0x0a, 0x00, 0x00, 0xff, 0xff, 0x64, 0xbf, 0xda, 0x5e, 0xce, 0x06, 0x00, 0x00, + // 709 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x55, 0xdd, 0x4e, 0x1b, 0x3b, + 0x10, 0x26, 0x27, 0x01, 0x92, 0x09, 0x3a, 0xe4, 0x98, 0x1c, 0x08, 0x14, 0x24, 0xba, 0x52, 0x69, + 0x54, 0xd1, 0x20, 0xa0, 0xbd, 0xe8, 0xcf, 0x45, 0x1b, 0x10, 0x0a, 0x2d, 0x17, 0x95, 0x43, 0x55, + 0xa9, 0x52, 0x65, 0x39, 0xd9, 0x21, 0x58, 0x6c, 0xec, 0xad, 0xed, 0x04, 0xf5, 0x11, 0xfa, 0x28, + 0x7d, 0x8c, 0xaa, 0xcf, 0xd0, 0xf7, 0xa9, 0xd6, 0xbb, 0x9b, 0x5d, 0x20, 0x80, 0x7a, 0x67, 0x8f, + 0xbf, 0xf9, 0xbe, 0xf1, 0xac, 0xbf, 0x59, 0xf0, 0x06, 0x3a, 0xec, 0xb3, 0xa0, 0xc7, 0xc6, 0xbb, + 0x3b, 0x43, 0x34, 0x86, 0x0f, 0xd0, 0x4c, 0x16, 0xad, 0x50, 0x2b, 0xab, 0x08, 0x44, 0x98, 0x56, + 0xd0, 0x6b, 0x8d, 0x77, 0xbd, 0x97, 0x50, 0x3e, 0x1c, 0x69, 0x6e, 0x85, 0x92, 0xa4, 0x01, 0xf3, + 0x06, 0xfb, 0x4a, 0xfa, 0xa6, 0x51, 0xd8, 0x2c, 0x34, 0x8b, 0x34, 0xdd, 0x92, 0x3a, 0xcc, 0x4a, + 0x2e, 0x95, 0x69, 0xfc, 0xb3, 0x59, 0x68, 0xce, 0xd2, 0x78, 0xe3, 0xbd, 0x82, 0xca, 0xa9, 0x18, + 0xa2, 0xb1, 0x7c, 0x18, 0xfe, 0x75, 0xf2, 0xcf, 0x02, 0x90, 0x13, 0xc5, 0xfd, 0x36, 0x0f, 0xb8, + 0xec, 0x23, 0xc5, 0xaf, 0x23, 0x34, 0x96, 0x7c, 0x80, 0x45, 0x21, 0x85, 0x15, 0x3c, 0x60, 0x3a, + 0x0e, 0x39, 0xba, 0xea, 0xde, 0xa3, 0x56, 0x56, 0x75, 0xeb, 0x38, 0x86, 0xdc, 0xcc, 0xef, 0xcc, + 0xd0, 0x7f, 0x93, 0xfc, 0x94, 0xf1, 0x35, 0x2c, 0xf4, 0x03, 0x81, 0xd2, 0x32, 0x63, 0xb9, 0x8d, + 0xab, 0xa8, 0xee, 0xad, 0xe4, 0xe9, 0x0e, 0xdc, 0x79, 0x37, 0x3a, 0xee, 0xcc, 0xd0, 0x6a, 0x3f, + 0xdb, 0xb6, 0x1f, 0xc0, 0x6a, 0xa0, 0xb8, 0xcf, 0x7a, 0xb1, 0x4c, 0x5a, 0x14, 0xb3, 0xdf, 0x42, + 0xf4, 0x76, 0x60, 0xf5, 0xd6, 0x4a, 0x08, 0x81, 0x92, 0xe4, 0x43, 0x74, 0xe5, 0x57, 0xa8, 0x5b, + 0x7b, 0xdf, 0x4b, 0x50, 0xcd, 0x89, 0x91, 0x7d, 0xa8, 0xd8, 0xb4, 0x83, 0xc9, 0x3d, 0xff, 0xcf, + 0x17, 0x36, 0x69, 0x2f, 0xcd, 0x70, 0xe4, 0x09, 0xfc, 0x27, 0x47, 0x43, 0xd6, 0xe7, 0x41, 0x60, + 0xa2, 0x3b, 0x69, 0x8b, 0xbe, 0xbb, 0x55, 0x91, 0x2e, 0xca, 0xd1, 0xf0, 0x20, 0x8a, 0x77, 0xe3, + 0x30, 0xd9, 0x06, 0x92, 0x61, 0xcf, 0x84, 0x14, 0xe6, 0x1c, 0xfd, 0x46, 0xd1, 0x81, 0x6b, 0x29, + 0xf8, 0x28, 0x89, 0x13, 0x06, 0xad, 0x9b, 0x68, 0x76, 0x29, 0xec, 0x39, 0xf3, 0xb5, 0x0a, 0xd9, + 0x99, 0xd2, 0x4c, 0x73, 0x8b, 0x2c, 0x10, 0x43, 0x61, 0x85, 0x1c, 0x34, 0x4a, 0x8e, 0xe9, 0xf1, + 0x75, 0xa6, 0x4f, 0xc2, 0x9e, 0x1f, 0x6a, 0x15, 0x1e, 0x29, 0x4d, 0xb9, 0xc5, 0x93, 0x04, 0x4e, + 0x38, 0xec, 0xdc, 0x2b, 0x90, 0x6b, 0x77, 0xa4, 0x30, 0xeb, 0x14, 0x9a, 0x77, 0x28, 0x64, 0xbd, + 0x8f, 0x24, 0xbe, 0xc0, 0xd3, 0xdb, 0x24, 0x92, 0x67, 0x70, 0xc6, 0x45, 0x80, 0x3e, 0xb3, 0x8a, + 0x19, 0x94, 0x7e, 0x63, 0xce, 0x09, 0x6c, 0x4d, 0x13, 0x88, 0x3f, 0xd5, 0x91, 0xc3, 0x9f, 0xaa, + 0x2e, 0x4a, 0x9f, 0x74, 0xe0, 0xe1, 0x14, 0xfa, 0x0b, 0xa9, 0x2e, 0x25, 0xd3, 0xd8, 0x47, 0x31, + 0x46, 0xbf, 0x31, 0xef, 0x28, 0x37, 0xae, 0x53, 0xbe, 0x8f, 0x50, 0x34, 0x01, 0x79, 0xbf, 0x0a, + 0xb0, 0x74, 0xe5, 0xd9, 0x98, 0x50, 0x49, 0x83, 0xa4, 0x0b, 0xb5, 0xcc, 0x01, 0x71, 0x2c, 0x79, + 0x1a, 0x5b, 0xf7, 0x59, 0x20, 0x46, 0x77, 0x66, 0xe8, 0xe2, 0xc4, 0x03, 0x09, 0xe9, 0x0b, 0xa8, + 0x1a, 0xd4, 0x63, 0xd4, 0x2c, 0x10, 0xc6, 0x26, 0x1e, 0x58, 0xce, 0xf3, 0x75, 0xdd, 0xf1, 0x89, + 0x70, 0x1e, 0x02, 0x33, 0xd9, 0xb5, 0xd7, 0x61, 0xed, 0x9a, 0x03, 0x62, 0xce, 0xd8, 0x02, 0x3f, + 0x0a, 0xb0, 0x76, 0x7b, 0x29, 0xe4, 0x19, 0x2c, 0xe7, 0x93, 0x35, 0xf3, 0x31, 0xc0, 0x01, 0xb7, + 0xa9, 0x2d, 0xea, 0x41, 0x96, 0xa4, 0x0f, 0x93, 0x33, 0xf2, 0x11, 0xd6, 0xf3, 0x96, 0x65, 0x1a, + 0x43, 0xa5, 0x2d, 0x13, 0xd2, 0xa2, 0x1e, 0xf3, 0x20, 0x29, 0xbf, 0x9e, 0x2f, 0x3f, 0x1d, 0x62, + 0x74, 0x35, 0xe7, 0x5e, 0xea, 0xf2, 0x8e, 0x93, 0x34, 0xef, 0x0d, 0x40, 0x76, 0x4b, 0xb2, 0x1d, + 0x0d, 0xac, 0x68, 0x17, 0x0d, 0xac, 0x62, 0xb3, 0xba, 0x47, 0x6e, 0xb6, 0x83, 0xa6, 0x90, 0x77, + 0xa5, 0x72, 0xb1, 0x56, 0xf2, 0x7e, 0x17, 0x60, 0x2e, 0x3e, 0x21, 0x1b, 0x00, 0x22, 0x64, 0xdc, + 0xf7, 0x35, 0x9a, 0x78, 0xe4, 0x2d, 0xd0, 0x8a, 0x08, 0xdf, 0xc6, 0x81, 0xc8, 0xfd, 0x91, 0x76, + 0x32, 0xf3, 0xdc, 0x3a, 0x32, 0xe3, 0x95, 0x4e, 0x5a, 0x75, 0x81, 0xd2, 0x99, 0xb1, 0x42, 0x6b, + 0xb9, 0x46, 0x9c, 0x46, 0x71, 0xb2, 0x0f, 0xcb, 0x77, 0x98, 0xae, 0x4c, 0x97, 0xfc, 0x29, 0x06, + 0x7b, 0x0e, 0x2b, 0x77, 0x19, 0xa9, 0x4c, 0xeb, 0xfe, 0x14, 0xd3, 0xb4, 0xe1, 0x73, 0x39, 0xfd, + 0x47, 0xf4, 0xe6, 0xdc, 0x4f, 0x62, 0xff, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa3, 0x36, 0x86, + 0xa6, 0x4a, 0x06, 0x00, 0x00, } diff --git a/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/grpclb.proto b/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/messages/messages.proto index b13b3438c..2ed04551f 100644 --- a/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/grpclb.proto +++ b/vendor/google.golang.org/grpc/grpclb/grpc_lb_v1/messages/messages.proto @@ -15,6 +15,7 @@ syntax = "proto3"; package grpc.lb.v1; +option go_package = "messages"; message Duration { // Signed seconds of the span of time. Must be from -315,576,000,000 @@ -31,7 +32,6 @@ message Duration { } message Timestamp { - // Represents seconds of UTC time since Unix epoch // 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to // 9999-12-31T23:59:59Z inclusive. @@ -44,12 +44,6 @@ message Timestamp { int32 nanos = 2; } -service LoadBalancer { - // Bidirectional rpc to get a list of servers. - rpc BalanceLoad(stream LoadBalanceRequest) - returns (stream LoadBalanceResponse); -} - message LoadBalanceRequest { oneof load_balance_request_type { // This message should be sent on the first request to the load balancer. @@ -127,11 +121,8 @@ message ServerList { // unless instructed otherwise via the client_config. repeated Server servers = 1; - // Indicates the amount of time that the client should consider this server - // list as valid. It may be considered stale after waiting this interval of - // time after receiving the list. If the interval is not positive, the - // client can assume the list is valid until the next list is received. - Duration expiration_interval = 3; + // Was google.protobuf.Duration expiration_interval. + reserved 3; } // Contains server information. When none of the [drop_for_*] fields are true, diff --git a/vendor/google.golang.org/grpc/grpclb/grpclb_server_generated.go b/vendor/google.golang.org/grpc/grpclb/grpclb_server_generated.go deleted file mode 100644 index 3cf1ac858..000000000 --- a/vendor/google.golang.org/grpc/grpclb/grpclb_server_generated.go +++ /dev/null @@ -1,72 +0,0 @@ -/* - * - * Copyright 2017 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// This file contains the generated server side code. -// It's only used for grpclb testing. - -package grpclb - -import ( - "google.golang.org/grpc" - lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" -) - -// Server API for LoadBalancer service - -type loadBalancerServer interface { - // Bidirectional rpc to get a list of servers. - BalanceLoad(*loadBalancerBalanceLoadServer) error -} - -func registerLoadBalancerServer(s *grpc.Server, srv loadBalancerServer) { - s.RegisterService( - &grpc.ServiceDesc{ - ServiceName: "grpc.lb.v1.LoadBalancer", - HandlerType: (*loadBalancerServer)(nil), - Methods: []grpc.MethodDesc{}, - Streams: []grpc.StreamDesc{ - { - StreamName: "BalanceLoad", - Handler: balanceLoadHandler, - ServerStreams: true, - ClientStreams: true, - }, - }, - Metadata: "grpclb.proto", - }, srv) -} - -func balanceLoadHandler(srv interface{}, stream grpc.ServerStream) error { - return srv.(loadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream}) -} - -type loadBalancerBalanceLoadServer struct { - grpc.ServerStream -} - -func (x *loadBalancerBalanceLoadServer) Send(m *lbpb.LoadBalanceResponse) error { - return x.ServerStream.SendMsg(m) -} - -func (x *loadBalancerBalanceLoadServer) Recv() (*lbpb.LoadBalanceRequest, error) { - m := new(lbpb.LoadBalanceRequest) - if err := x.ServerStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} diff --git a/vendor/google.golang.org/grpc/grpclb/grpclb_test.go b/vendor/google.golang.org/grpc/grpclb/grpclb_test.go index d58535041..46c1fe5b9 100644 --- a/vendor/google.golang.org/grpc/grpclb/grpclb_test.go +++ b/vendor/google.golang.org/grpc/grpclb/grpclb_test.go @@ -16,8 +16,11 @@ * */ -// Package grpclb is currently used only for grpclb testing. -package grpclb +//go:generate protoc --go_out=plugins=:. grpc_lb_v1/messages/messages.proto +//go:generate protoc --go_out=Mgrpc_lb_v1/messages/messages.proto=google.golang.org/grpc/grpclb/grpc_lb_v1/messages,plugins=grpc:. grpc_lb_v1/service/service.proto + +// Package grpclb_test is currently used only for grpclb testing. +package grpclb_test import ( "errors" @@ -34,10 +37,13 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" - lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" + lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service" + _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/metadata" "google.golang.org/grpc/naming" testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/test/leakcheck" ) var ( @@ -82,6 +88,7 @@ func (w *testWatcher) Next() (updates []*naming.Update, err error) { } func (w *testWatcher) Close() { + close(w.side) } // Inject naming resolution updates to the testWatcher. @@ -109,8 +116,8 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { r.w.update <- &naming.Update{ Op: naming.Add, Addr: addr, - Metadata: &grpc.AddrMetadataGRPCLB{ - AddrType: grpc.GRPCLB, + Metadata: &naming.AddrMetadataGRPCLB{ + AddrType: naming.GRPCLB, ServerName: lbsn, }, } @@ -181,15 +188,15 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) { } type remoteBalancer struct { - sls []*lbpb.ServerList + sls []*lbmpb.ServerList intervals []time.Duration statsDura time.Duration done chan struct{} mu sync.Mutex - stats lbpb.ClientStats + stats lbmpb.ClientStats } -func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer { +func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer { return &remoteBalancer{ sls: sls, intervals: intervals, @@ -201,7 +208,7 @@ func (b *remoteBalancer) stop() { close(b.done) } -func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) error { +func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer) error { req, err := stream.Recv() if err != nil { return err @@ -210,10 +217,10 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro if initReq.Name != besn { return grpc.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) } - resp := &lbpb.LoadBalanceResponse{ - LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ - InitialResponse: &lbpb.InitialLoadBalanceResponse{ - ClientStatsReportInterval: &lbpb.Duration{ + resp := &lbmpb.LoadBalanceResponse{ + LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_InitialResponse{ + InitialResponse: &lbmpb.InitialLoadBalanceResponse{ + ClientStatsReportInterval: &lbmpb.Duration{ Seconds: int64(b.statsDura.Seconds()), Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9), }, @@ -226,7 +233,7 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro go func() { for { var ( - req *lbpb.LoadBalanceRequest + req *lbmpb.LoadBalanceRequest err error ) if req, err = stream.Recv(); err != nil { @@ -244,8 +251,8 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro }() for k, v := range b.sls { time.Sleep(b.intervals[k]) - resp = &lbpb.LoadBalanceResponse{ - LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{ + resp = &lbmpb.LoadBalanceResponse{ + LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{ ServerList: v, }, } @@ -347,7 +354,7 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er return } ls = newRemoteBalancer(nil, nil) - registerLoadBalancerServer(lb, ls) + lbspb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) }() @@ -370,23 +377,24 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er } func TestGRPCLB(t *testing.T) { + defer leakcheck.Check(t) tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ + be := &lbmpb.Server{ IpAddress: tss.beIPs[0], Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, } - var bes []*lbpb.Server + var bes []*lbmpb.Server bes = append(bes, be) - sl := &lbpb.ServerList{ + sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.sls = []*lbmpb.ServerList{sl} tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ expected: besn, @@ -399,21 +407,22 @@ func TestGRPCLB(t *testing.T) { if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } + defer cc.Close() testC := testpb.NewTestServiceClient(cc) - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) } - cc.Close() } func TestDropRequest(t *testing.T) { + defer leakcheck.Check(t) tss, cleanup, err := newLoadBalancer(2) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbpb.ServerList{{ - Servers: []*lbpb.Server{{ + tss.ls.sls = []*lbmpb.ServerList{{ + Servers: []*lbmpb.Server{{ IpAddress: tss.beIPs[0], Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, @@ -437,7 +446,17 @@ func TestDropRequest(t *testing.T) { if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } + defer cc.Close() testC := testpb.NewTestServiceClient(cc) + // Wait until the first connection is up. + // The first one has Drop set to true, error should contain "drop requests". + for { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if strings.Contains(err.Error(), "drops requests") { + break + } + } + } // The 1st, non-fail-fast RPC should succeed. This ensures both server // connections are made, because the first one has DropForLoadBalancing set to true. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { @@ -455,27 +474,27 @@ func TestDropRequest(t *testing.T) { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) } } - cc.Close() } func TestDropRequestFailedNonFailFast(t *testing.T) { + defer leakcheck.Check(t) tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ + be := &lbmpb.Server{ IpAddress: tss.beIPs[0], Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, DropForLoadBalancing: true, } - var bes []*lbpb.Server + var bes []*lbmpb.Server bes = append(bes, be) - sl := &lbpb.ServerList{ + sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.sls = []*lbmpb.ServerList{sl} tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ expected: besn, @@ -488,77 +507,18 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } + defer cc.Close() testC := testpb.NewTestServiceClient(cc) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) } - cc.Close() -} - -func TestServerExpiration(t *testing.T) { - tss, cleanup, err := newLoadBalancer(1) - if err != nil { - t.Fatalf("failed to create new load balancer: %v", err) - } - defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - exp := &lbpb.Duration{ - Seconds: 0, - Nanos: 100000000, // 100ms - } - var sls []*lbpb.ServerList - sl := &lbpb.ServerList{ - Servers: bes, - ExpirationInterval: exp, - } - sls = append(sls, sl) - sl = &lbpb.ServerList{ - Servers: bes, - } - sls = append(sls, sl) - var intervals []time.Duration - intervals = append(intervals, 0) - intervals = append(intervals, 500*time.Millisecond) - tss.ls.sls = sls - tss.ls.intervals = intervals - creds := serverNameCheckCreds{ - expected: besn, - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) - if err != nil { - t.Fatalf("Failed to dial to the backend %v", err) - } - testC := testpb.NewTestServiceClient(cc) - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) - } - // Sleep and wake up when the first server list gets expired. - time.Sleep(150 * time.Millisecond) - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) - } - // A non-failfast rpc should be succeeded after the second server list is received from - // the remote load balancer. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) - } - cc.Close() } // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. func TestBalancerDisconnects(t *testing.T) { + defer leakcheck.Check(t) var ( lbAddrs []string lbs []*grpc.Server @@ -570,17 +530,17 @@ func TestBalancerDisconnects(t *testing.T) { } defer cleanup() - be := &lbpb.Server{ + be := &lbmpb.Server{ IpAddress: tss.beIPs[0], Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, } - var bes []*lbpb.Server + var bes []*lbmpb.Server bes = append(bes, be) - sl := &lbpb.ServerList{ + sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.sls = []*lbmpb.ServerList{sl} tss.ls.intervals = []time.Duration{0} lbAddrs = append(lbAddrs, tss.lbAddr) @@ -601,6 +561,7 @@ func TestBalancerDisconnects(t *testing.T) { if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } + defer cc.Close() testC := testpb.NewTestServiceClient(cc) var previousTrailer string trailer := metadata.MD{} @@ -627,8 +588,8 @@ func TestBalancerDisconnects(t *testing.T) { resolver.inject([]*naming.Update{ {Op: naming.Add, Addr: lbAddrs[2], - Metadata: &grpc.AddrMetadataGRPCLB{ - AddrType: grpc.GRPCLB, + Metadata: &naming.AddrMetadataGRPCLB{ + AddrType: naming.GRPCLB, ServerName: lbsn, }, }, @@ -645,7 +606,6 @@ func TestBalancerDisconnects(t *testing.T) { } time.Sleep(100 * time.Millisecond) } - cc.Close() } type failPreRPCCred struct{} @@ -661,21 +621,21 @@ func (failPreRPCCred) RequireTransportSecurity() bool { return false } -func checkStats(stats *lbpb.ClientStats, expected *lbpb.ClientStats) error { +func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error { if !proto.Equal(stats, expected) { return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected) } return nil } -func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbpb.ClientStats { +func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats { tss, cleanup, err := newLoadBalancer(3) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbpb.ServerList{{ - Servers: []*lbpb.Server{{ + tss.ls.sls = []*lbmpb.ServerList{{ + Servers: []*lbmpb.Server{{ IpAddress: tss.beIPs[2], Port: int32(tss.bePorts[2]), LoadBalanceToken: lbToken, @@ -709,6 +669,7 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool const countRPC = 40 func TestGRPCLBStatsUnarySuccess(t *testing.T) { + defer leakcheck.Check(t) stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) // The first non-failfast RPC succeeds, all connections are up. @@ -720,7 +681,7 @@ func TestGRPCLBStatsUnarySuccess(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC), NumCallsFinished: int64(countRPC), NumCallsFinishedKnownReceived: int64(countRPC), @@ -730,6 +691,7 @@ func TestGRPCLBStatsUnarySuccess(t *testing.T) { } func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { + defer leakcheck.Check(t) c := 0 stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) @@ -746,7 +708,7 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC + c), NumCallsFinished: int64(countRPC + c), NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1), @@ -757,6 +719,7 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { } func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { + defer leakcheck.Check(t) c := 0 stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) @@ -773,7 +736,7 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC + c), NumCallsFinished: int64(countRPC + c), NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1), @@ -784,6 +747,7 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { } func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { + defer leakcheck.Check(t) stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) // The first non-failfast RPC succeeds, all connections are up. @@ -795,7 +759,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC), NumCallsFinished: int64(countRPC), NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1), @@ -806,6 +770,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { } func TestGRPCLBStatsStreamingSuccess(t *testing.T) { + defer leakcheck.Check(t) stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) // The first non-failfast RPC succeeds, all connections are up. @@ -831,7 +796,7 @@ func TestGRPCLBStatsStreamingSuccess(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC), NumCallsFinished: int64(countRPC), NumCallsFinishedKnownReceived: int64(countRPC), @@ -841,6 +806,7 @@ func TestGRPCLBStatsStreamingSuccess(t *testing.T) { } func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { + defer leakcheck.Check(t) c := 0 stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) @@ -857,7 +823,7 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC + c), NumCallsFinished: int64(countRPC + c), NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1), @@ -868,6 +834,7 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { } func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { + defer leakcheck.Check(t) c := 0 stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) @@ -884,7 +851,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC + c), NumCallsFinished: int64(countRPC + c), NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1), @@ -895,6 +862,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { } func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { + defer leakcheck.Check(t) stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { testC := testpb.NewTestServiceClient(cc) // The first non-failfast RPC succeeds, all connections are up. @@ -912,7 +880,7 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { } }) - if err := checkStats(&stats, &lbpb.ClientStats{ + if err := checkStats(&stats, &lbmpb.ClientStats{ NumCallsStarted: int64(countRPC), NumCallsFinished: int64(countRPC), NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1), diff --git a/vendor/google.golang.org/grpc/grpclog/grpclog.go b/vendor/google.golang.org/grpc/grpclog/grpclog.go index 73d117097..16a7d8886 100644 --- a/vendor/google.golang.org/grpc/grpclog/grpclog.go +++ b/vendor/google.golang.org/grpc/grpclog/grpclog.go @@ -1,33 +1,18 @@ /* * - * Copyright 2017, Google Inc. - * All rights reserved. + * Copyright 2017 gRPC authors. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google Inc. nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. + * http://www.apache.org/licenses/LICENSE-2.0 * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ @@ -99,6 +84,7 @@ func Errorln(args ...interface{}) { // It calls os.Exit() with exit code 1. func Fatal(args ...interface{}) { logger.Fatal(args...) + // Make sure fatal logs will exit. os.Exit(1) } @@ -106,6 +92,7 @@ func Fatal(args ...interface{}) { // It calles os.Exit() with exit code 1. func Fatalf(format string, args ...interface{}) { logger.Fatalf(format, args...) + // Make sure fatal logs will exit. os.Exit(1) } @@ -113,6 +100,7 @@ func Fatalf(format string, args ...interface{}) { // It calle os.Exit()) with exit code 1. func Fatalln(args ...interface{}) { logger.Fatalln(args...) + // Make sure fatal logs will exit. os.Exit(1) } diff --git a/vendor/google.golang.org/grpc/grpclog/loggerv2.go b/vendor/google.golang.org/grpc/grpclog/loggerv2.go index f5193be92..d49325776 100644 --- a/vendor/google.golang.org/grpc/grpclog/loggerv2.go +++ b/vendor/google.golang.org/grpc/grpclog/loggerv2.go @@ -1,33 +1,18 @@ /* * - * Copyright 2017, Google Inc. - * All rights reserved. + * Copyright 2017 gRPC authors. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google Inc. nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. + * http://www.apache.org/licenses/LICENSE-2.0 * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ @@ -62,13 +47,16 @@ type LoggerV2 interface { // Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf. Errorf(format string, args ...interface{}) // Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print. - // This function should call os.Exit() with a non-zero exit code. + // gRPC ensures that all Fatal logs will exit with os.Exit(1). + // Implementations may also call os.Exit() with a non-zero exit code. Fatal(args ...interface{}) // Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println. - // This function should call os.Exit() with a non-zero exit code. + // gRPC ensures that all Fatal logs will exit with os.Exit(1). + // Implementations may also call os.Exit() with a non-zero exit code. Fatalln(args ...interface{}) // Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf. - // This function should call os.Exit() with a non-zero exit code. + // gRPC ensures that all Fatal logs will exit with os.Exit(1). + // Implementations may also call os.Exit() with a non-zero exit code. Fatalf(format string, args ...interface{}) // V reports whether verbosity level l is at least the requested verbose level. V(l int) bool @@ -189,14 +177,17 @@ func (g *loggerT) Errorf(format string, args ...interface{}) { func (g *loggerT) Fatal(args ...interface{}) { g.m[fatalLog].Fatal(args...) + // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). } func (g *loggerT) Fatalln(args ...interface{}) { g.m[fatalLog].Fatalln(args...) + // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). } func (g *loggerT) Fatalf(format string, args ...interface{}) { g.m[fatalLog].Fatalf(format, args...) + // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). } func (g *loggerT) V(l int) bool { diff --git a/vendor/google.golang.org/grpc/grpclog/loggerv2_test.go b/vendor/google.golang.org/grpc/grpclog/loggerv2_test.go index 61c0efe34..756f215f9 100644 --- a/vendor/google.golang.org/grpc/grpclog/loggerv2_test.go +++ b/vendor/google.golang.org/grpc/grpclog/loggerv2_test.go @@ -1,33 +1,18 @@ /* * - * Copyright 2017, Google Inc. - * All rights reserved. + * Copyright 2017 gRPC authors. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google Inc. nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. + * http://www.apache.org/licenses/LICENSE-2.0 * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ diff --git a/vendor/google.golang.org/grpc/metadata/metadata.go b/vendor/google.golang.org/grpc/metadata/metadata.go index be4f9e73c..ccfea5d45 100644 --- a/vendor/google.golang.org/grpc/metadata/metadata.go +++ b/vendor/google.golang.org/grpc/metadata/metadata.go @@ -44,6 +44,9 @@ type MD map[string][]string // - lowercase letters: a-z // - special characters: -_. // Uppercase letters are automatically converted to lowercase. +// +// Keys beginning with "grpc-" are reserved for grpc-internal use only and may +// result in errors if set in metadata. func New(m map[string]string) MD { md := MD{} for k, val := range m { @@ -62,6 +65,9 @@ func New(m map[string]string) MD { // - lowercase letters: a-z // - special characters: -_. // Uppercase letters are automatically converted to lowercase. +// +// Keys beginning with "grpc-" are reserved for grpc-internal use only and may +// result in errors if set in metadata. func Pairs(kv ...string) MD { if len(kv)%2 == 1 { panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) @@ -104,11 +110,6 @@ func Join(mds ...MD) MD { type mdIncomingKey struct{} type mdOutgoingKey struct{} -// NewContext is a wrapper for NewOutgoingContext(ctx, md). Deprecated. -func NewContext(ctx context.Context, md MD) context.Context { - return NewOutgoingContext(ctx, md) -} - // NewIncomingContext creates a new context with incoming md attached. func NewIncomingContext(ctx context.Context, md MD) context.Context { return context.WithValue(ctx, mdIncomingKey{}, md) @@ -119,11 +120,6 @@ func NewOutgoingContext(ctx context.Context, md MD) context.Context { return context.WithValue(ctx, mdOutgoingKey{}, md) } -// FromContext is a wrapper for FromIncomingContext(ctx). Deprecated. -func FromContext(ctx context.Context) (md MD, ok bool) { - return FromIncomingContext(ctx) -} - // FromIncomingContext returns the incoming metadata in ctx if it exists. The // returned MD should not be modified. Writing to it may cause races. // Modification should be made to copies of the returned MD. diff --git a/vendor/google.golang.org/grpc/naming/dns_resolver.go b/vendor/google.golang.org/grpc/naming/dns_resolver.go new file mode 100644 index 000000000..7e69a2ca0 --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/dns_resolver.go @@ -0,0 +1,290 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import ( + "errors" + "fmt" + "net" + "strconv" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" +) + +const ( + defaultPort = "443" + defaultFreq = time.Minute * 30 +) + +var ( + errMissingAddr = errors.New("missing address") + errWatcherClose = errors.New("watcher has been closed") +) + +// NewDNSResolverWithFreq creates a DNS Resolver that can resolve DNS names, and +// create watchers that poll the DNS server using the frequency set by freq. +func NewDNSResolverWithFreq(freq time.Duration) (Resolver, error) { + return &dnsResolver{freq: freq}, nil +} + +// NewDNSResolver creates a DNS Resolver that can resolve DNS names, and create +// watchers that poll the DNS server using the default frequency defined by defaultFreq. +func NewDNSResolver() (Resolver, error) { + return NewDNSResolverWithFreq(defaultFreq) +} + +// dnsResolver handles name resolution for names following the DNS scheme +type dnsResolver struct { + // frequency of polling the DNS server that the watchers created by this resolver will use. + freq time.Duration +} + +// formatIP returns ok = false if addr is not a valid textual representation of an IP address. +// If addr is an IPv4 address, return the addr and ok = true. +// If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. +func formatIP(addr string) (addrIP string, ok bool) { + ip := net.ParseIP(addr) + if ip == nil { + return "", false + } + if ip.To4() != nil { + return addr, true + } + return "[" + addr + "]", true +} + +// parseTarget takes the user input target string, returns formatted host and port info. +// If target doesn't specify a port, set the port to be the defaultPort. +// If target is in IPv6 format and host-name is enclosed in sqarue brackets, brackets +// are strippd when setting the host. +// examples: +// target: "www.google.com" returns host: "www.google.com", port: "443" +// target: "ipv4-host:80" returns host: "ipv4-host", port: "80" +// target: "[ipv6-host]" returns host: "ipv6-host", port: "443" +// target: ":80" returns host: "localhost", port: "80" +// target: ":" returns host: "localhost", port: "443" +func parseTarget(target string) (host, port string, err error) { + if target == "" { + return "", "", errMissingAddr + } + + if ip := net.ParseIP(target); ip != nil { + // target is an IPv4 or IPv6(without brackets) address + return target, defaultPort, nil + } + if host, port, err := net.SplitHostPort(target); err == nil { + // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port + if host == "" { + // Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed. + host = "localhost" + } + if port == "" { + // If the port field is empty(target ends with colon), e.g. "[::1]:", defaultPort is used. + port = defaultPort + } + return host, port, nil + } + if host, port, err := net.SplitHostPort(target + ":" + defaultPort); err == nil { + // target doesn't have port + return host, port, nil + } + return "", "", fmt.Errorf("invalid target address %v", target) +} + +// Resolve creates a watcher that watches the name resolution of the target. +func (r *dnsResolver) Resolve(target string) (Watcher, error) { + host, port, err := parseTarget(target) + if err != nil { + return nil, err + } + + if net.ParseIP(host) != nil { + ipWatcher := &ipWatcher{ + updateChan: make(chan *Update, 1), + } + host, _ = formatIP(host) + ipWatcher.updateChan <- &Update{Op: Add, Addr: host + ":" + port} + return ipWatcher, nil + } + + ctx, cancel := context.WithCancel(context.Background()) + return &dnsWatcher{ + r: r, + host: host, + port: port, + ctx: ctx, + cancel: cancel, + t: time.NewTimer(0), + }, nil +} + +// dnsWatcher watches for the name resolution update for a specific target +type dnsWatcher struct { + r *dnsResolver + host string + port string + // The latest resolved address set + curAddrs map[string]*Update + ctx context.Context + cancel context.CancelFunc + t *time.Timer +} + +// ipWatcher watches for the name resolution update for an IP address. +type ipWatcher struct { + updateChan chan *Update +} + +// Next returns the adrress resolution Update for the target. For IP address, +// the resolution is itself, thus polling name server is unncessary. Therefore, +// Next() will return an Update the first time it is called, and will be blocked +// for all following calls as no Update exisits until watcher is closed. +func (i *ipWatcher) Next() ([]*Update, error) { + u, ok := <-i.updateChan + if !ok { + return nil, errWatcherClose + } + return []*Update{u}, nil +} + +// Close closes the ipWatcher. +func (i *ipWatcher) Close() { + close(i.updateChan) +} + +// AddressType indicates the address type returned by name resolution. +type AddressType uint8 + +const ( + // Backend indicates the server is a backend server. + Backend AddressType = iota + // GRPCLB indicates the server is a grpclb load balancer. + GRPCLB +) + +// AddrMetadataGRPCLB contains the information the name resolver for grpclb should provide. The +// name resolver used by the grpclb balancer is required to provide this type of metadata in +// its address updates. +type AddrMetadataGRPCLB struct { + // AddrType is the type of server (grpc load balancer or backend). + AddrType AddressType + // ServerName is the name of the grpc load balancer. Used for authentication. + ServerName string +} + +// compileUpdate compares the old resolved addresses and newly resolved addresses, +// and generates an update list +func (w *dnsWatcher) compileUpdate(newAddrs map[string]*Update) []*Update { + var res []*Update + for a, u := range w.curAddrs { + if _, ok := newAddrs[a]; !ok { + u.Op = Delete + res = append(res, u) + } + } + for a, u := range newAddrs { + if _, ok := w.curAddrs[a]; !ok { + res = append(res, u) + } + } + return res +} + +func (w *dnsWatcher) lookupSRV() map[string]*Update { + newAddrs := make(map[string]*Update) + _, srvs, err := lookupSRV(w.ctx, "grpclb", "tcp", w.host) + if err != nil { + grpclog.Infof("grpc: failed dns SRV record lookup due to %v.\n", err) + return nil + } + for _, s := range srvs { + lbAddrs, err := lookupHost(w.ctx, s.Target) + if err != nil { + grpclog.Warningf("grpc: failed load banlacer address dns lookup due to %v.\n", err) + continue + } + for _, a := range lbAddrs { + a, ok := formatIP(a) + if !ok { + grpclog.Errorf("grpc: failed IP parsing due to %v.\n", err) + continue + } + addr := a + ":" + strconv.Itoa(int(s.Port)) + newAddrs[addr] = &Update{Addr: addr, + Metadata: AddrMetadataGRPCLB{AddrType: GRPCLB, ServerName: s.Target}} + } + } + return newAddrs +} + +func (w *dnsWatcher) lookupHost() map[string]*Update { + newAddrs := make(map[string]*Update) + addrs, err := lookupHost(w.ctx, w.host) + if err != nil { + grpclog.Warningf("grpc: failed dns A record lookup due to %v.\n", err) + return nil + } + for _, a := range addrs { + a, ok := formatIP(a) + if !ok { + grpclog.Errorf("grpc: failed IP parsing due to %v.\n", err) + continue + } + addr := a + ":" + w.port + newAddrs[addr] = &Update{Addr: addr} + } + return newAddrs +} + +func (w *dnsWatcher) lookup() []*Update { + newAddrs := w.lookupSRV() + if newAddrs == nil { + // If failed to get any balancer address (either no corresponding SRV for the + // target, or caused by failure during resolution/parsing of the balancer target), + // return any A record info available. + newAddrs = w.lookupHost() + } + result := w.compileUpdate(newAddrs) + w.curAddrs = newAddrs + return result +} + +// Next returns the resolved address update(delta) for the target. If there's no +// change, it will sleep for 30 mins and try to resolve again after that. +func (w *dnsWatcher) Next() ([]*Update, error) { + for { + select { + case <-w.ctx.Done(): + return nil, errWatcherClose + case <-w.t.C: + } + result := w.lookup() + // Next lookup should happen after an interval defined by w.r.freq. + w.t.Reset(w.r.freq) + if len(result) > 0 { + return result, nil + } + } +} + +func (w *dnsWatcher) Close() { + w.cancel() +} diff --git a/vendor/google.golang.org/grpc/naming/dns_resolver_test.go b/vendor/google.golang.org/grpc/naming/dns_resolver_test.go new file mode 100644 index 000000000..be1ac1aec --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/dns_resolver_test.go @@ -0,0 +1,315 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import ( + "fmt" + "net" + "reflect" + "sync" + "testing" + "time" +) + +func newUpdateWithMD(op Operation, addr, lb string) *Update { + return &Update{ + Op: op, + Addr: addr, + Metadata: AddrMetadataGRPCLB{AddrType: GRPCLB, ServerName: lb}, + } +} + +func toMap(u []*Update) map[string]*Update { + m := make(map[string]*Update) + for _, v := range u { + m[v.Addr] = v + } + return m +} + +func TestCompileUpdate(t *testing.T) { + tests := []struct { + oldAddrs []string + newAddrs []string + want []*Update + }{ + { + []string{}, + []string{"1.0.0.1"}, + []*Update{{Op: Add, Addr: "1.0.0.1"}}, + }, + { + []string{"1.0.0.1"}, + []string{"1.0.0.1"}, + []*Update{}, + }, + { + []string{"1.0.0.0"}, + []string{"1.0.0.1"}, + []*Update{{Op: Delete, Addr: "1.0.0.0"}, {Op: Add, Addr: "1.0.0.1"}}, + }, + { + []string{"1.0.0.1"}, + []string{"1.0.0.0"}, + []*Update{{Op: Add, Addr: "1.0.0.0"}, {Op: Delete, Addr: "1.0.0.1"}}, + }, + { + []string{"1.0.0.1"}, + []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"}, + []*Update{{Op: Add, Addr: "1.0.0.2"}, {Op: Add, Addr: "1.0.0.3"}}, + }, + { + []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"}, + []string{"1.0.0.0"}, + []*Update{{Op: Add, Addr: "1.0.0.0"}, {Op: Delete, Addr: "1.0.0.1"}, {Op: Delete, Addr: "1.0.0.2"}, {Op: Delete, Addr: "1.0.0.3"}}, + }, + { + []string{"1.0.0.1", "1.0.0.3", "1.0.0.5"}, + []string{"1.0.0.2", "1.0.0.3", "1.0.0.6"}, + []*Update{{Op: Delete, Addr: "1.0.0.1"}, {Op: Add, Addr: "1.0.0.2"}, {Op: Delete, Addr: "1.0.0.5"}, {Op: Add, Addr: "1.0.0.6"}}, + }, + { + []string{"1.0.0.1", "1.0.0.1", "1.0.0.2"}, + []string{"1.0.0.1"}, + []*Update{{Op: Delete, Addr: "1.0.0.2"}}, + }, + } + + var w dnsWatcher + for _, c := range tests { + w.curAddrs = make(map[string]*Update) + newUpdates := make(map[string]*Update) + for _, a := range c.oldAddrs { + w.curAddrs[a] = &Update{Addr: a} + } + for _, a := range c.newAddrs { + newUpdates[a] = &Update{Addr: a} + } + r := w.compileUpdate(newUpdates) + if !reflect.DeepEqual(toMap(c.want), toMap(r)) { + t.Errorf("w(%+v).compileUpdate(%+v) = %+v, want %+v", c.oldAddrs, c.newAddrs, updatesToSlice(r), updatesToSlice(c.want)) + } + } +} + +func TestResolveFunc(t *testing.T) { + tests := []struct { + addr string + want error + }{ + // TODO(yuxuanli): More false cases? + {"www.google.com", nil}, + {"foo.bar:12345", nil}, + {"127.0.0.1", nil}, + {"127.0.0.1:12345", nil}, + {"[::1]:80", nil}, + {"[2001:db8:a0b:12f0::1]:21", nil}, + {":80", nil}, + {"127.0.0...1:12345", nil}, + {"[fe80::1%lo0]:80", nil}, + {"golang.org:http", nil}, + {"[2001:db8::1]:http", nil}, + {":", nil}, + {"", errMissingAddr}, + {"[2001:db8:a0b:12f0::1", fmt.Errorf("invalid target address %v", "[2001:db8:a0b:12f0::1")}, + } + + r, err := NewDNSResolver() + if err != nil { + t.Errorf("%v", err) + } + for _, v := range tests { + _, err := r.Resolve(v.addr) + if !reflect.DeepEqual(err, v.want) { + t.Errorf("Resolve(%q) = %v, want %v", v.addr, err, v.want) + } + } +} + +var hostLookupTbl = map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + "ipv4.single.fake": {"1.2.3.4"}, + "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"}, + "ipv6.single.fake": {"2607:f8b0:400a:801::1001"}, + "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"}, +} + +func hostLookup(host string) ([]string, error) { + if addrs, ok := hostLookupTbl[host]; ok { + return addrs, nil + } + return nil, fmt.Errorf("failed to lookup host:%s resolution in hostLookupTbl", host) +} + +var srvLookupTbl = map[string][]*net.SRV{ + "_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}}, + "_grpclb._tcp.srv.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}}, + "_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}}, + "_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}}, +} + +func srvLookup(service, proto, name string) (string, []*net.SRV, error) { + cname := "_" + service + "._" + proto + "." + name + if srvs, ok := srvLookupTbl[cname]; ok { + return cname, srvs, nil + } + return "", nil, fmt.Errorf("failed to lookup srv record for %s in srvLookupTbl", cname) +} + +func updatesToSlice(updates []*Update) []Update { + res := make([]Update, len(updates)) + for i, u := range updates { + res[i] = *u + } + return res +} + +func testResolver(t *testing.T, freq time.Duration, slp time.Duration) { + tests := []struct { + target string + want []*Update + }{ + { + "foo.bar.com", + []*Update{{Op: Add, Addr: "1.2.3.4" + colonDefaultPort}, {Op: Add, Addr: "5.6.7.8" + colonDefaultPort}}, + }, + { + "foo.bar.com:1234", + []*Update{{Op: Add, Addr: "1.2.3.4:1234"}, {Op: Add, Addr: "5.6.7.8:1234"}}, + }, + { + "srv.ipv4.single.fake", + []*Update{newUpdateWithMD(Add, "1.2.3.4:1234", "ipv4.single.fake")}, + }, + { + "srv.ipv4.multi.fake", + []*Update{ + newUpdateWithMD(Add, "1.2.3.4:1234", "ipv4.multi.fake"), + newUpdateWithMD(Add, "5.6.7.8:1234", "ipv4.multi.fake"), + newUpdateWithMD(Add, "9.10.11.12:1234", "ipv4.multi.fake")}, + }, + { + "srv.ipv6.single.fake", + []*Update{newUpdateWithMD(Add, "[2607:f8b0:400a:801::1001]:1234", "ipv6.single.fake")}, + }, + { + "srv.ipv6.multi.fake", + []*Update{ + newUpdateWithMD(Add, "[2607:f8b0:400a:801::1001]:1234", "ipv6.multi.fake"), + newUpdateWithMD(Add, "[2607:f8b0:400a:801::1002]:1234", "ipv6.multi.fake"), + newUpdateWithMD(Add, "[2607:f8b0:400a:801::1003]:1234", "ipv6.multi.fake"), + }, + }, + } + + for _, a := range tests { + r, err := NewDNSResolverWithFreq(freq) + if err != nil { + t.Fatalf("%v\n", err) + } + w, err := r.Resolve(a.target) + if err != nil { + t.Fatalf("%v\n", err) + } + updates, err := w.Next() + if err != nil { + t.Fatalf("%v\n", err) + } + if !reflect.DeepEqual(toMap(a.want), toMap(updates)) { + t.Errorf("Resolve(%q) = %+v, want %+v\n", a.target, updatesToSlice(updates), updatesToSlice(a.want)) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + _, err := w.Next() + if err != nil { + return + } + t.Error("Execution shouldn't reach here, since w.Next() should be blocked until close happen.") + } + }() + // Sleep for sometime to let watcher do more than one lookup + time.Sleep(slp) + w.Close() + wg.Wait() + } +} + +func TestResolve(t *testing.T) { + defer replaceNetFunc()() + testResolver(t, time.Millisecond*5, time.Millisecond*10) +} + +const colonDefaultPort = ":" + defaultPort + +func TestIPWatcher(t *testing.T) { + tests := []struct { + target string + want []*Update + }{ + {"127.0.0.1", []*Update{{Op: Add, Addr: "127.0.0.1" + colonDefaultPort}}}, + {"127.0.0.1:12345", []*Update{{Op: Add, Addr: "127.0.0.1:12345"}}}, + {"::1", []*Update{{Op: Add, Addr: "[::1]" + colonDefaultPort}}}, + {"[::1]:12345", []*Update{{Op: Add, Addr: "[::1]:12345"}}}, + {"[::1]:", []*Update{{Op: Add, Addr: "[::1]:443"}}}, + {"2001:db8:85a3::8a2e:370:7334", []*Update{{Op: Add, Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, + {"[2001:db8:85a3::8a2e:370:7334]", []*Update{{Op: Add, Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, + {"[2001:db8:85a3::8a2e:370:7334]:12345", []*Update{{Op: Add, Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}}}, + {"[2001:db8::1]:http", []*Update{{Op: Add, Addr: "[2001:db8::1]:http"}}}, + // TODO(yuxuanli): zone support? + } + + for _, v := range tests { + r, err := NewDNSResolverWithFreq(time.Millisecond * 5) + if err != nil { + t.Fatalf("%v\n", err) + } + w, err := r.Resolve(v.target) + if err != nil { + t.Fatalf("%v\n", err) + } + var updates []*Update + var wg sync.WaitGroup + wg.Add(1) + count := 0 + go func() { + defer wg.Done() + for { + u, err := w.Next() + if err != nil { + return + } + updates = u + count++ + } + }() + // Sleep for sometime to let watcher do more than one lookup + time.Sleep(time.Millisecond * 10) + w.Close() + wg.Wait() + if !reflect.DeepEqual(v.want, updates) { + t.Errorf("Resolve(%q) = %v, want %+v\n", v.target, updatesToSlice(updates), updatesToSlice(v.want)) + } + if count != 1 { + t.Errorf("IPWatcher Next() should return only once, not %d times\n", count) + } + } +} diff --git a/vendor/google.golang.org/grpc/naming/go17.go b/vendor/google.golang.org/grpc/naming/go17.go new file mode 100644 index 000000000..8bdf21e79 --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/go17.go @@ -0,0 +1,34 @@ +// +build go1.7, !go1.8 + +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import ( + "net" + + "golang.org/x/net/context" +) + +var ( + lookupHost = func(ctx context.Context, host string) ([]string, error) { return net.LookupHost(host) } + lookupSRV = func(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + return net.LookupSRV(service, proto, name) + } +) diff --git a/vendor/google.golang.org/grpc/naming/go17_test.go b/vendor/google.golang.org/grpc/naming/go17_test.go new file mode 100644 index 000000000..d1de221a5 --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/go17_test.go @@ -0,0 +1,42 @@ +// +build go1.7, !go1.8 + +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import ( + "net" + + "golang.org/x/net/context" +) + +func replaceNetFunc() func() { + oldLookupHost := lookupHost + oldLookupSRV := lookupSRV + lookupHost = func(ctx context.Context, host string) ([]string, error) { + return hostLookup(host) + } + lookupSRV = func(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + return srvLookup(service, proto, name) + } + return func() { + lookupHost = oldLookupHost + lookupSRV = oldLookupSRV + } +} diff --git a/vendor/google.golang.org/grpc/naming/go18.go b/vendor/google.golang.org/grpc/naming/go18.go new file mode 100644 index 000000000..b5a0f8427 --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/go18.go @@ -0,0 +1,28 @@ +// +build go1.8 + +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import "net" + +var ( + lookupHost = net.DefaultResolver.LookupHost + lookupSRV = net.DefaultResolver.LookupSRV +) diff --git a/vendor/google.golang.org/grpc/naming/go18_test.go b/vendor/google.golang.org/grpc/naming/go18_test.go new file mode 100644 index 000000000..5e297539b --- /dev/null +++ b/vendor/google.golang.org/grpc/naming/go18_test.go @@ -0,0 +1,41 @@ +// +build go1.8 + +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package naming + +import ( + "context" + "net" +) + +func replaceNetFunc() func() { + oldLookupHost := lookupHost + oldLookupSRV := lookupSRV + lookupHost = func(ctx context.Context, host string) ([]string, error) { + return hostLookup(host) + } + lookupSRV = func(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + return srvLookup(service, proto, name) + } + return func() { + lookupHost = oldLookupHost + lookupSRV = oldLookupSRV + } +} diff --git a/vendor/google.golang.org/grpc/picker_wrapper.go b/vendor/google.golang.org/grpc/picker_wrapper.go new file mode 100644 index 000000000..9085dbc9c --- /dev/null +++ b/vendor/google.golang.org/grpc/picker_wrapper.go @@ -0,0 +1,141 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" + "google.golang.org/grpc/transport" +) + +// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick +// actions and unblock when there's a picker update. +type pickerWrapper struct { + mu sync.Mutex + done bool + blockingCh chan struct{} + picker balancer.Picker +} + +func newPickerWrapper() *pickerWrapper { + bp := &pickerWrapper{blockingCh: make(chan struct{})} + return bp +} + +// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. +func (bp *pickerWrapper) updatePicker(p balancer.Picker) { + bp.mu.Lock() + if bp.done { + bp.mu.Unlock() + return + } + bp.picker = p + // bp.blockingCh should never be nil. + close(bp.blockingCh) + bp.blockingCh = make(chan struct{}) + bp.mu.Unlock() +} + +// pick returns the transport that will be used for the RPC. +// It may block in the following cases: +// - there's no picker +// - the current picker returns ErrNoSubConnAvailable +// - the current picker returns other errors and failfast is false. +// - the subConn returned by the current picker is not READY +// When one of these situations happens, pick blocks until the picker gets updated. +func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.PickOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) { + var ( + p balancer.Picker + ch chan struct{} + ) + + for { + bp.mu.Lock() + if bp.done { + bp.mu.Unlock() + return nil, nil, ErrClientConnClosing + } + + if bp.picker == nil { + ch = bp.blockingCh + } + if ch == bp.blockingCh { + // This could happen when either: + // - bp.picker is nil (the previous if condition), or + // - has called pick on the current picker. + bp.mu.Unlock() + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case <-ch: + } + continue + } + + ch = bp.blockingCh + p = bp.picker + bp.mu.Unlock() + + subConn, put, err := p.Pick(ctx, opts) + + if err != nil { + switch err { + case balancer.ErrNoSubConnAvailable: + continue + case balancer.ErrTransientFailure: + if !failfast { + continue + } + return nil, nil, status.Errorf(codes.Unavailable, "%v", err) + default: + // err is some other error. + return nil, nil, toRPCErr(err) + } + } + + acw, ok := subConn.(*acBalancerWrapper) + if !ok { + grpclog.Infof("subconn returned from pick is not *acBalancerWrapper") + continue + } + if t, ok := acw.getAddrConn().getReadyTransport(); ok { + return t, put, nil + } + grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") + // If ok == false, ac.state is not READY. + // A valid picker always returns READY subConn. This means the state of ac + // just changed, and picker will be updated shortly. + // continue back to the beginning of the for loop to repick. + } +} + +func (bp *pickerWrapper) close() { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.done { + return + } + bp.done = true + close(bp.blockingCh) +} diff --git a/vendor/google.golang.org/grpc/picker_wrapper_test.go b/vendor/google.golang.org/grpc/picker_wrapper_test.go new file mode 100644 index 000000000..23bc8f243 --- /dev/null +++ b/vendor/google.golang.org/grpc/picker_wrapper_test.go @@ -0,0 +1,160 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + _ "google.golang.org/grpc/grpclog/glogger" + "google.golang.org/grpc/test/leakcheck" + "google.golang.org/grpc/transport" +) + +const goroutineCount = 5 + +var ( + testT = &testTransport{} + testSC = &acBalancerWrapper{ac: &addrConn{ + state: connectivity.Ready, + transport: testT, + }} + testSCNotReady = &acBalancerWrapper{ac: &addrConn{ + state: connectivity.TransientFailure, + }} +) + +type testTransport struct { + transport.ClientTransport +} + +type testingPicker struct { + err error + sc balancer.SubConn + maxCalled int64 +} + +func (p *testingPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + if atomic.AddInt64(&p.maxCalled, -1) < 0 { + return nil, nil, fmt.Errorf("Pick called to many times (> goroutineCount)") + } + if p.err != nil { + return nil, nil, p.err + } + return p.sc, nil, nil +} + +func TestBlockingPickTimeout(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, _, err := bp.pick(ctx, true, balancer.PickOptions{}); err != context.DeadlineExceeded { + t.Errorf("bp.pick returned error %v, want DeadlineExceeded", err) + } +} + +func TestBlockingPick(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + // All goroutines should block because picker is nil in bp. + var finishedCount uint64 + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(50 * time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickNoSubAvailable(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + var finishedCount uint64 + bp.updatePicker(&testingPicker{err: balancer.ErrNoSubConnAvailable, maxCalled: goroutineCount}) + // All goroutines should block because picker returns no sc avilable. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(50 * time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickTransientWaitforready(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + bp.updatePicker(&testingPicker{err: balancer.ErrTransientFailure, maxCalled: goroutineCount}) + var finishedCount uint64 + // All goroutines should block because picker returns transientFailure and + // picks are not failfast. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), false, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickSCNotReady(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + bp.updatePicker(&testingPicker{sc: testSCNotReady, maxCalled: goroutineCount}) + var finishedCount uint64 + // All goroutines should block because sc is not ready. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} diff --git a/vendor/google.golang.org/grpc/pickfirst.go b/vendor/google.golang.org/grpc/pickfirst.go new file mode 100644 index 000000000..7f993ef5a --- /dev/null +++ b/vendor/google.golang.org/grpc/pickfirst.go @@ -0,0 +1,95 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +func newPickfirstBuilder() balancer.Builder { + return &pickfirstBuilder{} +} + +type pickfirstBuilder struct{} + +func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + return &pickfirstBalancer{cc: cc} +} + +func (*pickfirstBuilder) Name() string { + return "pickfirst" +} + +type pickfirstBalancer struct { + cc balancer.ClientConn + sc balancer.SubConn +} + +func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + if err != nil { + grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err) + return + } + if b.sc == nil { + b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err) + return + } + b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) + } else { + b.sc.UpdateAddresses(addrs) + } +} + +func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) + if b.sc != sc || s == connectivity.Shutdown { + b.sc = nil + return + } + + switch s { + case connectivity.Ready, connectivity.Idle: + b.cc.UpdateBalancerState(s, &picker{sc: sc}) + case connectivity.Connecting: + b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrNoSubConnAvailable}) + case connectivity.TransientFailure: + b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrTransientFailure}) + } +} + +func (b *pickfirstBalancer) Close() { +} + +type picker struct { + err error + sc balancer.SubConn +} + +func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + if p.err != nil { + return nil, nil, p.err + } + return p.sc, nil, nil +} diff --git a/vendor/google.golang.org/grpc/pickfirst_test.go b/vendor/google.golang.org/grpc/pickfirst_test.go new file mode 100644 index 000000000..e58b3422c --- /dev/null +++ b/vendor/google.golang.org/grpc/pickfirst_test.go @@ -0,0 +1,352 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "math" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/test/leakcheck" +) + +func TestOneBackendPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + // The second RPC should succeed. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestBackendsPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestNewAddressWhileBlockingPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until NewAddress is called. + Invoke(context.Background(), "/foo/bar", &req, &reply, cc) + }() + } + time.Sleep(50 * time.Millisecond) + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + wg.Wait() +} + +func TestCloseWithPendingRPCPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + _, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until NewAddress is called. + Invoke(context.Background(), "/foo/bar", &req, &reply, cc) + }() + } + time.Sleep(50 * time.Millisecond) + cc.Close() + wg.Wait() +} + +func TestOneServerDownPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + + servers[0].stop() + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestAllServersDownPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + + for i := 0; i < numServers; i++ { + servers[i].stop() + } + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); Code(err) == codes.Unavailable { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, error with code unavailable", err) +} + +func TestAddressesRemovedPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 3 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}, {Addr: servers[2].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[0]. + r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Append server[0], nothing should change. + r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}, {Addr: servers[0].addr}}) + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[1]. + r.NewAddress([]resolver.Address{{Addr: servers[2].addr}, {Addr: servers[0].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[2]. + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/vendor/google.golang.org/grpc/proxy.go b/vendor/google.golang.org/grpc/proxy.go index 2d40236e2..3e17efec6 100644 --- a/vendor/google.golang.org/grpc/proxy.go +++ b/vendor/google.golang.org/grpc/proxy.go @@ -82,7 +82,8 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ Header: map[string][]string{"User-Agent": {grpcUA}}, }) - if err := sendHTTPRequest(ctx, req, conn); err != nil { + req = req.WithContext(ctx) + if err := req.Write(conn); err != nil { return nil, fmt.Errorf("failed to write the HTTP request: %v", err) } diff --git a/vendor/google.golang.org/grpc/proxy_test.go b/vendor/google.golang.org/grpc/proxy_test.go index 835b15af9..39ee123cc 100644 --- a/vendor/google.golang.org/grpc/proxy_test.go +++ b/vendor/google.golang.org/grpc/proxy_test.go @@ -1,3 +1,5 @@ +// +build !race + /* * * Copyright 2017 gRPC authors. @@ -28,6 +30,7 @@ import ( "time" "golang.org/x/net/context" + "google.golang.org/grpc/test/leakcheck" ) const ( @@ -45,29 +48,6 @@ func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() { } } -func TestMapAddressEnv(t *testing.T) { - // Overwrite the function in the test and restore them in defer. - hpfe := func(req *http.Request) (*url.URL, error) { - if req.URL.Host == envTestAddr { - return &url.URL{ - Scheme: "https", - Host: envProxyAddr, - }, nil - } - return nil, nil - } - defer overwrite(hpfe)() - - // envTestAddr should be handled by ProxyFromEnvironment. - got, err := mapAddress(context.Background(), envTestAddr) - if err != nil { - t.Error(err) - } - if got != envProxyAddr { - t.Errorf("want %v, got %v", envProxyAddr, got) - } -} - type proxyServer struct { t *testing.T lis net.Listener @@ -118,6 +98,7 @@ func (p *proxyServer) stop() { } func TestHTTPConnect(t *testing.T) { + defer leakcheck.Check(t) plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) @@ -175,3 +156,27 @@ func TestHTTPConnect(t *testing.T) { t.Fatalf("received msg: %v, want %v", recvBuf, msg) } } + +func TestMapAddressEnv(t *testing.T) { + defer leakcheck.Check(t) + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == envTestAddr { + return &url.URL{ + Scheme: "https", + Host: envProxyAddr, + }, nil + } + return nil, nil + } + defer overwrite(hpfe)() + + // envTestAddr should be handled by ProxyFromEnvironment. + got, err := mapAddress(context.Background(), envTestAddr) + if err != nil { + t.Error(err) + } + if got != envProxyAddr { + t.Errorf("want %v, got %v", envProxyAddr, got) + } +} diff --git a/vendor/google.golang.org/grpc/resolver/resolver.go b/vendor/google.golang.org/grpc/resolver/resolver.go new file mode 100644 index 000000000..49307e8fe --- /dev/null +++ b/vendor/google.golang.org/grpc/resolver/resolver.go @@ -0,0 +1,143 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package resolver defines APIs for name resolution in gRPC. +// All APIs in this package are experimental. +package resolver + +var ( + // m is a map from scheme to resolver builder. + m = make(map[string]Builder) + // defaultScheme is the default scheme to use. + defaultScheme string +) + +// TODO(bar) install dns resolver in init(){}. + +// Register registers the resolver builder to the resolver map. +// b.Scheme will be used as the scheme registered with this builder. +func Register(b Builder) { + m[b.Scheme()] = b +} + +// Get returns the resolver builder registered with the given scheme. +// If no builder is register with the scheme, the default scheme will +// be used. +// If the default scheme is not modified, "dns" will be the default +// scheme, and the preinstalled dns resolver will be used. +// If the default scheme is modified, and a resolver is registered with +// the scheme, that resolver will be returned. +// If the default scheme is modified, and no resolver is registered with +// the scheme, nil will be returned. +func Get(scheme string) Builder { + if b, ok := m[scheme]; ok { + return b + } + if b, ok := m[defaultScheme]; ok { + return b + } + return nil +} + +// SetDefaultScheme sets the default scheme that will be used. +// The default default scheme is "dns". +func SetDefaultScheme(scheme string) { + defaultScheme = scheme +} + +// AddressType indicates the address type returned by name resolution. +type AddressType uint8 + +const ( + // Backend indicates the address is for a backend server. + Backend AddressType = iota + // GRPCLB indicates the address is for a grpclb load balancer. + GRPCLB +) + +// Address represents a server the client connects to. +// This is the EXPERIMENTAL API and may be changed or extended in the future. +type Address struct { + // Addr is the server address on which a connection will be established. + Addr string + // Type is the type of this address. + Type AddressType + // ServerName is the name of this address. + // It's the name of the grpc load balancer, which will be used for authentication. + ServerName string + // Metadata is the information associated with Addr, which may be used + // to make load balancing decision. + Metadata interface{} +} + +// BuildOption includes additional information for the builder to create +// the resolver. +type BuildOption struct { +} + +// ClientConn contains the callbacks for resolver to notify any updates +// to the gRPC ClientConn. +type ClientConn interface { + // NewAddress is called by resolver to notify ClientConn a new list + // of resolved addresses. + // The address list should be the complete list of resolved addresses. + NewAddress(addresses []Address) + // NewServiceConfig is called by resolver to notify ClientConn a new + // service config. The service config should be provided as a json string. + NewServiceConfig(serviceConfig string) +} + +// Target represents a target for gRPC, as specified in: +// https://github.com/grpc/grpc/blob/master/doc/naming.md. +type Target struct { + Scheme string + Authority string + Endpoint string +} + +// Builder creates a resolver that will be used to watch name resolution updates. +type Builder interface { + // Build creates a new resolver for the given target. + // + // gRPC dial calls Build synchronously, and fails if the returned error is + // not nil. + Build(target Target, cc ClientConn, opts BuildOption) (Resolver, error) + // Scheme returns the scheme supported by this resolver. + // Scheme is defined at https://github.com/grpc/grpc/blob/master/doc/naming.md. + Scheme() string +} + +// ResolveNowOption includes additional information for ResolveNow. +type ResolveNowOption struct{} + +// Resolver watches for the updates on the specified target. +// Updates include address updates and service config updates. +type Resolver interface { + // ResolveNow will be called by gRPC to try to resolve the target name again. + // It's just a hint, resolver can ignore this if it's not necessary. + ResolveNow(ResolveNowOption) + // Close closes the resolver. + Close() +} + +// UnregisterForTesting removes the resolver builder with the given scheme from the +// resolver map. +// This function is for testing only. +func UnregisterForTesting(scheme string) { + delete(m, scheme) +} diff --git a/vendor/google.golang.org/grpc/resolver_conn_wrapper.go b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go new file mode 100644 index 000000000..7d53964d0 --- /dev/null +++ b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go @@ -0,0 +1,139 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "strings" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +// ccResolverWrapper is a wrapper on top of cc for resolvers. +// It implements resolver.ClientConnection interface. +type ccResolverWrapper struct { + cc *ClientConn + resolver resolver.Resolver + addrCh chan []resolver.Address + scCh chan string + done chan struct{} +} + +// split2 returns the values from strings.SplitN(s, sep, 2). +// If sep is not found, it returns "", s instead. +func split2(s, sep string) (string, string) { + spl := strings.SplitN(s, sep, 2) + if len(spl) < 2 { + return "", s + } + return spl[0], spl[1] +} + +// parseTarget splits target into a struct containing scheme, authority and +// endpoint. +func parseTarget(target string) (ret resolver.Target) { + ret.Scheme, ret.Endpoint = split2(target, "://") + ret.Authority, ret.Endpoint = split2(ret.Endpoint, "/") + return ret +} + +// newCCResolverWrapper parses cc.target for scheme and gets the resolver +// builder for this scheme. It then builds the resolver and starts the +// monitoring goroutine for it. +// +// This function could return nil, nil, in tests for old behaviors. +// TODO(bar) never return nil, nil when DNS becomes the default resolver. +func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { + target := parseTarget(cc.target) + grpclog.Infof("dialing to target with scheme: %q", target.Scheme) + + rb := resolver.Get(target.Scheme) + if rb == nil { + // TODO(bar) return error when DNS becomes the default (implemented and + // registered by DNS package). + grpclog.Infof("could not get resolver for scheme: %q", target.Scheme) + return nil, nil + } + + ccr := &ccResolverWrapper{ + cc: cc, + addrCh: make(chan []resolver.Address, 1), + scCh: make(chan string, 1), + done: make(chan struct{}), + } + + var err error + ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{}) + if err != nil { + return nil, err + } + go ccr.watcher() + return ccr, nil +} + +// watcher processes address updates and service config updates sequencially. +// Otherwise, we need to resolve possible races between address and service +// config (e.g. they specify different balancer types). +func (ccr *ccResolverWrapper) watcher() { + for { + select { + case <-ccr.done: + return + default: + } + + select { + case addrs := <-ccr.addrCh: + grpclog.Infof("ccResolverWrapper: sending new addresses to balancer wrapper: %v", addrs) + // TODO(bar switching) this should never be nil. Pickfirst should be default. + if ccr.cc.balancerWrapper != nil { + // TODO(bar switching) create balancer if it's nil? + ccr.cc.balancerWrapper.handleResolvedAddrs(addrs, nil) + } + case sc := <-ccr.scCh: + grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) + case <-ccr.done: + return + } + } +} + +func (ccr *ccResolverWrapper) close() { + ccr.resolver.Close() + close(ccr.done) +} + +// NewAddress is called by the resolver implemenetion to send addresses to gRPC. +func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { + select { + case <-ccr.addrCh: + default: + } + ccr.addrCh <- addrs +} + +// NewServiceConfig is called by the resolver implemenetion to send service +// configs to gPRC. +func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { + select { + case <-ccr.scCh: + default: + } + ccr.scCh <- sc +} diff --git a/vendor/google.golang.org/grpc/resolver_conn_wrapper_test.go b/vendor/google.golang.org/grpc/resolver_conn_wrapper_test.go new file mode 100644 index 000000000..024942301 --- /dev/null +++ b/vendor/google.golang.org/grpc/resolver_conn_wrapper_test.go @@ -0,0 +1,47 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "testing" + + "google.golang.org/grpc/resolver" +) + +func TestParseTarget(t *testing.T) { + for _, test := range []resolver.Target{ + {"", "", ""}, + {"a", "", ""}, + {"", "a", ""}, + {"", "", "a"}, + {"a", "b", ""}, + {"a", "", "b"}, + {"", "a", "b"}, + {"a", "b", "c"}, + {"dns", "a.server.com", "google.com"}, + {"dns", "a.server.com", "google.com"}, + {"dns", "a.server.com", "google.com/?a=b"}, + } { + str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint + got := parseTarget(str) + if got != test { + t.Errorf("parseTarget(%q) = %v, want %v", str, got, test) + } + } +} diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go index ace206b8f..d5f142d82 100644 --- a/vendor/google.golang.org/grpc/rpc_util.go +++ b/vendor/google.golang.org/grpc/rpc_util.go @@ -21,10 +21,12 @@ package grpc import ( "bytes" "compress/gzip" + stdctx "context" "encoding/binary" "io" "io/ioutil" "math" + "os" "sync" "time" @@ -63,6 +65,7 @@ func NewGZIPCompressor() Compressor { func (c *gzipCompressor) Do(w io.Writer, p []byte) error { z := c.pool.Get().(*gzip.Writer) + defer c.pool.Put(z) z.Reset(w) if _, err := z.Write(p); err != nil { return err @@ -131,7 +134,9 @@ type callInfo struct { creds credentials.PerRPCCredentials } -var defaultCallInfo = callInfo{failFast: true} +func defaultCallInfo() *callInfo { + return &callInfo{failFast: true} +} // CallOption configures a Call before it starts or extracts information from // a Call after it completes. @@ -287,19 +292,20 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt return pf, msg, nil } -// encode serializes msg and prepends the message header. If msg is nil, it -// generates the message header of 0 message length. -func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) { - var ( - b []byte - length uint +// encode serializes msg and returns a buffer of message header and a buffer of msg. +// If msg is nil, it generates the message header and an empty msg buffer. +func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) { + var b []byte + const ( + payloadLen = 1 + sizeLen = 4 ) + if msg != nil { var err error - // TODO(zhaoq): optimize to reduce memory alloc and copying. b, err = c.Marshal(msg) if err != nil { - return nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if outPayload != nil { outPayload.Payload = msg @@ -309,39 +315,28 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl } if cp != nil { if err := cp.Do(cbuf, b); err != nil { - return nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } b = cbuf.Bytes() } - length = uint(len(b)) - } - if length > math.MaxUint32 { - return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length) } - const ( - payloadLen = 1 - sizeLen = 4 - ) - - var buf = make([]byte, payloadLen+sizeLen+len(b)) + if uint(len(b)) > math.MaxUint32 { + return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + } - // Write payload format + bufHeader := make([]byte, payloadLen+sizeLen) if cp == nil { - buf[0] = byte(compressionNone) + bufHeader[0] = byte(compressionNone) } else { - buf[0] = byte(compressionMade) + bufHeader[0] = byte(compressionMade) } // Write length of b into buf - binary.BigEndian.PutUint32(buf[1:], uint32(length)) - // Copy encoded msg to buf - copy(buf[5:], b) - + binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) if outPayload != nil { - outPayload.WireLength = len(buf) + outPayload.WireLength = payloadLen + sizeLen + len(b) } - - return buf, nil + return bufHeader, b, nil } func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error { @@ -393,14 +388,15 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ } type rpcInfo struct { + failfast bool bytesSent bool bytesReceived bool } type rpcInfoContextKey struct{} -func newContextWithRPCInfo(ctx context.Context) context.Context { - return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{}) +func newContextWithRPCInfo(ctx context.Context, failfast bool) context.Context { + return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{failfast: failfast}) } func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { @@ -410,11 +406,63 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { if ss, ok := rpcInfoFromContext(ctx); ok { - *ss = s + ss.bytesReceived = s.bytesReceived + ss.bytesSent = s.bytesSent } return } +// toRPCErr converts an error into an error from the status package. +func toRPCErr(err error) error { + if _, ok := status.FromError(err); ok { + return err + } + switch e := err.(type) { + case transport.StreamError: + return status.Error(e.Code, e.Desc) + case transport.ConnectionError: + return status.Error(codes.Unavailable, e.Desc) + default: + switch err { + case context.DeadlineExceeded, stdctx.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled, stdctx.Canceled: + return status.Error(codes.Canceled, err.Error()) + case ErrClientConnClosing: + return status.Error(codes.FailedPrecondition, err.Error()) + } + } + return status.Error(codes.Unknown, err.Error()) +} + +// convertCode converts a standard Go error into its canonical code. Note that +// this is only used to translate the error returned by the server applications. +func convertCode(err error) codes.Code { + switch err { + case nil: + return codes.OK + case io.EOF: + return codes.OutOfRange + case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF: + return codes.FailedPrecondition + case os.ErrInvalid: + return codes.InvalidArgument + case context.Canceled, stdctx.Canceled: + return codes.Canceled + case context.DeadlineExceeded, stdctx.DeadlineExceeded: + return codes.DeadlineExceeded + } + switch { + case os.IsExist(err): + return codes.AlreadyExists + case os.IsNotExist(err): + return codes.NotFound + case os.IsPermission(err): + return codes.PermissionDenied + } + return codes.Unknown +} + // Code returns the error code for err if it was produced by the rpc system. // Otherwise, it returns codes.Unknown. // @@ -461,7 +509,7 @@ type MethodConfig struct { // MaxReqSize is the maximum allowed payload size for an individual request in a // stream (client->server) in bytes. The size which is measured is the serialized // payload after per-message compression (but before stream compression) in bytes. - // The actual value used is the minumum of the value specified here and the value set + // The actual value used is the minimum of the value specified here and the value set // by the application via the gRPC client API. If either one is not set, then the other // will be used. If neither is set, then the built-in default is used. MaxReqSize *int @@ -506,7 +554,7 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { // SupportPackageIsVersion3 is referenced from generated protocol buffer files. // The latest support package version is 4. -// SupportPackageIsVersion3 is kept for compability. It will be removed in the +// SupportPackageIsVersion3 is kept for compatibility. It will be removed in the // next support package version update. const SupportPackageIsVersion3 = true @@ -519,6 +567,6 @@ const SupportPackageIsVersion3 = true const SupportPackageIsVersion4 = true // Version is the current grpc version. -const Version = "1.5.2" +const Version = "1.7.1" const grpcUA = "grpc-go/" + Version diff --git a/vendor/google.golang.org/grpc/rpc_util_test.go b/vendor/google.golang.org/grpc/rpc_util_test.go index 7cbad491a..23c471e2e 100644 --- a/vendor/google.golang.org/grpc/rpc_util_test.go +++ b/vendor/google.golang.org/grpc/rpc_util_test.go @@ -104,14 +104,15 @@ func TestEncode(t *testing.T) { msg proto.Message cp Compressor // outputs - b []byte - err error + hdr []byte + data []byte + err error }{ - {nil, nil, []byte{0, 0, 0, 0, 0}, nil}, + {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, } { - b, err := encode(protoCodec{}, test.msg, nil, nil, nil) - if err != test.err || !bytes.Equal(b, test.b) { - t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err) + hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil) + if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { + t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) } } } @@ -164,8 +165,8 @@ func TestToRPCErr(t *testing.T) { // bytes. func bmEncode(b *testing.B, mSize int) { msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encoded, _ := encode(protoCodec{}, msg, nil, nil, nil) - encodedSz := int64(len(encoded)) + encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil) + encodedSz := int64(len(encodeHdr) + len(encodeData)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 5e9da3d95..7c882dbe6 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "math" "net" "net/http" "reflect" @@ -48,7 +49,7 @@ import ( const ( defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 - defaultServerMaxSendMessageSize = 1024 * 1024 * 4 + defaultServerMaxSendMessageSize = math.MaxInt32 ) type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error) @@ -115,6 +116,8 @@ type options struct { keepalivePolicy keepalive.EnforcementPolicy initialWindowSize int32 initialConnWindowSize int32 + writeBufferSize int + readBufferSize int } var defaultServerOptions = options{ @@ -125,6 +128,22 @@ var defaultServerOptions = options{ // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. type ServerOption func(*options) +// WriteBufferSize lets you set the size of write buffer, this determines how much data can be batched +// before doing a write on the wire. +func WriteBufferSize(s int) ServerOption { + return func(o *options) { + o.writeBufferSize = s + } +} + +// ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most +// for one read syscall. +func ReadBufferSize(s int) ServerOption { + return func(o *options) { + o.readBufferSize = s + } +} + // InitialWindowSize returns a ServerOption that sets window size for stream. // The lower bound for window size is 64K and any value smaller than that will be ignored. func InitialWindowSize(s int32) ServerOption { @@ -259,7 +278,7 @@ func StatsHandler(h stats.Handler) ServerOption { // handler that will be invoked instead of returning the "unimplemented" gRPC // error whenever a request is received for an unregistered service or method. // The handling function has full access to the Context of the request and the -// stream, and the invocation passes through interceptors. +// stream, and the invocation bypasses interceptors. func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { return func(o *options) { o.unknownStreamDesc = &StreamDesc{ @@ -523,6 +542,8 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) KeepalivePolicy: s.opts.keepalivePolicy, InitialWindowSize: s.opts.initialWindowSize, InitialConnWindowSize: s.opts.initialConnWindowSize, + WriteBufferSize: s.opts.writeBufferSize, + ReadBufferSize: s.opts.readBufferSize, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -588,6 +609,30 @@ func (s *Server) serveUsingHandler(conn net.Conn) { }) } +// ServeHTTP implements the Go standard library's http.Handler +// interface by responding to the gRPC request r, by looking up +// the requested gRPC method in the gRPC server s. +// +// The provided HTTP request must have arrived on an HTTP/2 +// connection. When using the Go standard library's server, +// practically this means that the Request must also have arrived +// over TLS. +// +// To share one port (such as 443 for https) between gRPC and an +// existing http.Handler, use a root http.Handler such as: +// +// if r.ProtoMajor == 2 && strings.HasPrefix( +// r.Header.Get("Content-Type"), "application/grpc") { +// grpcServer.ServeHTTP(w, r) +// } else { +// yourMux.ServeHTTP(w, r) +// } +// +// Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally +// separate from grpc-go's HTTP/2 server. Performance and features may vary +// between the two paths. ServeHTTP does not support some gRPC features +// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL +// and subject to change. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { st, err := transport.NewServerHandlerTransport(w, r) if err != nil { @@ -652,15 +697,15 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) + hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err } - if len(p) > s.opts.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize) + if len(data) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) } - err = t.Write(stream, p, opts) + err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) @@ -866,9 +911,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp trInfo: trInfo, statsHandler: sh, } - if ss.cp != nil { - ss.cbuf = new(bytes.Buffer) - } if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) defer func() { diff --git a/vendor/google.golang.org/grpc/server_test.go b/vendor/google.golang.org/grpc/server_test.go index 335f56a0d..6438b5f8a 100644 --- a/vendor/google.golang.org/grpc/server_test.go +++ b/vendor/google.golang.org/grpc/server_test.go @@ -23,6 +23,8 @@ import ( "reflect" "strings" "testing" + + "google.golang.org/grpc/test/leakcheck" ) type emptyServiceServer interface{} @@ -30,6 +32,7 @@ type emptyServiceServer interface{} type testServer struct{} func TestStopBeforeServe(t *testing.T) { + defer leakcheck.Check(t) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to create listener: %v", err) @@ -51,6 +54,7 @@ func TestStopBeforeServe(t *testing.T) { } func TestGetServiceInfo(t *testing.T) { + defer leakcheck.Check(t) testSd := ServiceDesc{ ServiceName: "grpc.testing.EmptyService", HandlerType: (*emptyServiceServer)(nil), diff --git a/vendor/google.golang.org/grpc/stats/stats.go b/vendor/google.golang.org/grpc/stats/stats.go index b85f166c5..d5aa2f793 100644 --- a/vendor/google.golang.org/grpc/stats/stats.go +++ b/vendor/google.golang.org/grpc/stats/stats.go @@ -16,6 +16,8 @@ * */ +//go:generate protoc --go_out=plugins=grpc:. grpc_testing/test.proto + // Package stats is for collecting and reporting various network and RPC stats. // This package is for monitoring purpose only. All fields are read-only. // All APIs are experimental. @@ -24,6 +26,8 @@ package stats // import "google.golang.org/grpc/stats" import ( "net" "time" + + "golang.org/x/net/context" ) // RPCStats contains stats information about RPCs. @@ -131,8 +135,6 @@ func (s *OutPayload) isRPCStats() {} type OutHeader struct { // Client is true if this OutHeader is from client side. Client bool - // WireLength is the wire length of header. - WireLength int // The following fields are valid only if Client is true. // FullMethod is the full RPC method string, i.e., /package.service/method. @@ -169,7 +171,9 @@ type End struct { Client bool // EndTime is the time when the RPC ends. EndTime time.Time - // Error is the error just happened. It implements status.Status if non-nil. + // Error is the error the RPC ended with. It is an error generated from + // status.Status and can be converted back to status.Status using + // status.FromError if non-nil. Error error } @@ -206,3 +210,85 @@ type ConnEnd struct { func (s *ConnEnd) IsClient() bool { return s.Client } func (s *ConnEnd) isConnStats() {} + +type incomingTagsKey struct{} +type outgoingTagsKey struct{} + +// SetTags attaches stats tagging data to the context, which will be sent in +// the outgoing RPC with the header grpc-tags-bin. Subsequent calls to +// SetTags will overwrite the values from earlier calls. +// +// NOTE: this is provided only for backward compatibility with existing clients +// and will likely be removed in an upcoming release. New uses should transmit +// this type of data using metadata with a different, non-reserved (i.e. does +// not begin with "grpc-") header name. +func SetTags(ctx context.Context, b []byte) context.Context { + return context.WithValue(ctx, outgoingTagsKey{}, b) +} + +// Tags returns the tags from the context for the inbound RPC. +// +// NOTE: this is provided only for backward compatibility with existing clients +// and will likely be removed in an upcoming release. New uses should transmit +// this type of data using metadata with a different, non-reserved (i.e. does +// not begin with "grpc-") header name. +func Tags(ctx context.Context) []byte { + b, _ := ctx.Value(incomingTagsKey{}).([]byte) + return b +} + +// SetIncomingTags attaches stats tagging data to the context, to be read by +// the application (not sent in outgoing RPCs). +// +// This is intended for gRPC-internal use ONLY. +func SetIncomingTags(ctx context.Context, b []byte) context.Context { + return context.WithValue(ctx, incomingTagsKey{}, b) +} + +// OutgoingTags returns the tags from the context for the outbound RPC. +// +// This is intended for gRPC-internal use ONLY. +func OutgoingTags(ctx context.Context) []byte { + b, _ := ctx.Value(outgoingTagsKey{}).([]byte) + return b +} + +type incomingTraceKey struct{} +type outgoingTraceKey struct{} + +// SetTrace attaches stats tagging data to the context, which will be sent in +// the outgoing RPC with the header grpc-trace-bin. Subsequent calls to +// SetTrace will overwrite the values from earlier calls. +// +// NOTE: this is provided only for backward compatibility with existing clients +// and will likely be removed in an upcoming release. New uses should transmit +// this type of data using metadata with a different, non-reserved (i.e. does +// not begin with "grpc-") header name. +func SetTrace(ctx context.Context, b []byte) context.Context { + return context.WithValue(ctx, outgoingTraceKey{}, b) +} + +// Trace returns the trace from the context for the inbound RPC. +// +// NOTE: this is provided only for backward compatibility with existing clients +// and will likely be removed in an upcoming release. New uses should transmit +// this type of data using metadata with a different, non-reserved (i.e. does +// not begin with "grpc-") header name. +func Trace(ctx context.Context) []byte { + b, _ := ctx.Value(incomingTraceKey{}).([]byte) + return b +} + +// SetIncomingTrace attaches stats tagging data to the context, to be read by +// the application (not sent in outgoing RPCs). It is intended for +// gRPC-internal use. +func SetIncomingTrace(ctx context.Context, b []byte) context.Context { + return context.WithValue(ctx, incomingTraceKey{}, b) +} + +// OutgoingTrace returns the trace from the context for the outbound RPC. It is +// intended for gRPC-internal use. +func OutgoingTrace(ctx context.Context) []byte { + b, _ := ctx.Value(outgoingTraceKey{}).([]byte) + return b +} diff --git a/vendor/google.golang.org/grpc/stats/stats_test.go b/vendor/google.golang.org/grpc/stats/stats_test.go index a25773d82..141324c0d 100644 --- a/vendor/google.golang.org/grpc/stats/stats_test.go +++ b/vendor/google.golang.org/grpc/stats/stats_test.go @@ -106,7 +106,7 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ } func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error { - md, ok := metadata.FromContext(stream.Context()) + md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) @@ -130,7 +130,7 @@ func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCall } func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error { - md, ok := metadata.FromContext(stream.Context()) + md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) @@ -330,7 +330,7 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *test err error ) tc := testpb.NewTestServiceClient(te.clientConn()) - stream, err := tc.ClientStreamCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) + stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) if err != nil { return reqs, resp, err } @@ -365,7 +365,7 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*test startID = errorID } req = &testpb.SimpleRequest{Id: startID} - stream, err := tc.ServerStreamCall(metadata.NewContext(context.Background(), testMetadata), req, grpc.FailFast(c.failfast)) + stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.FailFast(c.failfast)) if err != nil { return req, resps, err } @@ -444,10 +444,6 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want <non-nil>") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } if !d.client { if st.FullMethod != e.method { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) @@ -530,18 +526,13 @@ func checkInPayload(t *testing.T, d *gotData, e *expectedData) { func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { var ( ok bool - st *stats.InTrailer ) - if st, ok = d.s.(*stats.InTrailer); !ok { + if _, ok = d.s.(*stats.InTrailer); !ok { t.Fatalf("got %T, want InTrailer", d.s) } if d.ctx == nil { t.Fatalf("d.ctx = nil, want <non-nil>") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } } func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { @@ -555,10 +546,6 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want <non-nil>") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } if d.client { if st.FullMethod != e.method { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) @@ -642,10 +629,6 @@ func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { if st.Client { t.Fatalf("st IsClient = true, want false") } - // TODO check real length, not just > 0. - if st.WireLength <= 0 { - t.Fatalf("st.Lenght = 0, want > 0") - } } func checkEnd(t *testing.T, d *gotData, e *expectedData) { @@ -830,7 +813,9 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f err: err, } + h.mu.Lock() checkConnStats(t, h.gotConn) + h.mu.Unlock() checkServerStats(t, h.gotRPC, expect, checkFuncs) } @@ -1123,7 +1108,9 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map err: err, } + h.mu.Lock() checkConnStats(t, h.gotConn) + h.mu.Unlock() checkClientStats(t, h.gotRPC, expect, checkFuncs) } @@ -1238,3 +1225,41 @@ func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) { end: {checkEnd, 1}, }) } + +func TestTags(t *testing.T) { + b := []byte{5, 2, 4, 3, 1} + ctx := stats.SetTags(context.Background(), b) + if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) { + t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b) + } + if tg := stats.Tags(ctx); tg != nil { + t.Errorf("Tags(%v) = %v; want nil", ctx, tg) + } + + ctx = stats.SetIncomingTags(context.Background(), b) + if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) { + t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b) + } + if tg := stats.OutgoingTags(ctx); tg != nil { + t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg) + } +} + +func TestTrace(t *testing.T) { + b := []byte{5, 2, 4, 3, 1} + ctx := stats.SetTrace(context.Background(), b) + if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) { + t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b) + } + if tr := stats.Trace(ctx); tr != nil { + t.Errorf("Trace(%v) = %v; want nil", ctx, tr) + } + + ctx = stats.SetIncomingTrace(context.Background(), b) + if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) { + t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b) + } + if tr := stats.OutgoingTrace(ctx); tr != nil { + t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) + } +} diff --git a/vendor/google.golang.org/grpc/status/status.go b/vendor/google.golang.org/grpc/status/status.go index 68a3ac2f0..871dc4b31 100644 --- a/vendor/google.golang.org/grpc/status/status.go +++ b/vendor/google.golang.org/grpc/status/status.go @@ -28,9 +28,11 @@ package status import ( + "errors" "fmt" "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" ) @@ -128,3 +130,39 @@ func FromError(err error) (s *Status, ok bool) { } return nil, false } + +// WithDetails returns a new status with the provided details messages appended to the status. +// If any errors are encountered, it returns nil and the first error encountered. +func (s *Status) WithDetails(details ...proto.Message) (*Status, error) { + if s.Code() == codes.OK { + return nil, errors.New("no error details for status with code OK") + } + // s.Code() != OK implies that s.Proto() != nil. + p := s.Proto() + for _, detail := range details { + any, err := ptypes.MarshalAny(detail) + if err != nil { + return nil, err + } + p.Details = append(p.Details, any) + } + return &Status{s: p}, nil +} + +// Details returns a slice of details messages attached to the status. +// If a detail cannot be decoded, the error is returned in place of the detail. +func (s *Status) Details() []interface{} { + if s == nil || s.s == nil { + return nil + } + details := make([]interface{}, 0, len(s.s.Details)) + for _, any := range s.s.Details { + detail := &ptypes.DynamicAny{} + if err := ptypes.UnmarshalAny(any, detail); err != nil { + details = append(details, err) + continue + } + details = append(details, detail.Message) + } + return details +} diff --git a/vendor/google.golang.org/grpc/status/status_test.go b/vendor/google.golang.org/grpc/status/status_test.go index 798ee8d78..69be8c9f6 100644 --- a/vendor/google.golang.org/grpc/status/status_test.go +++ b/vendor/google.golang.org/grpc/status/status_test.go @@ -19,11 +19,17 @@ package status import ( + "errors" + "fmt" "reflect" "testing" "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" apb "github.com/golang/protobuf/ptypes/any" + dpb "github.com/golang/protobuf/ptypes/duration" + cpb "google.golang.org/genproto/googleapis/rpc/code" + epb "google.golang.org/genproto/googleapis/rpc/errdetails" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" ) @@ -112,3 +118,144 @@ func TestFromErrorOK(t *testing.T) { t.Fatalf("FromError(nil) = %v, %v; want <Code()=%s, Message()=%q, Err=nil>, true", s, ok, code, message) } } + +func TestStatus_ErrorDetails(t *testing.T) { + tests := []struct { + code codes.Code + details []proto.Message + }{ + { + code: codes.NotFound, + details: nil, + }, + { + code: codes.NotFound, + details: []proto.Message{ + &epb.ResourceInfo{ + ResourceType: "book", + ResourceName: "projects/1234/books/5678", + Owner: "User", + }, + }, + }, + { + code: codes.Internal, + details: []proto.Message{ + &epb.DebugInfo{ + StackEntries: []string{ + "first stack", + "second stack", + }, + }, + }, + }, + { + code: codes.Unavailable, + details: []proto.Message{ + &epb.RetryInfo{ + RetryDelay: &dpb.Duration{Seconds: 60}, + }, + &epb.ResourceInfo{ + ResourceType: "book", + ResourceName: "projects/1234/books/5678", + Owner: "User", + }, + }, + }, + } + + for _, tc := range tests { + s, err := New(tc.code, "").WithDetails(tc.details...) + if err != nil { + t.Fatalf("(%v).WithDetails(%+v) failed: %v", str(s), tc.details, err) + } + details := s.Details() + for i := range details { + if !proto.Equal(details[i].(proto.Message), tc.details[i]) { + t.Fatalf("(%v).Details()[%d] = %+v, want %+v", str(s), i, details[i], tc.details[i]) + } + } + } +} + +func TestStatus_WithDetails_Fail(t *testing.T) { + tests := []*Status{ + nil, + FromProto(nil), + New(codes.OK, ""), + } + for _, s := range tests { + if s, err := s.WithDetails(); err == nil || s != nil { + t.Fatalf("(%v).WithDetails(%+v) = %v, %v; want nil, non-nil", str(s), []proto.Message{}, s, err) + } + } +} + +func TestStatus_ErrorDetails_Fail(t *testing.T) { + tests := []struct { + s *Status + i []interface{} + }{ + { + nil, + nil, + }, + { + FromProto(nil), + nil, + }, + { + New(codes.OK, ""), + []interface{}{}, + }, + { + FromProto(&spb.Status{ + Code: int32(cpb.Code_CANCELLED), + Details: []*apb.Any{ + { + TypeUrl: "", + Value: []byte{}, + }, + mustMarshalAny(&epb.ResourceInfo{ + ResourceType: "book", + ResourceName: "projects/1234/books/5678", + Owner: "User", + }), + }, + }), + []interface{}{ + errors.New(`message type url "" is invalid`), + &epb.ResourceInfo{ + ResourceType: "book", + ResourceName: "projects/1234/books/5678", + Owner: "User", + }, + }, + }, + } + for _, tc := range tests { + got := tc.s.Details() + if !reflect.DeepEqual(got, tc.i) { + t.Errorf("(%v).Details() = %+v, want %+v", str(tc.s), got, tc.i) + } + } +} + +func str(s *Status) string { + if s == nil { + return "nil" + } + if s.s == nil { + return "<Code=OK>" + } + return fmt.Sprintf("<Code=%v, Message=%q, Details=%+v>", codes.Code(s.s.GetCode()), s.s.GetMessage(), s.s.GetDetails()) +} + +// mustMarshalAny converts a protobuf message to an any. +func mustMarshalAny(msg proto.Message) *apb.Any { + any, err := ptypes.MarshalAny(msg) + if err != nil { + panic(fmt.Sprintf("ptypes.MarshalAny(%+v) failed: %v", msg, err)) + } + return any +} diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 1c621ba87..75eab40b1 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -27,6 +27,7 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -106,10 +107,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth var ( t transport.ClientTransport s *transport.Stream - put func() + done func(balancer.DoneInfo) cancel context.CancelFunc ) - c := defaultCallInfo + c := defaultCallInfo() mc := cc.GetMethodConfig(method) if mc.WaitForReady != nil { c.failFast = !*mc.WaitForReady @@ -117,11 +118,16 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if mc.Timeout != nil { ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) + defer func() { + if err != nil { + cancel() + } + }() } opts = append(cc.dopts.callOptions, opts...) for _, o := range opts { - if err := o.before(&c); err != nil { + if err := o.before(c); err != nil { return nil, toRPCErr(err) } } @@ -162,7 +168,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } - ctx = newContextWithRPCInfo(ctx) + ctx = newContextWithRPCInfo(ctx, c.failFast) sh := cc.dopts.copts.StatsHandler if sh != nil { ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) @@ -183,11 +189,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } - gopts := BalancerGetOptions{ - BlockingWait: !c.failFast, - } for { - t, put, err = cc.getTransport(ctx, gopts) + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := status.FromError(err); ok { @@ -205,15 +208,15 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { - if _, ok := err.(transport.ConnectionError); ok && put != nil { + if _, ok := err.(transport.ConnectionError); ok && done != nil { // If error is connection error, transport was sending data on wire, // and we are not sure if anything has been sent on wire. // If error is not connection error, we are sure nothing has been sent. updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) } - if put != nil { - put() - put = nil + if done != nil { + done(balancer.DoneInfo{Err: err}) + done = nil } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -235,10 +238,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth dc: cc.dopts.dc, cancel: cancel, - put: put, - t: t, - s: s, - p: &parser{r: s}, + done: done, + t: t, + s: s, + p: &parser{r: s}, tracing: EnableTracing, trInfo: trInfo, @@ -246,9 +249,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth statsCtx: ctx, statsHandler: cc.dopts.copts.StatsHandler, } - if cc.dopts.cp != nil { - cs.cbuf = new(bytes.Buffer) - } // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination // when there is no pending I/O operations on this stream. go func() { @@ -278,21 +278,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { opts []CallOption - c callInfo + c *callInfo t transport.ClientTransport s *transport.Stream p *parser desc *StreamDesc codec Codec cp Compressor - cbuf *bytes.Buffer dc Decompressor cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. mu sync.Mutex - put func() + done func(balancer.DoneInfo) closed bool finished bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), @@ -342,7 +341,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { return } if err == io.EOF { - // Specialize the process for server streaming. SendMesg is only called + // Specialize the process for server streaming. SendMsg is only called // once when creating the stream object. io.EOF needs to be skipped when // the rpc is early finished (before the stream object is created.). // TODO: It is probably better to move this into the generated code. @@ -362,22 +361,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, } } - out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) - defer func() { - if cs.cbuf != nil { - cs.cbuf.Reset() - } - }() + hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload) if err != nil { return err } if cs.c.maxSendMessageSize == nil { return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } - if len(out) > *cs.c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize) + if len(data) > *cs.c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } - err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) + err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) @@ -449,7 +443,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } func (cs *clientStream) CloseSend() (err error) { - err = cs.t.Write(cs.s, nil, &transport.Options{Last: true}) + err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) defer func() { if err != nil { cs.finish(err) @@ -489,15 +483,15 @@ func (cs *clientStream) finish(err error) { } }() for _, o := range cs.opts { - o.after(&cs.c) + o.after(cs.c) } - if cs.put != nil { + if cs.done != nil { updateRPCInfoInContext(cs.s.Context(), rpcInfo{ bytesSent: cs.s.BytesSent(), bytesReceived: cs.s.BytesReceived(), }) - cs.put() - cs.put = nil + cs.done(balancer.DoneInfo{Err: err}) + cs.done = nil } if cs.statsHandler != nil { end := &stats.End{ @@ -552,7 +546,6 @@ type serverStream struct { codec Codec cp Compressor dc Decompressor - cbuf *bytes.Buffer maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -599,24 +592,23 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } ss.mu.Unlock() } + if err != nil && err != io.EOF { + st, _ := status.FromError(toRPCErr(err)) + ss.t.WriteStatus(ss.s, st) + } }() var outPayload *stats.OutPayload if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) - defer func() { - if ss.cbuf != nil { - ss.cbuf.Reset() - } - }() + hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload) if err != nil { return err } - if len(out) > ss.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize) + if len(data) > ss.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } - if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { + if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } if outPayload != nil { @@ -640,6 +632,10 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } ss.mu.Unlock() } + if err != nil && err != io.EOF { + st, _ := status.FromError(toRPCErr(err)) + ss.t.WriteStatus(ss.s, st) + } }() var inPayload *stats.InPayload if ss.statsHandler != nil { diff --git a/vendor/google.golang.org/grpc/tap/tap.go b/vendor/google.golang.org/grpc/tap/tap.go index decb6786b..22b8fb50d 100644 --- a/vendor/google.golang.org/grpc/tap/tap.go +++ b/vendor/google.golang.org/grpc/tap/tap.go @@ -32,8 +32,20 @@ type Info struct { // TODO: More to be added. } -// ServerInHandle defines the function which runs when a new stream is created -// on the server side. Note that it is executed in the per-connection I/O goroutine(s) instead -// of per-RPC goroutine. Therefore, users should NOT have any blocking/time-consuming -// work in this handle. Otherwise all the RPCs would slow down. +// ServerInHandle defines the function which runs before a new stream is created +// on the server side. If it returns a non-nil error, the stream will not be +// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM. +// The client will receive an RPC error "code = Unavailable, desc = stream +// terminated by RST_STREAM with error code: REFUSED_STREAM". +// +// It's intended to be used in situations where you don't want to waste the +// resources to accept the new stream (e.g. rate-limiting). And the content of +// the error will be ignored and won't be sent back to the client. For other +// general usages, please use interceptors. +// +// Note that it is executed in the per-connection I/O goroutine(s) instead of +// per-RPC goroutine. Therefore, users should NOT have any +// blocking/time-consuming work in this handle. Otherwise all the RPCs would +// slow down. Also, for the same reason, this handle won't be called +// concurrently by gRPC. type ServerInHandle func(ctx context.Context, info *Info) (context.Context, error) diff --git a/vendor/google.golang.org/grpc/trace.go b/vendor/google.golang.org/grpc/trace.go index b419c9e3d..c1c96dedc 100644 --- a/vendor/google.golang.org/grpc/trace.go +++ b/vendor/google.golang.org/grpc/trace.go @@ -31,7 +31,7 @@ import ( // EnableTracing controls whether to trace RPCs using the golang.org/x/net/trace package. // This should only be set before any RPCs are sent or received by this program. -var EnableTracing = true +var EnableTracing bool // methodFamily returns the trace family for the given method. // It turns "/pkg.Service/GetFoo" into "pkg.Service". @@ -76,6 +76,15 @@ func (f *firstLine) String() string { return line.String() } +const truncateSize = 100 + +func truncate(x string, l int) string { + if l > len(x) { + return x + } + return x[:l] +} + // payload represents an RPC request or response payload. type payload struct { sent bool // whether this is an outgoing payload @@ -85,9 +94,9 @@ type payload struct { func (p payload) String() string { if p.sent { - return fmt.Sprintf("sent: %v", p.msg) + return truncate(fmt.Sprintf("sent: %v", p.msg), truncateSize) } - return fmt.Sprintf("recv: %v", p.msg) + return truncate(fmt.Sprintf("recv: %v", p.msg), truncateSize) } type fmtStringer struct { diff --git a/vendor/google.golang.org/grpc/transport/bdp_estimator.go b/vendor/google.golang.org/grpc/transport/bdp_estimator.go index 643652ade..8dd2ed427 100644 --- a/vendor/google.golang.org/grpc/transport/bdp_estimator.go +++ b/vendor/google.golang.org/grpc/transport/bdp_estimator.go @@ -1,3 +1,21 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + package transport import ( @@ -41,7 +59,7 @@ type bdpEstimator struct { sample uint32 // bwMax is the maximum bandwidth noted so far (bytes/sec). bwMax float64 - // bool to keep track of the begining of a new measurement cycle. + // bool to keep track of the beginning of a new measurement cycle. isSent bool // Callback to update the window sizes. updateFlowControl func(n uint32) @@ -52,7 +70,7 @@ type bdpEstimator struct { } // timesnap registers the time bdp ping was sent out so that -// network rtt can be calculated when its ack is recieved. +// network rtt can be calculated when its ack is received. // It is called (by controller) when the bdpPing is // being written on the wire. func (b *bdpEstimator) timesnap(d [8]byte) { @@ -101,7 +119,7 @@ func (b *bdpEstimator) calculate(d [8]byte) { b.rtt += (rttSample - b.rtt) * float64(alpha) } b.isSent = false - // The number of bytes accumalated so far in the sample is smaller + // The number of bytes accumulated so far in the sample is smaller // than or equal to 1.5 times the real BDP on a saturated connection. bwCurrent := float64(b.sample) / (b.rtt * float64(1.5)) if bwCurrent > b.bwMax { diff --git a/vendor/google.golang.org/grpc/transport/control.go b/vendor/google.golang.org/grpc/transport/control.go index 501eb03c4..dd1a8d42e 100644 --- a/vendor/google.golang.org/grpc/transport/control.go +++ b/vendor/google.golang.org/grpc/transport/control.go @@ -22,9 +22,11 @@ import ( "fmt" "math" "sync" + "sync/atomic" "time" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" ) const ( @@ -44,15 +46,44 @@ const ( defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute) // max window limit set by HTTP2 Specs. maxWindowSize = math.MaxInt32 + // defaultLocalSendQuota sets is default value for number of data + // bytes that each stream can schedule before some of it being + // flushed out. + defaultLocalSendQuota = 64 * 1024 ) // The following defines various control items which could flow through // the control buffer of transport. They represent different aspects of // control tasks, e.g., flow control, settings, streaming resetting, etc. + +type headerFrame struct { + streamID uint32 + hf []hpack.HeaderField + endStream bool +} + +func (*headerFrame) item() {} + +type continuationFrame struct { + streamID uint32 + endHeaders bool + headerBlockFragment []byte +} + +type dataFrame struct { + streamID uint32 + endStream bool + d []byte + f func() +} + +func (*dataFrame) item() {} + +func (*continuationFrame) item() {} + type windowUpdate struct { streamID uint32 increment uint32 - flush bool } func (*windowUpdate) item() {} @@ -97,8 +128,9 @@ func (*ping) item() {} type quotaPool struct { c chan int - mu sync.Mutex - quota int + mu sync.Mutex + version uint32 + quota int } // newQuotaPool creates a quotaPool which has quota q available to consume. @@ -119,6 +151,10 @@ func newQuotaPool(q int) *quotaPool { func (qb *quotaPool) add(v int) { qb.mu.Lock() defer qb.mu.Unlock() + qb.lockedAdd(v) +} + +func (qb *quotaPool) lockedAdd(v int) { select { case n := <-qb.c: qb.quota += n @@ -139,6 +175,35 @@ func (qb *quotaPool) add(v int) { } } +func (qb *quotaPool) addAndUpdate(v int) { + qb.mu.Lock() + defer qb.mu.Unlock() + qb.lockedAdd(v) + // Update the version only after having added to the quota + // so that if acquireWithVesrion sees the new vesrion it is + // guaranteed to have seen the updated quota. + // Also, still keep this inside of the lock, so that when + // compareAndExecute is processing, this function doesn't + // get executed partially (quota gets updated but the version + // doesn't). + atomic.AddUint32(&(qb.version), 1) +} + +func (qb *quotaPool) acquireWithVersion() (<-chan int, uint32) { + return qb.c, atomic.LoadUint32(&(qb.version)) +} + +func (qb *quotaPool) compareAndExecute(version uint32, success, failure func()) bool { + qb.mu.Lock() + defer qb.mu.Unlock() + if version == atomic.LoadUint32(&(qb.version)) { + success() + return true + } + failure() + return false +} + // acquire returns the channel on which available quota amounts are sent. func (qb *quotaPool) acquire() <-chan int { return qb.c diff --git a/vendor/google.golang.org/grpc/transport/go16.go b/vendor/google.golang.org/grpc/transport/go16.go deleted file mode 100644 index 7cffee11e..000000000 --- a/vendor/google.golang.org/grpc/transport/go16.go +++ /dev/null @@ -1,45 +0,0 @@ -// +build go1.6,!go1.7 - -/* - * - * Copyright 2016 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package transport - -import ( - "net" - - "google.golang.org/grpc/codes" - - "golang.org/x/net/context" -) - -// dialContext connects to the address on the named network. -func dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) -} - -// ContextErr converts the error from context package into a StreamError. -func ContextErr(err error) StreamError { - switch err { - case context.DeadlineExceeded: - return streamErrorf(codes.DeadlineExceeded, "%v", err) - case context.Canceled: - return streamErrorf(codes.Canceled, "%v", err) - } - return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) -} diff --git a/vendor/google.golang.org/grpc/transport/go17.go b/vendor/google.golang.org/grpc/transport/go17.go deleted file mode 100644 index 2464e69fa..000000000 --- a/vendor/google.golang.org/grpc/transport/go17.go +++ /dev/null @@ -1,46 +0,0 @@ -// +build go1.7 - -/* - * - * Copyright 2016 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package transport - -import ( - "context" - "net" - - "google.golang.org/grpc/codes" - - netctx "golang.org/x/net/context" -) - -// dialContext connects to the address on the named network. -func dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, network, address) -} - -// ContextErr converts the error from context package into a StreamError. -func ContextErr(err error) StreamError { - switch err { - case context.DeadlineExceeded, netctx.DeadlineExceeded: - return streamErrorf(codes.DeadlineExceeded, "%v", err) - case context.Canceled, netctx.Canceled: - return streamErrorf(codes.Canceled, "%v", err) - } - return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) -} diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go index 27372b508..f1f6caf89 100644 --- a/vendor/google.golang.org/grpc/transport/handler_server.go +++ b/vendor/google.golang.org/grpc/transport/handler_server.go @@ -33,6 +33,7 @@ import ( "sync" "time" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "golang.org/x/net/http2" "google.golang.org/grpc/codes" @@ -121,6 +122,11 @@ type serverHandlerTransport struct { // ServeHTTP (HandleStreams) goroutine. The channel is closed // when WriteStatus is called. writes chan func() + + mu sync.Mutex + // streamDone indicates whether WriteStatus has been called and writes channel + // has been closed. + streamDone bool } func (ht *serverHandlerTransport) Close() error { @@ -167,11 +173,17 @@ func (ht *serverHandlerTransport) do(fn func()) error { case <-ht.closedCh: return ErrConnClosing } - } } func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { + ht.mu.Lock() + if ht.streamDone { + ht.mu.Unlock() + return nil + } + ht.streamDone = true + ht.mu.Unlock() err := ht.do(func() { ht.writeCommonHeaders(s) @@ -186,7 +198,15 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro h.Set("Grpc-Message", encodeGrpcMessage(m)) } - // TODO: Support Grpc-Status-Details-Bin + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + panic(err) + } + + h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes)) + } if md := s.Trailer(); len(md) > 0 { for k, vv := range md { @@ -225,16 +245,17 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers h.Add("Trailer", "Grpc-Status") h.Add("Trailer", "Grpc-Message") - // TODO: Support Grpc-Status-Details-Bin + h.Add("Trailer", "Grpc-Status-Details-Bin") if s.sendCompress != "" { h.Set("Grpc-Encoding", s.sendCompress) } } -func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { +func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { return ht.do(func() { ht.writeCommonHeaders(s) + ht.rw.Write(hdr) ht.rw.Write(data) if !opts.Delay { ht.rw.(http.Flusher).Flush() diff --git a/vendor/google.golang.org/grpc/transport/handler_server_test.go b/vendor/google.golang.org/grpc/transport/handler_server_test.go index 65d61a8b0..06fe813ca 100644 --- a/vendor/google.golang.org/grpc/transport/handler_server_test.go +++ b/vendor/google.golang.org/grpc/transport/handler_server_test.go @@ -26,10 +26,14 @@ import ( "net/http/httptest" "net/url" "reflect" + "sync" "testing" "time" + "github.com/golang/protobuf/proto" + dpb "github.com/golang/protobuf/ptypes/duration" "golang.org/x/net/context" + epb "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -199,6 +203,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { "user-agent": {"x/y a/b"}, "meta-foo": {"foo-val"}, } + if !reflect.DeepEqual(ht.headerMD, want) { return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want) } @@ -294,7 +299,7 @@ func TestHandlerTransport_HandleStreams(t *testing.T) { wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, - "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, "Grpc-Status": {"0"}, } if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { @@ -314,6 +319,7 @@ func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { st := newHandleStreamTest(t) + handleStream := func(s *Stream) { st.ht.WriteStatus(s, status.New(statusCode, msg)) } @@ -324,10 +330,11 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, - "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Message": {encodeGrpcMessage(msg)}, } + if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) } @@ -375,7 +382,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, - "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, "Grpc-Status": {"4"}, "Grpc-Message": {encodeGrpcMessage("too slow")}, } @@ -383,3 +390,73 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) } } + +func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { + st := newHandleStreamTest(t) + handleStream := func(s *Stream) { + if want := "/service/foo.bar"; s.method != want { + t.Errorf("stream method = %q; want %q", s.method, want) + } + st.bodyw.Close() // no body + + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + st.ht.WriteStatus(s, status.New(codes.OK, "")) + }() + } + wg.Wait() + } + st.ht.HandleStreams( + func(s *Stream) { go handleStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) +} + +func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { + errDetails := []proto.Message{ + &epb.RetryInfo{ + RetryDelay: &dpb.Duration{Seconds: 60}, + }, + &epb.ResourceInfo{ + ResourceType: "foo bar", + ResourceName: "service.foo.bar", + Owner: "User", + }, + } + + statusCode := codes.ResourceExhausted + msg := "you are being throttled" + st, err := status.New(statusCode, msg).WithDetails(errDetails...) + if err != nil { + t.Fatal(err) + } + + stBytes, err := proto.Marshal(st.Proto()) + if err != nil { + t.Fatal(err) + } + + hst := newHandleStreamTest(t) + handleStream := func(s *Stream) { + hst.ht.WriteStatus(s, st) + } + hst.ht.HandleStreams( + func(s *Stream) { go handleStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, + "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, + "Grpc-Message": {encodeGrpcMessage(msg)}, + "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, + } + + if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) { + t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader) + } +} diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go index d4fc6815e..1abb62e6d 100644 --- a/vendor/google.golang.org/grpc/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -43,6 +43,7 @@ import ( // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { ctx context.Context + cancel context.CancelFunc target string // server name/addr userAgent string md interface{} @@ -52,17 +53,6 @@ type http2Client struct { authInfo credentials.AuthInfo // auth info about the connection nextID uint32 // the next stream ID to be used - // writableChan synchronizes write access to the transport. - // A writer acquires the write lock by sending a value on writableChan - // and releases it by receiving from writableChan. - writableChan chan int - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - // TODO(zhaoq): Maybe have a channel context? - shutdownChan chan struct{} - // errorChan is closed to notify the I/O error to the caller. - errorChan chan struct{} // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} @@ -98,7 +88,8 @@ type http2Client struct { initialWindowSize int32 - bdpEst *bdpEstimator + bdpEst *bdpEstimator + outQuotaVersion uint32 mu sync.Mutex // guard the following variables state transportState // the state of underlying connection @@ -118,7 +109,7 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error if fn != nil { return fn(ctx, addr) } - return dialContext(ctx, "tcp", addr) + return (&net.Dialer{}).DialContext(ctx, "tcp", addr) } func isTemporary(err error) bool { @@ -152,9 +143,18 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { +func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) { scheme := "http" - conn, err := dial(ctx, opts.Dialer, addr.Addr) + ctx, cancel := context.WithCancel(ctx) + connectCtx, connectCancel := context.WithTimeout(ctx, timeout) + defer func() { + connectCancel() + if err != nil { + cancel() + } + }() + + conn, err := dial(connectCtx, opts.Dialer, addr.Addr) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) @@ -173,7 +173,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) + conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. @@ -197,8 +197,17 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( dynamicWindow = false } var buf bytes.Buffer + writeBufSize := defaultWriteBufSize + if opts.WriteBufferSize > 0 { + writeBufSize = opts.WriteBufferSize + } + readBufSize := defaultReadBufSize + if opts.ReadBufferSize > 0 { + readBufSize = opts.ReadBufferSize + } t := &http2Client{ ctx: ctx, + cancel: cancel, target: addr.Addr, userAgent: opts.UserAgent, md: addr.Metadata, @@ -208,14 +217,11 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, - writableChan: make(chan int, 1), - shutdownChan: make(chan struct{}), - errorChan: make(chan struct{}), goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), - framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), + framer: newFramer(conn, writeBufSize, readBufSize), controlBuf: newControlBuffer(), fc: &inFlow{limit: uint32(icwz)}, sendQuotaPool: newQuotaPool(defaultWindowSize), @@ -269,12 +275,12 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if t.initialWindowSize != defaultWindowSize { - err = t.framer.writeSettings(true, http2.Setting{ + err = t.framer.fr.WriteSettings(http2.Setting{ ID: http2.SettingInitialWindowSize, Val: uint32(t.initialWindowSize), }) } else { - err = t.framer.writeSettings(true) + err = t.framer.fr.WriteSettings() } if err != nil { t.Close() @@ -282,31 +288,35 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { - if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { + if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { t.Close() return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) } } - go t.controller() + t.framer.writer.Flush() + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() if t.kp.Time != infinity { go t.keepalive() } - t.writableChan <- 0 return t, nil } func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ - id: t.nextID, - done: make(chan struct{}), - goAway: make(chan struct{}), - method: callHdr.Method, - sendCompress: callHdr.SendCompress, - buf: newRecvBuffer(), - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), - headerChan: make(chan struct{}), + id: t.nextID, + done: make(chan struct{}), + goAway: make(chan struct{}), + method: callHdr.Method, + sendCompress: callHdr.SendCompress, + buf: newRecvBuffer(), + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), + localSendQuota: newQuotaPool(defaultLocalSendQuota), + headerChan: make(chan struct{}), } t.nextID += 2 s.requestRead = func(n int) { @@ -348,18 +358,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea // Create an audience string only if needed. if len(t.creds) > 0 || callHdr.Creds != nil { // Construct URI required to get auth request metadata. - var port string - if pos := strings.LastIndex(t.target, ":"); pos != -1 { - // Omit port if it is the default one. - if t.target[pos+1:] != "443" { - port = ":" + t.target[pos+1:] - } - } + // Omit port if it is the default one. + host := strings.TrimSuffix(callHdr.Host, ":443") pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { pos = len(callHdr.Method) } - audience = "https://" + callHdr.Host + port + callHdr.Method[:pos] + audience = "https://" + host + callHdr.Method[:pos] } for _, c := range t.creds { data, err := c.GetRequestMetadata(ctx, audience) @@ -372,13 +377,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea authData[k] = v } } - callAuthData := make(map[string]string) + callAuthData := map[string]string{} // Check if credentials.PerRPCCredentials were provided via call options. // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. if callCreds := callHdr.Creds; callCreds != nil { if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton") + return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { @@ -404,7 +409,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } t.mu.Unlock() - sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, t.ctx, nil, nil, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -412,69 +417,41 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if sq > 1 { t.streamsQuota.add(sq - 1) } - if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - // Return the quota back now because there is no stream returned to the caller. - if _, ok := err.(StreamError); ok { - t.streamsQuota.add(1) - } - return nil, err - } - t.mu.Lock() - if t.state == draining { - t.mu.Unlock() - t.streamsQuota.add(1) - // Need to make t writable again so that the rpc in flight can still proceed. - t.writableChan <- 0 - return nil, ErrStreamDrain - } - if t.state != reachable { - t.mu.Unlock() - return nil, ErrConnClosing - } - s := t.newStream(ctx, callHdr) - t.activeStreams[s.id] = s - // If the number of active streams change from 0 to 1, then check if keepalive - // has gone dormant. If so, wake it up. - if len(t.activeStreams) == 1 { - select { - case t.awakenKeepalive <- struct{}{}: - t.framer.writePing(false, false, [8]byte{}) - default: - } - } - - t.mu.Unlock() - - // HPACK encodes various headers. Note that once WriteField(...) is - // called, the corresponding headers/continuation frame has to be sent - // because hpack.Encoder is stateful. - t.hBuf.Reset() - t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}) - t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme}) - t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method}) - t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) - t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) - t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) - t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + // Make the slice of certain predictable size to reduce allocations made by append. + hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te + hfLen += len(authData) + len(callAuthData) + headerFields := make([]hpack.HeaderField, 0, hfLen) + headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) + headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) if callHdr.SendCompress != "" { - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } if dl, ok := ctx.Deadline(); ok { // Send out timeout regardless its value. The server can detect timeout context by itself. + // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. timeout := dl.Sub(time.Now()) - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } - for k, v := range authData { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } for k, v := range callAuthData { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + if b := stats.OutgoingTags(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) + } + if b := stats.OutgoingTrace(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) } - var ( - endHeaders bool - ) if md, ok := metadata.FromOutgoingContext(ctx); ok { for k, vv := range md { // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. @@ -482,7 +459,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea continue } for _, v := range vv { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } } @@ -492,46 +469,42 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea continue } for _, v := range vv { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } } - first := true - bufLen := t.hBuf.Len() - // Sends the headers in a single batch even when they span multiple frames. - for !endHeaders { - size := t.hBuf.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - var flush bool - if callHdr.Flush && endHeaders { - flush = true - } - if first { - // Sends a HeadersFrame to server to start a new stream. - p := http2.HeadersFrameParam{ - StreamID: s.id, - BlockFragment: t.hBuf.Next(size), - EndStream: false, - EndHeaders: endHeaders, - } - // Do a force flush for the buffered frames iff it is the last headers frame - // and there is header metadata to be sent. Otherwise, there is flushing until - // the corresponding data frame is written. - err = t.framer.writeHeaders(flush, p) - first = false - } else { - // Sends Continuation frames for the leftover headers. - err = t.framer.writeContinuation(flush, s.id, endHeaders, t.hBuf.Next(size)) - } - if err != nil { - t.notifyError(err) - return nil, connectionErrorf(true, err, "transport: %v", err) + t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + t.streamsQuota.add(1) + return nil, ErrStreamDrain + } + if t.state != reachable { + t.mu.Unlock() + return nil, ErrConnClosing + } + s := t.newStream(ctx, callHdr) + t.activeStreams[s.id] = s + // If the number of active streams change from 0 to 1, then check if keepalive + // has gone dormant. If so, wake it up. + if len(t.activeStreams) == 1 { + select { + case t.awakenKeepalive <- struct{}{}: + t.controlBuf.put(&ping{data: [8]byte{}}) + // Fill the awakenKeepalive channel again as this channel must be + // kept non-writable except at the point that the keepalive() + // goroutine is waiting either to be awaken or shutdown. + t.awakenKeepalive <- struct{}{} + default: } } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) + t.mu.Unlock() + s.mu.Lock() s.bytesSent = true s.mu.Unlock() @@ -539,7 +512,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.statsHandler != nil { outHeader := &stats.OutHeader{ Client: true, - WireLength: bufLen, FullMethod: callHdr.Method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, @@ -547,7 +519,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } t.statsHandler.HandleRPC(s.ctx, outHeader) } - t.writableChan <- 0 return s, nil } @@ -617,12 +588,9 @@ func (t *http2Client) Close() (err error) { t.mu.Unlock() return } - if t.state == reachable || t.state == draining { - close(t.errorChan) - } t.state = closing t.mu.Unlock() - close(t.shutdownChan) + t.cancel() err = t.conn.Close() t.mu.Lock() streams := t.activeStreams @@ -644,23 +612,18 @@ func (t *http2Client) Close() (err error) { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return + return err } +// GracefulClose sets the state to draining, which prevents new streams from +// being created and causes the transport to be closed when the last active +// stream is closed. If there are no active streams, the transport is closed +// immediately. This does nothing if the transport is already draining or +// closing. func (t *http2Client) GracefulClose() error { t.mu.Lock() switch t.state { - case unreachable: - // The server may close the connection concurrently. t is not available for - // any streams. Close it now. - t.mu.Unlock() - t.Close() - return nil - case closing: - t.mu.Unlock() - return nil - } - if t.state == draining { + case closing, draining: t.mu.Unlock() return nil } @@ -675,21 +638,38 @@ func (t *http2Client) GracefulClose() error { // Write formats the data into HTTP2 data frame(s) and sends it out. The caller // should proceed only if Write returns nil. -// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later -// if it improves the performance. -func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { - r := bytes.NewBuffer(data) - for { - var p []byte - if r.Len() > 0 { +func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + select { + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + + if hdr == nil && data == nil && opts.Last { + // stream.CloseSend uses this to send an empty frame with endStream=True + t.controlBuf.put(&dataFrame{streamID: s.id, endStream: true, f: func() {}}) + return nil + } + // Add data to header frame so that we can equally distribute data across frames. + emptyLen := http2MaxFrameLen - len(hdr) + if emptyLen > len(data) { + emptyLen = len(data) + } + hdr = append(hdr, data[:emptyLen]...) + data = data[emptyLen:] + for idx, r := range [][]byte{hdr, data} { + for len(r) > 0 { size := http2MaxFrameLen // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) + quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() + sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan) if err != nil { return err } // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire()) if err != nil { return err } @@ -699,69 +679,51 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { if tq < size { size = tq } - p = r.Next(size) - ps := len(p) - if ps < sq { - // Overbooked stream quota. Return it back. - s.sendQuotaPool.add(sq - ps) + if size > len(r) { + size = len(r) } + p := r[:size] + ps := len(p) if ps < tq { // Overbooked transport quota. Return it back. t.sendQuotaPool.add(tq - ps) } - } - var ( - endStream bool - forceFlush bool - ) - if opts.Last && r.Len() == 0 { - endStream = true - } - // Indicate there is a writer who is about to write a data frame. - t.framer.adjustNumWriters(1) - // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil { - if _, ok := err.(StreamError); ok || err == io.EOF { - // Return the connection quota back. - t.sendQuotaPool.add(len(p)) + // Acquire local send quota to be able to write to the controlBuf. + ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire()) + if err != nil { + if _, ok := err.(ConnectionError); !ok { + t.sendQuotaPool.add(ps) + } + return err } - if t.framer.adjustNumWriters(-1) == 0 { - // This writer is the last one in this batch and has the - // responsibility to flush the buffered frames. It queues - // a flush request to controlBuf instead of flushing directly - // in order to avoid the race with other writing or flushing. - t.controlBuf.put(&flushIO{}) + s.localSendQuota.add(ltq - ps) // It's ok if we make it negative. + var endStream bool + // See if this is the last frame to be written. + if opts.Last { + if len(r)-size == 0 { // No more data in r after this iteration. + if idx == 0 { // We're writing data header. + if len(data) == 0 { // There's no data to follow. + endStream = true + } + } else { // We're writing data. + endStream = true + } + } } - return err - } - select { - case <-s.ctx.Done(): - t.sendQuotaPool.add(len(p)) - if t.framer.adjustNumWriters(-1) == 0 { - t.controlBuf.put(&flushIO{}) + success := func() { + t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(ps) }}) + if ps < sq { + s.sendQuotaPool.lockedAdd(sq - ps) + } + r = r[ps:] + } + failure := func() { + s.sendQuotaPool.lockedAdd(sq) + } + if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) { + t.sendQuotaPool.add(ps) + s.localSendQuota.add(ps) } - t.writableChan <- 0 - return ContextErr(s.ctx.Err()) - default: - } - if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { - // Do a force flush iff this is last frame for the entire gRPC message - // and the caller is the only writer at this moment. - forceFlush = true - } - // If WriteData fails, all the pending streams will be handled - // by http2Client.Close(). No explicit CloseStream() needs to be - // invoked. - if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { - t.notifyError(err) - return connectionErrorf(true, err, "transport: %v", err) - } - if t.framer.adjustNumWriters(-1) == 0 { - t.framer.flushWrite() - } - t.writableChan <- 0 - if r.Len() == 0 { - break } } if !opts.Last { @@ -792,11 +754,11 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) { return } if w := s.fc.maybeAdjust(n); w > 0 { - // Piggyback conneciton's window update along. + // Piggyback connection's window update along. if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw, false}) + t.controlBuf.put(&windowUpdate{0, cw}) } - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } @@ -811,9 +773,9 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { } if w := s.fc.onRead(n); w > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw, false}) + t.controlBuf.put(&windowUpdate{0, cw}) } - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } @@ -827,7 +789,7 @@ func (t *http2Client) updateFlowControl(n uint32) { } t.initialWindowSize = int32(n) t.mu.Unlock() - t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n), false}) + t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) t.controlBuf.put(&settings{ ack: false, ss: []http2.Setting{ @@ -857,15 +819,17 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // Furthermore, if a bdpPing is being sent out we can piggyback // connection's window update for the bytes we just received. if sendBDPPing { - t.controlBuf.put(&windowUpdate{0, uint32(size), false}) + if size != 0 { // Could've been an empty data frame. + t.controlBuf.put(&windowUpdate{0, uint32(size)}) + } t.controlBuf.put(bdpPing) } else { if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(connectionErrorf(true, err, "%v", err)) + t.Close() return } if w := t.fc.onRead(uint32(size)); w > 0 { - t.controlBuf.put(&windowUpdate{0, w, true}) + t.controlBuf.put(&windowUpdate{0, w}) } } // Select the right stream to dispatch. @@ -889,7 +853,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { } if f.Header().Flags.Has(http2.FlagDataPadded) { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } s.mu.Unlock() @@ -935,7 +899,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) statusCode = codes.Unknown } - s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %d", f.ErrCode)) + s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -978,10 +942,10 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + t.Close() return } - // A client can recieve multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). + // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). // The idea is that the first GoAway will be sent with an ID of MaxInt32 and the second GoAway will be sent after an RTT delay // with the ID of the last stream the server will process. // Therefore, when we get the first GoAway we don't really close any streams. While in case of second GoAway we @@ -992,7 +956,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.Close() return } default: @@ -1136,22 +1100,22 @@ func handleMalformedHTTP2(s *Stream, err error) { // TODO(zhaoq): Check the validity of the incoming frame sequence. func (t *http2Client) reader() { // Check the validity of server preface. - frame, err := t.framer.readFrame() + frame, err := t.framer.fr.ReadFrame() if err != nil { - t.notifyError(err) + t.Close() return } atomic.CompareAndSwapUint32(&t.activity, 0, 1) sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.notifyError(err) + t.Close() return } t.handleSettings(sf) // loop to keep reading incoming messages on this transport. for { - frame, err := t.framer.readFrame() + frame, err := t.framer.fr.ReadFrame() atomic.CompareAndSwapUint32(&t.activity, 0, 1) if err != nil { // Abort an active stream if the http2.Framer returns a @@ -1163,12 +1127,12 @@ func (t *http2Client) reader() { t.mu.Unlock() if s != nil { // use error detail to provide better err message - handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail())) + handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail())) } continue } else { // Transport error. - t.notifyError(err) + t.Close() return } } @@ -1212,7 +1176,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) { t.mu.Lock() for _, stream := range t.activeStreams { // Adjust the sending quota for each stream. - stream.sendQuotaPool.add(int(s.Val) - int(t.streamSendQuota)) + stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota)) } t.streamSendQuota = s.Val t.mu.Unlock() @@ -1220,52 +1184,78 @@ func (t *http2Client) applySettings(ss []http2.Setting) { } } -// controller running in a separate goroutine takes charge of sending control -// frames (e.g., window update, reset stream, setting, etc.) to the server. -func (t *http2Client) controller() { - for { - select { - case i := <-t.controlBuf.get(): - t.controlBuf.load() - select { - case <-t.writableChan: - switch i := i.(type) { - case *windowUpdate: - t.framer.writeWindowUpdate(i.flush, i.streamID, i.increment) - case *settings: - if i.ack { - t.framer.writeSettingsAck(true) - t.applySettings(i.ss) - } else { - t.framer.writeSettings(true, i.ss...) - } - case *resetStream: - // If the server needs to be to intimated about stream closing, - // then we need to make sure the RST_STREAM frame is written to - // the wire before the headers of the next stream waiting on - // streamQuota. We ensure this by adding to the streamsQuota pool - // only after having acquired the writableChan to send RST_STREAM. - t.streamsQuota.add(1) - t.framer.writeRSTStream(true, i.streamID, i.code) - case *flushIO: - t.framer.flushWrite() - case *ping: - if !i.ack { - t.bdpEst.timesnap(i.data) - } - t.framer.writePing(true, i.ack, i.data) - default: - errorf("transport: http2Client.controller got unexpected item type %v\n", i) - } - t.writableChan <- 0 - continue - case <-t.shutdownChan: - return +// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) +// is duplicated between the client and the server. +// The transport layer needs to be refactored to take care of this. +func (t *http2Client) itemHandler(i item) error { + var err error + switch i := i.(type) { + case *dataFrame: + err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) + if err == nil { + i.f() + } + case *headerFrame: + t.hBuf.Reset() + for _, f := range i.hf { + t.hEnc.WriteField(f) + } + endHeaders := false + first := true + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true } - case <-t.shutdownChan: - return + if first { + first = false + err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: i.streamID, + BlockFragment: t.hBuf.Next(size), + EndStream: i.endStream, + EndHeaders: endHeaders, + }) + } else { + err = t.framer.fr.WriteContinuation( + i.streamID, + endHeaders, + t.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + case *windowUpdate: + err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + case *settings: + if i.ack { + t.applySettings(i.ss) + err = t.framer.fr.WriteSettingsAck() + } else { + err = t.framer.fr.WriteSettings(i.ss...) } + case *resetStream: + // If the server needs to be to intimated about stream closing, + // then we need to make sure the RST_STREAM frame is written to + // the wire before the headers of the next stream waiting on + // streamQuota. We ensure this by adding to the streamsQuota pool + // only after having acquired the writableChan to send RST_STREAM. + err = t.framer.fr.WriteRSTStream(i.streamID, i.code) + t.streamsQuota.add(1) + case *flushIO: + err = t.framer.writer.Flush() + case *ping: + if !i.ack { + t.bdpEst.timesnap(i.data) + } + err = t.framer.fr.WritePing(i.ack, i.data) + default: + errorf("transport: http2Client.controller got unexpected item type %v\n", i) } + return err } // keepalive running in a separate goroutune makes sure the connection is alive by sending pings. @@ -1289,7 +1279,7 @@ func (t *http2Client) keepalive() { case <-t.awakenKeepalive: // If the control gets here a ping has been sent // need to reset the timer with keepalive.Timeout. - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } else { @@ -1308,13 +1298,13 @@ func (t *http2Client) keepalive() { } t.Close() return - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } return } - case <-t.shutdownChan: + case <-t.ctx.Done(): if !timer.Stop() { <-timer.C } @@ -1324,25 +1314,9 @@ func (t *http2Client) keepalive() { } func (t *http2Client) Error() <-chan struct{} { - return t.errorChan + return t.ctx.Done() } func (t *http2Client) GoAway() <-chan struct{} { return t.goAway } - -func (t *http2Client) notifyError(err error) { - t.mu.Lock() - // make sure t.errorChan is closed only once. - if t.state == draining { - t.mu.Unlock() - t.Close() - return - } - if t.state == reachable { - t.state = unreachable - close(t.errorChan) - infof("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) - } - t.mu.Unlock() -} diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go index 92ab4d9c3..bad29b88a 100644 --- a/vendor/google.golang.org/grpc/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/transport/http2_server.go @@ -21,6 +21,7 @@ package transport import ( "bytes" "errors" + "fmt" "io" "math" "math/rand" @@ -51,23 +52,16 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { ctx context.Context + cancel context.CancelFunc conn net.Conn remoteAddr net.Addr localAddr net.Addr maxStreamID uint32 // max stream ID ever seen authInfo credentials.AuthInfo // auth info about the connection inTapHandle tap.ServerInHandle - // writableChan synchronizes write access to the transport. - // A writer acquires the write lock by receiving a value on writableChan - // and releases it by sending on writableChan. - writableChan chan int - // shutdownChan is closed when Close is called. - // Blocking operations should select on shutdownChan to avoid - // blocking forever after Close. - shutdownChan chan struct{} - framer *framer - hBuf *bytes.Buffer // the buffer for HPACK encoding - hEnc *hpack.Encoder // HPACK encoder + framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window @@ -110,7 +104,7 @@ type http2Server struct { // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 // idle is the time instant when the connection went idle. - // This is either the begining of the connection or when the number of + // This is either the beginning of the connection or when the number of // RPCs go down to 0. // When the connection is busy, this value is set to 0. idle time.Time @@ -119,7 +113,15 @@ type http2Server struct { // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // returned if something goes wrong. func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { - framer := newFramer(conn) + writeBufSize := defaultWriteBufSize + if config.WriteBufferSize > 0 { + writeBufSize = config.WriteBufferSize + } + readBufSize := defaultReadBufSize + if config.ReadBufferSize > 0 { + readBufSize = config.ReadBufferSize + } + framer := newFramer(conn, writeBufSize, readBufSize) // Send initial settings as connection preface to client. var isettings []http2.Setting // TODO(zhaoq): Have a better way to signal "no limit" because 0 is @@ -149,12 +151,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ID: http2.SettingInitialWindowSize, Val: uint32(iwz)}) } - if err := framer.writeSettings(true, isettings...); err != nil { + if err := framer.fr.WriteSettings(isettings...); err != nil { return nil, connectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { - if err := framer.writeWindowUpdate(true, 0, delta); err != nil { + if err := framer.fr.WriteWindowUpdate(0, delta); err != nil { return nil, connectionErrorf(true, err, "transport: %v", err) } } @@ -181,8 +183,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err kep.MinTime = defaultKeepalivePolicyMinTime } var buf bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) t := &http2Server{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -196,8 +200,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err fc: &inFlow{limit: uint32(icwz)}, sendQuotaPool: newQuotaPool(defaultWindowSize), state: reachable, - writableChan: make(chan int, 1), - shutdownChan: make(chan struct{}), activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, stats: config.StatsHandler, @@ -220,37 +222,43 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err connBegin := &stats.ConnBegin{} t.stats.HandleConn(t.ctx, connBegin) } - go t.controller() + t.framer.writer.Flush() + go func() { + loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.Close() + }() go t.keepalive() - t.writableChan <- 0 return t, nil } // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { - buf := newRecvBuffer() - s := &Stream{ - id: frame.Header().StreamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - } + streamID := frame.Header().StreamID var state decodeState for _, hf := range frame.Fields { if err := state.processHeaderField(hf); err != nil { if se, ok := err.(StreamError); ok { - t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) + t.controlBuf.put(&resetStream{streamID, statusCodeConvTab[se.Code]}) } return } } + buf := newRecvBuffer() + s := &Stream{ + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + recvCompress: state.encoding, + method: state.method, + } + if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - s.recvCompress = state.encoding if state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) } else { @@ -272,17 +280,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( if len(state.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) } - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, + if state.statsTags != nil { + s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags) + } + if state.statsTrace != nil { + s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) } - s.recvCompress = state.encoding - s.method = state.method if t.inTapHandle != nil { var err error info := &tap.Info{ @@ -302,18 +305,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + t.controlBuf.put(&resetStream{streamID, http2.ErrCodeRefusedStream}) return } - if s.id%2 != 1 || s.id <= t.maxStreamID { + if streamID%2 != 1 || streamID <= t.maxStreamID { t.mu.Unlock() // illegal gRPC stream id. - errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", s.id) + errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", streamID) return true } - t.maxStreamID = s.id + t.maxStreamID = streamID s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) - t.activeStreams[s.id] = s + s.localSendQuota = newQuotaPool(defaultLocalSendQuota) + t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} } @@ -333,6 +337,15 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } t.stats.HandleRPC(s.ctx, inHeader) } + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + recv: s.buf, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } handle(s) return } @@ -357,7 +370,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. return } - frame, err := t.framer.readFrame() + frame, err := t.framer.fr.ReadFrame() if err == io.EOF || err == io.ErrUnexpectedEOF { t.Close() return @@ -377,7 +390,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. t.handleSettings(sf) for { - frame, err := t.framer.readFrame() + frame, err := t.framer.fr.ReadFrame() atomic.StoreUint32(&t.activity, 1) if err != nil { if se, ok := err.(http2.StreamError); ok { @@ -448,9 +461,9 @@ func (t *http2Server) adjustWindow(s *Stream, n uint32) { } if w := s.fc.maybeAdjust(n); w > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw, false}) + t.controlBuf.put(&windowUpdate{0, cw}) } - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } @@ -465,9 +478,9 @@ func (t *http2Server) updateWindow(s *Stream, n uint32) { } if w := s.fc.onRead(n); w > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw, false}) + t.controlBuf.put(&windowUpdate{0, cw}) } - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } @@ -481,7 +494,7 @@ func (t *http2Server) updateFlowControl(n uint32) { } t.initialWindowSize = int32(n) t.mu.Unlock() - t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n), false}) + t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) t.controlBuf.put(&settings{ ack: false, ss: []http2.Setting{ @@ -512,7 +525,9 @@ func (t *http2Server) handleData(f *http2.DataFrame) { // Furthermore, if a bdpPing is being sent out we can piggyback // connection's window update for the bytes we just received. if sendBDPPing { - t.controlBuf.put(&windowUpdate{0, uint32(size), false}) + if size != 0 { // Could be an empty frame. + t.controlBuf.put(&windowUpdate{0, uint32(size)}) + } t.controlBuf.put(bdpPing) } else { if err := t.fc.onData(uint32(size)); err != nil { @@ -521,7 +536,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { return } if w := t.fc.onRead(uint32(size)); w > 0 { - t.controlBuf.put(&windowUpdate{0, w, true}) + t.controlBuf.put(&windowUpdate{0, w}) } } // Select the right stream to dispatch. @@ -543,7 +558,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { } if f.Header().Flags.Has(http2.FlagDataPadded) { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{s.id, w, true}) + t.controlBuf.put(&windowUpdate{s.id, w}) } } s.mu.Unlock() @@ -584,10 +599,23 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) { ss = append(ss, s) return nil }) - // The settings will be applied once the ack is sent. t.controlBuf.put(&settings{ack: true, ss: ss}) } +func (t *http2Server) applySettings(ss []http2.Setting) { + for _, s := range ss { + if s.ID == http2.SettingInitialWindowSize { + t.mu.Lock() + for _, stream := range t.activeStreams { + stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota)) + } + t.streamSendQuota = s.Val + t.mu.Unlock() + } + + } +} + const ( maxPingStrikes = 2 defaultPingTimeout = 2 * time.Hour @@ -625,7 +653,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { t.mu.Unlock() if ns < 1 && !t.kep.PermitWithoutStream { // Keepalive shouldn't be active thus, this new ping should - // have come after atleast defaultPingTimeout. + // have come after at least defaultPingTimeout. if t.lastPingAt.Add(defaultPingTimeout).After(now) { t.pingStrikes++ } @@ -638,6 +666,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { if t.pingStrikes > maxPingStrikes { // Send goaway and close the connection. + errorf("transport: Got to too many pings from the client, closing the connection.") t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: true}) } } @@ -654,47 +683,16 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { } } -func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error { - first := true - endHeaders := false - var err error - defer func() { - if err == nil { - // Reset ping strikes when seding headers since that might cause the - // peer to send ping. - atomic.StoreUint32(&t.resetPingStrikes, 1) - } - }() - // Sends the headers in a single batch. - for !endHeaders { - size := t.hBuf.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - if first { - p := http2.HeadersFrameParam{ - StreamID: s.id, - BlockFragment: b.Next(size), - EndStream: endStream, - EndHeaders: endHeaders, - } - err = t.framer.writeHeaders(endHeaders, p) - first = false - } else { - err = t.framer.writeContinuation(endHeaders, s.id, endHeaders, b.Next(size)) - } - if err != nil { - t.Close() - return connectionErrorf(true, err, "transport: %v", err) - } - } - return nil -} - // WriteHeader sends the header metedata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + select { + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + s.mu.Lock() if s.headerOk || s.state == streamDone { s.mu.Unlock() @@ -710,14 +708,13 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } md = s.header s.mu.Unlock() - if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - return err - } - t.hBuf.Reset() - t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) if s.sendCompress != "" { - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } for k, vv := range md { if isReservedHeader(k) { @@ -725,20 +722,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { continue } for _, v := range vv { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } - bufLen := t.hBuf.Len() - if err := t.writeHeaders(s, t.hBuf, false); err != nil { - return err - } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) if t.stats != nil { outHeader := &stats.OutHeader{ - WireLength: bufLen, + //WireLength: // TODO(mmukhi): Revisit this later, if needed. } t.stats.HandleRPC(s.Context(), outHeader) } - t.writableChan <- 0 return nil } @@ -747,6 +744,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { + select { + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + var headersSent, hasHeader bool s.mu.Lock() if s.state == streamDone { @@ -766,20 +769,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headersSent = true } - if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - return err - } - t.hBuf.Reset() + // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. if !headersSent { - t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) } - t.hEnc.WriteField( - hpack.HeaderField{ - Name: "grpc-status", - Value: strconv.Itoa(int(st.Code())), - }) - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) if p := st.Proto(); p != nil && len(p.Details) > 0 { stBytes, err := proto.Marshal(p) @@ -788,7 +786,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { panic(err) } - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) } // Attach the trailer metadata. @@ -798,29 +796,32 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { continue } for _, v := range vv { - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } - bufLen := t.hBuf.Len() - if err := t.writeHeaders(s, t.hBuf, true); err != nil { - t.Close() - return err - } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: true, + }) if t.stats != nil { - outTrailer := &stats.OutTrailer{ - WireLength: bufLen, - } - t.stats.HandleRPC(s.Context(), outTrailer) + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) } t.closeStream(s) - t.writableChan <- 0 return nil } // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). -func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { - // TODO(zhaoq): Support multi-writers for a single stream. +func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) { + select { + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + var writeHeaderFrame bool s.mu.Lock() if s.state == streamDone { @@ -834,103 +835,81 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { if writeHeaderFrame { t.WriteHeader(s, nil) } - r := bytes.NewBuffer(data) - for { - if r.Len() == 0 { - return nil - } - size := http2MaxFrameLen - // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) - if err != nil { - return err - } - // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) - if err != nil { - return err - } - if sq < size { - size = sq - } - if tq < size { - size = tq - } - p := r.Next(size) - ps := len(p) - if ps < sq { - // Overbooked stream quota. Return it back. - s.sendQuotaPool.add(sq - ps) - } - if ps < tq { - // Overbooked transport quota. Return it back. - t.sendQuotaPool.add(tq - ps) - } - t.framer.adjustNumWriters(1) - // Got some quota. Try to acquire writing privilege on the - // transport. - if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - if _, ok := err.(StreamError); ok { - // Return the connection quota back. - t.sendQuotaPool.add(ps) + // Add data to header frame so that we can equally distribute data across frames. + emptyLen := http2MaxFrameLen - len(hdr) + if emptyLen > len(data) { + emptyLen = len(data) + } + hdr = append(hdr, data[:emptyLen]...) + data = data[emptyLen:] + for _, r := range [][]byte{hdr, data} { + for len(r) > 0 { + size := http2MaxFrameLen + // Wait until the stream has some quota to send the data. + quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion() + sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan) + if err != nil { + return err } - if t.framer.adjustNumWriters(-1) == 0 { - // This writer is the last one in this batch and has the - // responsibility to flush the buffered frames. It queues - // a flush request to controlBuf instead of flushing directly - // in order to avoid the race with other writing or flushing. - t.controlBuf.put(&flushIO{}) + // Wait until the transport has some quota to send the data. + tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire()) + if err != nil { + return err } - return err - } - select { - case <-s.ctx.Done(): - t.sendQuotaPool.add(ps) - if t.framer.adjustNumWriters(-1) == 0 { - t.controlBuf.put(&flushIO{}) + if sq < size { + size = sq } - t.writableChan <- 0 - return ContextErr(s.ctx.Err()) - default: - } - var forceFlush bool - if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { - forceFlush = true - } - // Reset ping strikes when sending data since this might cause - // the peer to send ping. - atomic.StoreUint32(&t.resetPingStrikes, 1) - if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { - t.Close() - return connectionErrorf(true, err, "transport: %v", err) - } - if t.framer.adjustNumWriters(-1) == 0 { - t.framer.flushWrite() - } - t.writableChan <- 0 - } - -} - -func (t *http2Server) applySettings(ss []http2.Setting) { - for _, s := range ss { - if s.ID == http2.SettingInitialWindowSize { - t.mu.Lock() - defer t.mu.Unlock() - for _, stream := range t.activeStreams { - stream.sendQuotaPool.add(int(s.Val) - int(t.streamSendQuota)) + if tq < size { + size = tq + } + if size > len(r) { + size = len(r) + } + p := r[:size] + ps := len(p) + if ps < tq { + // Overbooked transport quota. Return it back. + t.sendQuotaPool.add(tq - ps) + } + // Acquire local send quota to be able to write to the controlBuf. + ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire()) + if err != nil { + if _, ok := err.(ConnectionError); !ok { + t.sendQuotaPool.add(ps) + } + return err + } + s.localSendQuota.add(ltq - ps) // It's ok we make this negative. + // Reset ping strikes when sending data since this might cause + // the peer to send ping. + atomic.StoreUint32(&t.resetPingStrikes, 1) + success := func() { + t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() { + s.localSendQuota.add(ps) + }}) + if ps < sq { + // Overbooked stream quota. Return it back. + s.sendQuotaPool.lockedAdd(sq - ps) + } + r = r[ps:] + } + failure := func() { + s.sendQuotaPool.lockedAdd(sq) + } + if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) { + t.sendQuotaPool.add(ps) + s.localSendQuota.add(ps) } - t.streamSendQuota = s.Val } - } + return nil } // keepalive running in a separate goroutine does the following: // 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle. // 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge. // 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge. -// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-resposive connection +// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-responsive connection // after an additional duration of keepalive.Timeout. func (t *http2Server) keepalive() { p := &ping{} @@ -939,7 +918,7 @@ func (t *http2Server) keepalive() { maxAge := time.NewTimer(t.kp.MaxConnectionAge) keepalive := time.NewTimer(t.kp.Time) // NOTE: All exit paths of this function should reset their - // respecitve timers. A failure to do so will cause the + // respective timers. A failure to do so will cause the // following clean-up to deadlock and eventually leak. defer func() { if !maxIdle.Stop() { @@ -982,7 +961,7 @@ func (t *http2Server) keepalive() { t.Close() // Reseting the timer so that the clean-up doesn't deadlock. maxAge.Reset(infinity) - case <-t.shutdownChan: + case <-t.ctx.Done(): } return case <-keepalive.C: @@ -1000,7 +979,7 @@ func (t *http2Server) keepalive() { pingSent = true t.controlBuf.put(p) keepalive.Reset(t.kp.Timeout) - case <-t.shutdownChan: + case <-t.ctx.Done(): return } } @@ -1008,92 +987,129 @@ func (t *http2Server) keepalive() { var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} -// controller running in a separate goroutine takes charge of sending control -// frames (e.g., window update, reset stream, setting, etc.) to the server. -func (t *http2Server) controller() { - for { - select { - case i := <-t.controlBuf.get(): - t.controlBuf.load() +// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) +// is duplicated between the client and the server. +// The transport layer needs to be refactored to take care of this. +func (t *http2Server) itemHandler(i item) error { + switch i := i.(type) { + case *dataFrame: + if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { + return err + } + i.f() + return nil + case *headerFrame: + t.hBuf.Reset() + for _, f := range i.hf { + t.hEnc.WriteField(f) + } + first := true + endHeaders := false + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + var err error + if first { + first = false + err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: i.streamID, + BlockFragment: t.hBuf.Next(size), + EndStream: i.endStream, + EndHeaders: endHeaders, + }) + } else { + err = t.framer.fr.WriteContinuation( + i.streamID, + endHeaders, + t.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + atomic.StoreUint32(&t.resetPingStrikes, 1) + return nil + case *windowUpdate: + return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + case *settings: + if i.ack { + t.applySettings(i.ss) + return t.framer.fr.WriteSettingsAck() + } + return t.framer.fr.WriteSettings(i.ss...) + case *resetStream: + return t.framer.fr.WriteRSTStream(i.streamID, i.code) + case *goAway: + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + // The transport is closing. + return fmt.Errorf("transport: Connection closing") + } + sid := t.maxStreamID + if !i.headsUp { + // Stop accepting more streams now. + t.state = draining + t.mu.Unlock() + if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { + return err + } + if i.closeConn { + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. + t.framer.writer.Flush() + return fmt.Errorf("transport: Connection closing") + } + return nil + } + t.mu.Unlock() + // For a graceful close, send out a GoAway with stream ID of MaxUInt32, + // Follow that with a ping and wait for the ack to come back or a timer + // to expire. During this time accept new streams since they might have + // originated before the GoAway reaches the client. + // After getting the ack or timer expiration send out another GoAway this + // time with an ID of the max stream server intends to process. + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return err + } + go func() { + timer := time.NewTimer(time.Minute) + defer timer.Stop() select { - case <-t.writableChan: - switch i := i.(type) { - case *windowUpdate: - t.framer.writeWindowUpdate(i.flush, i.streamID, i.increment) - case *settings: - if i.ack { - t.framer.writeSettingsAck(true) - t.applySettings(i.ss) - } else { - t.framer.writeSettings(true, i.ss...) - } - case *resetStream: - t.framer.writeRSTStream(true, i.streamID, i.code) - case *goAway: - t.mu.Lock() - if t.state == closing { - t.mu.Unlock() - // The transport is closing. - return - } - sid := t.maxStreamID - if !i.headsUp { - // Stop accepting more streams now. - t.state = draining - t.mu.Unlock() - t.framer.writeGoAway(true, sid, i.code, i.debugData) - if i.closeConn { - // Abruptly close the connection following the GoAway. - t.Close() - } - t.writableChan <- 0 - continue - } - t.mu.Unlock() - // For a graceful close, send out a GoAway with stream ID of MaxUInt32, - // Follow that with a ping and wait for the ack to come back or a timer - // to expire. During this time accept new streams since they might have - // originated before the GoAway reaches the client. - // After getting the ack or timer expiration send out another GoAway this - // time with an ID of the max stream server intends to process. - t.framer.writeGoAway(true, math.MaxUint32, http2.ErrCodeNo, []byte{}) - t.framer.writePing(true, false, goAwayPing.data) - go func() { - timer := time.NewTimer(time.Minute) - defer timer.Stop() - select { - case <-t.drainChan: - case <-timer.C: - case <-t.shutdownChan: - return - } - t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) - }() - case *flushIO: - t.framer.flushWrite() - case *ping: - if !i.ack { - t.bdpEst.timesnap(i.data) - } - t.framer.writePing(true, i.ack, i.data) - default: - errorf("transport: http2Server.controller got unexpected item type %v\n", i) - } - t.writableChan <- 0 - continue - case <-t.shutdownChan: + case <-t.drainChan: + case <-timer.C: + case <-t.ctx.Done(): return } - case <-t.shutdownChan: - return + t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) + }() + return nil + case *flushIO: + return t.framer.writer.Flush() + case *ping: + if !i.ack { + t.bdpEst.timesnap(i.data) } + return t.framer.fr.WritePing(i.ack, i.data) + default: + err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t", i) + errorf("%v", err) + return err } } // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. -func (t *http2Server) Close() (err error) { +func (t *http2Server) Close() error { t.mu.Lock() if t.state == closing { t.mu.Unlock() @@ -1103,8 +1119,8 @@ func (t *http2Server) Close() (err error) { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - close(t.shutdownChan) - err = t.conn.Close() + t.cancel() + err := t.conn.Close() // Cancel all active streams. for _, s := range streams { s.cancel() @@ -1113,7 +1129,7 @@ func (t *http2Server) Close() (err error) { connEnd := &stats.ConnEnd{} t.stats.HandleConn(t.ctx, connEnd) } - return + return err } // closeStream clears the footprint of a stream when the stream is not needed diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go index 685c6fbf9..39f878cfd 100644 --- a/vendor/google.golang.org/grpc/transport/http_util.go +++ b/vendor/google.golang.org/grpc/transport/http_util.go @@ -28,7 +28,6 @@ import ( "net/http" "strconv" "strings" - "sync/atomic" "time" "github.com/golang/protobuf/proto" @@ -45,7 +44,8 @@ const ( // http://http2.github.io/http2-spec/#SettingValues http2InitHeaderTableSize = 4096 // http2IOBufSize specifies the buffer size for sending frames. - http2IOBufSize = 32 * 1024 + defaultWriteBufSize = 32 * 1024 + defaultReadBufSize = 32 * 1024 ) var ( @@ -111,7 +111,9 @@ type decodeState struct { timeout time.Duration method string // key-value metadata map from the peer. - mdata map[string][]string + mdata map[string][]string + statsTags []byte + statsTrace []byte } // isReservedHeader checks whether hdr belongs to HTTP2 headers @@ -235,6 +237,13 @@ func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error } +func (d *decodeState) addMetadata(k, v string) { + if d.mdata == nil { + d.mdata = make(map[string][]string) + } + d.mdata[k] = append(d.mdata[k], v) +} + func (d *decodeState) processHeaderField(f hpack.HeaderField) error { switch f.Name { case "content-type": @@ -275,18 +284,30 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err) } d.httpStatus = &code + case "grpc-tags-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + } + d.statsTags = v + d.addMetadata(f.Name, string(v)) + case "grpc-trace-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + } + d.statsTrace = v + d.addMetadata(f.Name, string(v)) default: - if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { - if d.mdata == nil { - d.mdata = make(map[string][]string) - } - v, err := decodeMetadataHeader(f.Name, f.Value) - if err != nil { - errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) - return nil - } - d.mdata[f.Name] = append(d.mdata[f.Name], v) + if isReservedHeader(f.Name) && !isWhitelistedPseudoHeader(f.Name) { + break } + v, err := decodeMetadataHeader(f.Name, f.Value) + if err != nil { + errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) + return nil + } + d.addMetadata(f.Name, string(v)) } return nil } @@ -454,10 +475,10 @@ type framer struct { fr *http2.Framer } -func newFramer(conn net.Conn) *framer { +func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { f := &framer{ - reader: bufio.NewReaderSize(conn, http2IOBufSize), - writer: bufio.NewWriterSize(conn, http2IOBufSize), + reader: bufio.NewReaderSize(conn, readBufferSize), + writer: bufio.NewWriterSize(conn, writeBufferSize), } f.fr = http2.NewFramer(f.writer, f.reader) // Opt-in to Frame reuse API on framer to reduce garbage. @@ -466,132 +487,3 @@ func newFramer(conn net.Conn) *framer { f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) return f } - -func (f *framer) adjustNumWriters(i int32) int32 { - return atomic.AddInt32(&f.numWriters, i) -} - -// The following writeXXX functions can only be called when the caller gets -// unblocked from writableChan channel (i.e., owns the privilege to write). - -func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error { - if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error { - if err := f.fr.WriteData(streamID, endStream, data); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error { - if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error { - if err := f.fr.WriteHeaders(p); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error { - if err := f.fr.WritePing(ack, data); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error { - if err := f.fr.WritePriority(streamID, p); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error { - if err := f.fr.WritePushPromise(p); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error { - if err := f.fr.WriteRSTStream(streamID, code); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error { - if err := f.fr.WriteSettings(settings...); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeSettingsAck(forceFlush bool) error { - if err := f.fr.WriteSettingsAck(); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error { - if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil { - return err - } - if forceFlush { - return f.writer.Flush() - } - return nil -} - -func (f *framer) flushWrite() error { - return f.writer.Flush() -} - -func (f *framer) readFrame() (http2.Frame, error) { - return f.fr.ReadFrame() -} - -func (f *framer) errorDetail() error { - return f.fr.ErrorDetail() -} diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go index 14eb1627f..bde8fa5c3 100644 --- a/vendor/google.golang.org/grpc/transport/transport.go +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -21,10 +21,12 @@ package transport // import "google.golang.org/grpc/transport" import ( + stdctx "context" "fmt" "io" "net" "sync" + "time" "golang.org/x/net/context" "golang.org/x/net/http2" @@ -67,20 +69,20 @@ func newRecvBuffer() *recvBuffer { func (b *recvBuffer) put(r recvMsg) { b.mu.Lock() - defer b.mu.Unlock() if len(b.backlog) == 0 { select { case b.c <- r: + b.mu.Unlock() return default: } } b.backlog = append(b.backlog, r) + b.mu.Unlock() } func (b *recvBuffer) load() { b.mu.Lock() - defer b.mu.Unlock() if len(b.backlog) > 0 { select { case b.c <- b.backlog[0]: @@ -89,6 +91,7 @@ func (b *recvBuffer) load() { default: } } + b.mu.Unlock() } // get returns the channel that receives a recvMsg in the buffer. @@ -116,7 +119,11 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { if r.err != nil { return 0, r.err } - defer func() { r.err = err }() + n, r.err = r.read(p) + return n, r.err +} + +func (r *recvBufferReader) read(p []byte) (n int, err error) { if r.last != nil && len(r.last) > 0 { // Read remaining data left in last call. copied := copy(p, r.last) @@ -160,20 +167,20 @@ func newControlBuffer() *controlBuffer { func (b *controlBuffer) put(r item) { b.mu.Lock() - defer b.mu.Unlock() if len(b.backlog) == 0 { select { case b.c <- r: + b.mu.Unlock() return default: } } b.backlog = append(b.backlog, r) + b.mu.Unlock() } func (b *controlBuffer) load() { b.mu.Lock() - defer b.mu.Unlock() if len(b.backlog) > 0 { select { case b.c <- b.backlog[0]: @@ -182,6 +189,7 @@ func (b *controlBuffer) load() { default: } } + b.mu.Unlock() } // get returns the channel that receives an item in the buffer. @@ -231,7 +239,8 @@ type Stream struct { // is used to adjust flow control, if need be. requestRead func(int) - sendQuotaPool *quotaPool + sendQuotaPool *quotaPool + localSendQuota *quotaPool // Close headerChan to indicate the end of reception of header metadata. headerChan chan struct{} // header caches the received header metadata. @@ -309,8 +318,9 @@ func (s *Stream) Header() (metadata.MD, error) { // side only. func (s *Stream) Trailer() metadata.MD { s.mu.RLock() - defer s.mu.RUnlock() - return s.trailer.Copy() + c := s.trailer.Copy() + s.mu.RUnlock() + return c } // ServerTransport returns the underlying ServerTransport for the stream. @@ -338,14 +348,16 @@ func (s *Stream) Status() *status.Status { // Server side only. func (s *Stream) SetHeader(md metadata.MD) error { s.mu.Lock() - defer s.mu.Unlock() if s.headerOk || s.state == streamDone { + s.mu.Unlock() return ErrIllegalHeaderWrite } if md.Len() == 0 { + s.mu.Unlock() return nil } s.header = metadata.Join(s.header, md) + s.mu.Unlock() return nil } @@ -356,8 +368,8 @@ func (s *Stream) SetTrailer(md metadata.MD) error { return nil } s.mu.Lock() - defer s.mu.Unlock() s.trailer = metadata.Join(s.trailer, md) + s.mu.Unlock() return nil } @@ -408,15 +420,17 @@ func (s *Stream) finish(st *status.Status) { // BytesSent indicates whether any bytes have been sent on this stream. func (s *Stream) BytesSent() bool { s.mu.Lock() - defer s.mu.Unlock() - return s.bytesSent + bs := s.bytesSent + s.mu.Unlock() + return bs } // BytesReceived indicates whether any bytes have been received on this stream. func (s *Stream) BytesReceived() bool { s.mu.Lock() - defer s.mu.Unlock() - return s.bytesReceived + br := s.bytesReceived + s.mu.Unlock() + return br } // GoString is implemented by Stream so context.String() won't @@ -445,7 +459,6 @@ type transportState int const ( reachable transportState = iota - unreachable closing draining ) @@ -460,6 +473,8 @@ type ServerConfig struct { KeepalivePolicy keepalive.EnforcementPolicy InitialWindowSize int32 InitialConnWindowSize int32 + WriteBufferSize int + ReadBufferSize int } // NewServerTransport creates a ServerTransport with conn or non-nil error @@ -487,10 +502,14 @@ type ConnectOptions struct { KeepaliveParams keepalive.ClientParameters // StatsHandler stores the handler for stats. StatsHandler stats.Handler - // InitialWindowSize sets the intial window size for a stream. + // InitialWindowSize sets the initial window size for a stream. InitialWindowSize int32 - // InitialConnWindowSize sets the intial window size for a connection. + // InitialConnWindowSize sets the initial window size for a connection. InitialConnWindowSize int32 + // WriteBufferSize sets the size of write buffer which in turn determines how much data can be batched before it's written on the wire. + WriteBufferSize int + // ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall. + ReadBufferSize int } // TargetInfo contains the information of the target such as network address and metadata. @@ -501,8 +520,8 @@ type TargetInfo struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { - return newHTTP2Client(ctx, target, opts) +func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) { + return newHTTP2Client(ctx, target, opts, timeout) } // Options provides additional hints and information for message @@ -514,7 +533,7 @@ type Options struct { // Delay is a hint to the transport implementation for whether // the data could be buffered for a batching write. The - // Transport implementation may ignore the hint. + // transport implementation may ignore the hint. Delay bool } @@ -560,7 +579,7 @@ type ClientTransport interface { // Write sends the data for the given stream. A nil stream indicates // the write is to be performed on the transport as a whole. - Write(s *Stream, data []byte, opts *Options) error + Write(s *Stream, hdr []byte, data []byte, opts *Options) error // NewStream creates a Stream for an RPC. NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) @@ -602,7 +621,7 @@ type ServerTransport interface { // Write sends the data for the given stream. // Write may not be called on all streams. - Write(s *Stream, data []byte, opts *Options) error + Write(s *Stream, hdr []byte, data []byte, opts *Options) error // WriteStatus sends the status of a stream to the client. WriteStatus is // the final call made on a stream and always occurs. @@ -684,34 +703,33 @@ func (e StreamError) Error() string { return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) } -// wait blocks until it can receive from ctx.Done, closing, or proceed. -// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. -// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise -// it return the StreamError for ctx.Err. -// If it receives from goAway, it returns 0, ErrStreamDrain. -// If it receives from closing, it returns 0, ErrConnClosing. -// If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { +// wait blocks until it can receive from one of the provided contexts or channels +func wait(ctx, tctx context.Context, done, goAway <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) case <-done: - // User cancellation has precedence. - select { - case <-ctx.Done(): - return 0, ContextErr(ctx.Err()) - default: - } return 0, io.EOF case <-goAway: return 0, ErrStreamDrain - case <-closing: + case <-tctx.Done(): return 0, ErrConnClosing case i := <-proceed: return i, nil } } +// ContextErr converts the error from context package into a StreamError. +func ContextErr(err error) StreamError { + switch err { + case context.DeadlineExceeded, stdctx.DeadlineExceeded: + return streamErrorf(codes.DeadlineExceeded, "%v", err) + case context.Canceled, stdctx.Canceled: + return streamErrorf(codes.Canceled, "%v", err) + } + return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) +} + // GoAwayReason contains the reason for the GoAway frame received. type GoAwayReason uint8 @@ -721,6 +739,39 @@ const ( // NoReason is the default value when GoAway frame is received. NoReason GoAwayReason = 1 // TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm - // was recieved and that the debug data said "too_many_pings". + // was received and that the debug data said "too_many_pings". TooManyPings GoAwayReason = 2 ) + +// loopyWriter is run in a separate go routine. It is the single code path that will +// write data on wire. +func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) { + for { + select { + case i := <-cbuf.get(): + cbuf.load() + if err := handler(i); err != nil { + return + } + case <-ctx.Done(): + return + } + hasData: + for { + select { + case i := <-cbuf.get(): + cbuf.load() + if err := handler(i); err != nil { + return + } + case <-ctx.Done(): + return + default: + if err := handler(&flushIO{}); err != nil { + return + } + break hasData + } + } + } +} diff --git a/vendor/google.golang.org/grpc/transport/transport_test.go b/vendor/google.golang.org/grpc/transport/transport_test.go index 861047889..e1dd080a1 100644 --- a/vendor/google.golang.org/grpc/transport/transport_test.go +++ b/vendor/google.golang.org/grpc/transport/transport_test.go @@ -49,6 +49,7 @@ type server struct { startedErr chan error // error (or nil) with server start value mu sync.Mutex conns map[ServerTransport]bool + h *testStreamHandler } var ( @@ -60,7 +61,8 @@ var ( ) type testStreamHandler struct { - t *http2Server + t *http2Server + notify chan struct{} } type hType int @@ -68,6 +70,7 @@ type hType int const ( normal hType = iota suspended + notifyCall misbehaved encodingRequiredStatus invalidHeaderField @@ -76,6 +79,19 @@ const ( pingpong ) +func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { + if h.notify == nil { + return + } + go func() { + select { + case <-h.notify: + default: + close(h.notify) + } + }() +} + func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { req := expectedRequest resp := expectedResponse @@ -92,7 +108,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. - h.t.Write(s, resp, &Options{}) + h.t.Write(s, nil, resp, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -112,17 +128,10 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { buf[0] = byte(0) binary.BigEndian.PutUint32(buf[1:], uint32(sz)) copy(buf[5:], msg) - h.t.Write(s, buf, &Options{}) + h.t.Write(s, nil, buf, &Options{}) } } -// handleStreamSuspension blocks until s.ctx is canceled. -func (h *testStreamHandler) handleStreamSuspension(s *Stream) { - go func() { - <-s.ctx.Done() - }() -} - func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { conn, ok := s.ServerTransport().(*http2Server) if !ok { @@ -131,7 +140,6 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { var sent int p := make([]byte, http2MaxFrameLen) for sent < initialWindowSize { - <-conn.writableChan n := initialWindowSize - sent // The last message may be smaller than http2MaxFrameLen if n <= http2MaxFrameLen { @@ -144,11 +152,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { p = make([]byte, n+1) } } - if err := conn.framer.writeData(true, s.id, false, p); err != nil { - conn.writableChan <- 0 - break - } - conn.writableChan <- 0 + conn.controlBuf.put(&dataFrame{s.id, false, p, func() {}}) sent += len(p) } } @@ -159,13 +163,13 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s * } func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { - <-h.t.writableChan - h.t.hBuf.Reset() - h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) - if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil { - t.Fatalf("Failed to write headers: %v", err) - } - h.t.writableChan <- 0 + headerFields := []hpack.HeaderField{} + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) + h.t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) } func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { @@ -190,7 +194,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. - h.t.Write(s, resp, &Options{}) + h.t.Write(s, nil, resp, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -215,7 +219,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { // Wait before sending. Give time to client to start reading // before server starts sending. time.Sleep(2 * time.Second) - h.t.Write(s, resp, &Options{}) + h.t.Write(s, nil, resp, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } @@ -256,11 +260,17 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT return } s.conns[transport] = true + h := &testStreamHandler{t: transport.(*http2Server)} + s.h = h s.mu.Unlock() - h := &testStreamHandler{transport.(*http2Server)} switch ht { + case notifyCall: + go transport.HandleStreams(h.handleStreamAndNotify, + func(ctx context.Context, _ string) context.Context { + return ctx + }) case suspended: - go transport.HandleStreams(h.handleStreamSuspension, + go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream. func(ctx context.Context, method string) context.Context { return ctx }) @@ -347,7 +357,7 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy target := TargetInfo{ Addr: addr, } - ct, connErr = NewClientTransport(context.Background(), target, copts) + ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } @@ -370,7 +380,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } done <- conn }() - tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts) + tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second) if err != nil { // Server clean-up. lis.Close() @@ -808,7 +818,7 @@ func TestClientSendAndReceive(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("failed to send data: %v", err) } p := make([]byte, len(expectedResponse)) @@ -845,7 +855,7 @@ func performOneRPC(ct ClientTransport) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF { + if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { time.Sleep(5 * time.Millisecond) // The following s.Recv()'s could error out because the // underlying transport is gone. @@ -889,7 +899,7 @@ func TestLargeMessage(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) } - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -921,7 +931,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) } - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -959,7 +969,7 @@ func TestLargeMessageDelayWrite(t *testing.T) { // Give time to server to start reading before client starts sending. time.Sleep(2 * time.Second) - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -1005,7 +1015,7 @@ func TestGracefulClose(t *testing.T) { Delay: false, } // The stream which was created before graceful close can still proceed. - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err) } p := make([]byte, len(expectedResponse)) @@ -1027,14 +1037,15 @@ func TestLargeMessageSuspension(t *testing.T) { Method: "foo.Large", } // Set a long enough timeout for writing a large message out. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() s, err := ct.NewStream(ctx, callHdr) if err != nil { t.Fatalf("failed to open stream: %v", err) } // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) - err = ct.Write(s, msg, &Options{Last: true, Delay: false}) + err = ct.Write(s, nil, msg, &Options{Last: true, Delay: false}) expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) if err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) @@ -1156,12 +1167,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { if err != nil { t.Fatalf("Failed to open stream: %v", err) } - // Make sure the headers frame is flushed out. - <-cc.writableChan - if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil { - t.Fatalf("Failed to write data: %v", err) - } - cc.writableChan <- 0 + cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) // Loop until the server side stream is created. var ss *Stream for { @@ -1192,7 +1198,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { InitialWindowSize: defaultWindowSize, InitialConnWindowSize: defaultWindowSize, } - server, client := setUpWithOptions(t, 0, &ServerConfig{}, suspended, connectOptions) + server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) defer server.stop() defer client.Close() @@ -1211,66 +1217,56 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { for k := range server.conns { st = k.(*http2Server) } + notifyChan := make(chan struct{}) + server.h.notify = notifyChan server.mu.Unlock() cstream1, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) if err != nil { t.Fatalf("Client failed to create first stream. Err: %v", err) } + <-notifyChan var sstream1 *Stream // Access stream on the server. - waitWhileTrue(t, func() (bool, error) { - st.mu.Lock() - defer st.mu.Unlock() - - if len(st.activeStreams) != 1 { - return true, fmt.Errorf("timed-out while waiting for server to have created a stream") - } - for _, v := range st.activeStreams { + st.mu.Lock() + for _, v := range st.activeStreams { + if v.id == cstream1.id { sstream1 = v } - return false, nil - }) - + } + st.mu.Unlock() + if sstream1 == nil { + t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) + } // Exhaust client's connection window. - <-st.writableChan - if err := st.framer.writeData(true, sstream1.id, true, make([]byte, defaultWindowSize)); err != nil { - st.writableChan <- 0 + if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Server failed to write data. Err: %v", err) } - st.writableChan <- 0 + notifyChan = make(chan struct{}) + server.mu.Lock() + server.h.notify = notifyChan + server.mu.Unlock() // Create another stream on client. cstream2, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) if err != nil { t.Fatalf("Client failed to create second stream. Err: %v", err) } - + <-notifyChan var sstream2 *Stream - waitWhileTrue(t, func() (bool, error) { - st.mu.Lock() - defer st.mu.Unlock() - - if len(st.activeStreams) != 2 { - return true, fmt.Errorf("timed-out while waiting for server to have created the second stream") - } - for _, v := range st.activeStreams { - if v.id == cstream2.id { - sstream2 = v - } - } - if sstream2 == nil { - return true, fmt.Errorf("didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) + st.mu.Lock() + for _, v := range st.activeStreams { + if v.id == cstream2.id { + sstream2 = v } - return false, nil - }) - + } + st.mu.Unlock() + if sstream2 == nil { + t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) + } // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. - <-st.writableChan - if err := st.framer.writeData(true, sstream2.id, true, make([]byte, defaultWindowSize)); err != nil { - st.writableChan <- 0 + if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Server failed to write data. Err: %v", err) } - st.writableChan <- 0 // Client should be able to read data on second stream. if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { @@ -1311,7 +1307,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Failed to create 1st stream. Err: %v", err) } // Exhaust server's connection window. - if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { + if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } //Client should be able to create another stream and send data on it. @@ -1319,7 +1315,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { if err != nil { t.Fatalf("Failed to create 2nd stream. Err: %v", err) } - if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil { + if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } // Get the streams on server. @@ -1342,11 +1338,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { st.mu.Unlock() // Trying to write more on a max-ed out stream should result in a RST_STREAM from the server. ct := client.(*http2Client) - <-ct.writableChan - if err := ct.framer.writeData(true, cstream2.id, true, make([]byte, 1)); err != nil { - t.Fatalf("Client failed to write. Err: %v", err) - } - ct.writableChan <- 0 + ct.controlBuf.put(&dataFrame{cstream2.id, true, make([]byte, 1), func() {}}) code := http2ErrConvTab[http2.ErrCodeFlowControl] waitWhileTrue(t, func() (bool, error) { cstream2.mu.Lock() @@ -1403,11 +1395,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { } var sent int // Drain the stream flow control window - <-cc.writableChan - if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil { - t.Fatalf("Failed to write data: %v", err) - } - cc.writableChan <- 0 + cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) sent += http2MaxFrameLen // Wait until the server creates the corresponding stream and receive some data. var ss *Stream @@ -1432,11 +1420,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { } // Keep sending until the server inbound window is drained for that stream. for sent <= initialWindowSize { - <-cc.writableChan - if err = cc.framer.writeData(true, s.id, false, make([]byte, 1)); err != nil { - t.Fatalf("Failed to write data: %v", err) - } - cc.writableChan <- 0 + cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, 1), func() {}}) sent++ } // Server sent a resetStream for s already. @@ -1474,7 +1458,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { t.Fatalf("Failed to open stream: %v", err) } d := make([]byte, 1) - if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, nil, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Fatalf("Failed to write: %v", err) } // Read without window update. @@ -1516,7 +1500,7 @@ func TestEncodingRequiredStatus(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) @@ -1544,7 +1528,7 @@ func TestInvalidHeaderField(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) @@ -1696,7 +1680,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + serverSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) } @@ -1718,7 +1702,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) } ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + clientSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) if err != nil { t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) } @@ -1787,7 +1771,7 @@ func TestAccountCheckExpandingWindow(t *testing.T) { opts := Options{} header := make([]byte, 5) for i := 1; i <= 10; i++ { - if err := ct.Write(cstream, buf, &opts); err != nil { + if err := ct.Write(cstream, nil, buf, &opts); err != nil { t.Fatalf("Error on client while writing message: %v", err) } if _, err := cstream.Read(header); err != nil { @@ -1853,8 +1837,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { st.fc.mu.Unlock() // Check flow conrtrol window on client stream is equal to out flow on server stream. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + serverStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, sstream.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) } @@ -1867,8 +1852,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on server stream is equal to out flow on client stream. - ctx, _ = context.WithTimeout(context.Background(), time.Second) - clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire()) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + clientStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, cstream.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) } @@ -1881,8 +1867,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on client transport is equal to out flow of server transport. - ctx, _ = context.WithTimeout(context.Background(), time.Second) - serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + serverTrSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) } @@ -1895,8 +1882,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on server transport is equal to out flow of client transport. - ctx, _ = context.WithTimeout(context.Background(), time.Second) - clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + clientTrSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) } @@ -1945,15 +1933,12 @@ func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error { var buf bytes.Buffer henc := hpack.NewEncoder(&buf) henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)}) - if err := framer.WriteHeaders(http2.HeadersFrameParam{ + return framer.WriteHeaders(http2.HeadersFrameParam{ StreamID: sid, BlockFragment: buf.Bytes(), EndStream: true, EndHeaders: true, - }); err != nil { - return err - } - return nil + }) } func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { @@ -1975,15 +1960,12 @@ func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { Name: ":status", Value: fmt.Sprint(httpStatus), }) - if err := framer.WriteHeaders(http2.HeadersFrameParam{ + return framer.WriteHeaders(http2.HeadersFrameParam{ StreamID: sid, BlockFragment: buf.Bytes(), EndStream: true, EndHeaders: true, - }); err != nil { - return err - } - return nil + }) } type httpServer struct { @@ -2007,8 +1989,8 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err) return } - reader := bufio.NewReaderSize(s.conn, http2IOBufSize) - writer := bufio.NewWriterSize(s.conn, http2IOBufSize) + reader := bufio.NewReaderSize(s.conn, defaultWriteBufSize) + writer := bufio.NewWriterSize(s.conn, defaultReadBufSize) framer := http2.NewFramer(writer, reader) if err = framer.WriteSettingsAck(); err != nil { t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) @@ -2073,7 +2055,7 @@ func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream wh: wh, } server.start(t, lis) - client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) + client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second) if err != nil { t.Fatalf("Error creating client. Err: %v", err) } diff --git a/vendor/google.golang.org/grpc/vet.sh b/vendor/google.golang.org/grpc/vet.sh new file mode 100755 index 000000000..72ef3290a --- /dev/null +++ b/vendor/google.golang.org/grpc/vet.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +set -ex # Exit on error; debugging enabled. +set -o pipefail # Fail a pipe if any sub-command fails. + +die() { + echo "$@" >&2 + exit 1 +} + +# TODO: Remove this check and the mangling below once "context" is imported +# directly. +if git status --porcelain | read; then + die "Uncommitted or untracked files found; commit changes first" +fi + +PATH="$GOPATH/bin:$GOROOT/bin:$PATH" + +# Check proto in manual runs or cron runs. +if [[ "$TRAVIS" != "true" || "$TRAVIS_EVENT_TYPE" = "cron" ]]; then + check_proto="true" +fi + +if [ "$1" = "-install" ]; then + go get -d \ + google.golang.org/grpc/... + go get -u \ + github.com/golang/lint/golint \ + golang.org/x/tools/cmd/goimports \ + honnef.co/go/tools/cmd/staticcheck \ + github.com/golang/protobuf/protoc-gen-go \ + golang.org/x/tools/cmd/stringer + if [[ "$check_proto" = "true" ]]; then + if [[ "$TRAVIS" = "true" ]]; then + PROTOBUF_VERSION=3.3.0 + PROTOC_FILENAME=protoc-${PROTOBUF_VERSION}-linux-x86_64.zip + pushd /home/travis + wget https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/${PROTOC_FILENAME} + unzip ${PROTOC_FILENAME} + bin/protoc --version + popd + elif ! which protoc > /dev/null; then + die "Please install protoc into your path" + fi + fi + exit 0 +elif [[ "$#" -ne 0 ]]; then + die "Unknown argument(s): $*" +fi + +git ls-files "*.go" | xargs grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)\|DO NOT EDIT" 2>&1 | tee /dev/stderr | (! read) +gofmt -s -d -l . 2>&1 | tee /dev/stderr | (! read) +goimports -l . 2>&1 | tee /dev/stderr | (! read) +golint ./... 2>&1 | (grep -vE "(_mock|_string|\.pb)\.go:" || true) | tee /dev/stderr | (! read) + +# Undo any edits made by this script. +cleanup() { + git reset --hard HEAD +} +trap cleanup EXIT + +# Rewrite golang.org/x/net/context -> context imports (see grpc/grpc-go#1484). +# TODO: Remove this mangling once "context" is imported directly (grpc/grpc-go#711). +git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":' +set +o pipefail +# TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. +go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read) +set -o pipefail +git reset --hard HEAD + +if [[ "$check_proto" = "true" ]]; then + PATH="/home/travis/bin:$PATH" make proto && \ + git status --porcelain 2>&1 | (! read) || \ + (git status; git --no-pager diff; exit 1) +fi + +# TODO(menghanl): fix errors in transport_test. +staticcheck -ignore google.golang.org/grpc/transport/transport_test.go:SA2002 ./... |