TUN-9467: bump coredns to solve CVE

* TUN-9467: bump coredns to solve CVE
This commit is contained in:
João Oliveirinha
2025-06-12 10:46:10 +00:00
committed by João "Pisco" Fernandes
parent f8d12c9d39
commit a408612f26
459 changed files with 30077 additions and 16165 deletions

View File

@@ -4,7 +4,7 @@ We definitely welcome your patches and contributions to gRPC! Please read the gR
organization's [governance rules](https://github.com/grpc/grpc-community/blob/master/governance.md)
and [contribution guidelines](https://github.com/grpc/grpc-community/blob/master/CONTRIBUTING.md) before proceeding.
If you are new to github, please start by reading [Pull Request howto](https://help.github.com/articles/about-pull-requests/)
If you are new to GitHub, please start by reading [Pull Request howto](https://help.github.com/articles/about-pull-requests/)
## Legal requirements
@@ -25,8 +25,8 @@ How to get your contributions merged smoothly and quickly.
is a great place to start. These issues are well-documented and usually can be
resolved with a single pull request.
- If you are adding a new file, make sure it has the copyright message template
at the top as a comment. You can copy over the message from an existing file
- If you are adding a new file, make sure it has the copyright message template
at the top as a comment. You can copy over the message from an existing file
and update the year.
- The grpc package should only depend on standard Go packages and a small number
@@ -39,12 +39,12 @@ How to get your contributions merged smoothly and quickly.
proposal](https://github.com/grpc/proposal).
- Provide a good **PR description** as a record of **what** change is being made
and **why** it was made. Link to a github issue if it exists.
and **why** it was made. Link to a GitHub issue if it exists.
- If you want to fix formatting or style, consider whether your changes are an
obvious improvement or might be considered a personal preference. If a style
change is based on preference, it likely will not be accepted. If it corrects
widely agreed-upon anti-patterns, then please do create a PR and explain the
- If you want to fix formatting or style, consider whether your changes are an
obvious improvement or might be considered a personal preference. If a style
change is based on preference, it likely will not be accepted. If it corrects
widely agreed-upon anti-patterns, then please do create a PR and explain the
benefits of the change.
- Unless your PR is trivial, you should expect there will be reviewer comments
@@ -66,7 +66,7 @@ How to get your contributions merged smoothly and quickly.
- **All tests need to be passing** before your change can be merged. We
recommend you **run tests locally** before creating your PR to catch breakages
early on.
- `VET_SKIP_PROTO=1 ./vet.sh` to catch vet errors
- `./scripts/vet.sh` to catch vet errors
- `go test -cpu 1,4 -timeout 7m ./...` to run the tests
- `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode

View File

@@ -9,20 +9,28 @@ for general contribution guidelines.
## Maintainers (in alphabetical order)
- [cesarghali](https://github.com/cesarghali), Google LLC
- [aranjans](https://github.com/aranjans), Google LLC
- [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc.
- [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC
- [menghanl](https://github.com/menghanl), Google LLC
- [srini100](https://github.com/srini100), Google LLC
- [erm-g](https://github.com/erm-g), Google LLC
- [gtcooke94](https://github.com/gtcooke94), Google LLC
- [purnesh42h](https://github.com/purnesh42h), Google LLC
- [zasweq](https://github.com/zasweq), Google LLC
## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez), Google LLC
- [canguler](https://github.com/canguler), Google LLC
- [iamqizhao](https://github.com/iamqizhao), Google LLC
- [jadekler](https://github.com/jadekler), Google LLC
- [jtattermusch](https://github.com/jtattermusch), Google LLC
- [lyuxuan](https://github.com/lyuxuan), Google LLC
- [makmukhi](https://github.com/makmukhi), Google LLC
- [matt-kwong](https://github.com/matt-kwong), Google LLC
- [nicolasnoble](https://github.com/nicolasnoble), Google LLC
- [yongni](https://github.com/yongni), Google LLC
- [adelez](https://github.com/adelez)
- [canguler](https://github.com/canguler)
- [cesarghali](https://github.com/cesarghali)
- [iamqizhao](https://github.com/iamqizhao)
- [jeanbza](https://github.com/jeanbza)
- [jtattermusch](https://github.com/jtattermusch)
- [lyuxuan](https://github.com/lyuxuan)
- [makmukhi](https://github.com/makmukhi)
- [matt-kwong](https://github.com/matt-kwong)
- [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble)
- [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni)

View File

@@ -30,17 +30,20 @@ testdeps:
GO111MODULE=on go get -d -v -t google.golang.org/grpc/...
vet: vetdeps
./vet.sh
./scripts/vet.sh
vetdeps:
./vet.sh -install
./scripts/vet.sh -install
.PHONY: \
all \
build \
clean \
deps \
proto \
test \
testsubmodule \
testrace \
testdeps \
vet \
vetdeps

View File

@@ -10,7 +10,7 @@ RPC framework that puts mobile and HTTP/2 first. For more information see the
## Prerequisites
- **[Go][]**: any one of the **three latest major** [releases][go-releases].
- **[Go][]**: any one of the **two latest major** [releases][go-releases].
## Installation

View File

@@ -1,3 +1,3 @@
# Security Policy
For information on gRPC Security Policy and reporting potentional security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).
For information on gRPC Security Policy and reporting potential security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).

View File

@@ -39,7 +39,7 @@ type Config struct {
MaxDelay time.Duration
}
// DefaultConfig is a backoff configuration with the default values specfied
// DefaultConfig is a backoff configuration with the default values specified
// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
//
// This should be useful for callers who want to configure backoff with

View File

@@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/channelz"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
@@ -74,6 +75,8 @@ func unregisterForTesting(name string) {
func init() {
internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
}
// Get returns the resolver builder registered with the given name.
@@ -92,54 +95,6 @@ func Get(name string) Builder {
return nil
}
// A SubConn represents a single connection to a gRPC backend service.
//
// Each SubConn contains a list of addresses.
//
// All SubConns start in IDLE, and will not try to connect. To trigger the
// connecting, Balancers must call Connect. If a connection re-enters IDLE,
// Balancers must call Connect again to trigger a new connection attempt.
//
// gRPC will try to connect to the addresses in sequence, and stop trying the
// remainder once the first connection is successful. If an attempt to connect
// to all addresses encounters an error, the SubConn will enter
// TRANSIENT_FAILURE for a backoff period, and then transition to IDLE.
//
// Once established, if a connection is lost, the SubConn will transition
// directly to IDLE.
//
// This interface is to be implemented by gRPC. Users should not need their own
// implementation of this interface. For situations like testing, any
// implementations should embed this interface. This allows gRPC to add new
// methods to this interface.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully closed, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
//
// Deprecated: this method will be removed. Create new SubConns for new
// addresses instead.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which must
// be called when the Producer is no longer needed.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
// Shutdown shuts down the SubConn gracefully. Any started RPCs will be
// allowed to complete. No future calls should be made on the SubConn.
// One final state update will be delivered to the StateListener (or
// UpdateSubConnState; deprecated) with ConnectivityState of Shutdown to
// indicate the shutdown operation. This may be delivered before
// in-progress RPCs are complete and the actual connection is closed.
Shutdown()
}
// NewSubConnOptions contains options to create new SubConn.
type NewSubConnOptions struct {
// CredsBundle is the credentials bundle that will be used in the created
@@ -174,6 +129,13 @@ type State struct {
// brand new implementation of this interface. For the situations like
// testing, the new implementation should embed this interface. This allows
// gRPC to add new methods to this interface.
//
// NOTICE: This interface is intended to be implemented by gRPC, or intercepted
// by custom load balancing polices. Users should not need their own complete
// implementation of this interface -- they should always delegate to a
// ClientConn passed to Builder.Build() by embedding it in their
// implementations. An embedded ClientConn must never be nil, or runtime panics
// will occur.
type ClientConn interface {
// NewSubConn is called by balancer to create a new SubConn.
// It doesn't block and wait for the connections to be established.
@@ -212,6 +174,17 @@ type ClientConn interface {
//
// Deprecated: Use the Target field in the BuildOptions instead.
Target() string
// MetricsRecorder provides the metrics recorder that balancers can use to
// record metrics. Balancer implementations which do not register metrics on
// metrics registry and record on them can ignore this method. The returned
// MetricsRecorder is guaranteed to never be nil.
MetricsRecorder() estats.MetricsRecorder
// EnforceClientConnEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceClientConnEmbedding
}
// BuildOptions contains additional information for Build.
@@ -403,15 +376,6 @@ type ExitIdler interface {
ExitIdle()
}
// SubConnState describes the state of a SubConn.
type SubConnState struct {
// ConnectivityState is the connectivity state of the SubConn.
ConnectivityState connectivity.State
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
}
// ClientConnState describes the state of a ClientConn relevant to the
// balancer.
type ClientConnState struct {
@@ -424,20 +388,3 @@ type ClientConnState struct {
// ErrBadResolverState may be returned by UpdateClientConnState to indicate a
// problem with the provided name resolver data.
var ErrBadResolverState = errors.New("bad resolver state")
// A ProducerBuilder is a simple constructor for a Producer. It is used by the
// SubConn to create producers when needed.
type ProducerBuilder interface {
// Build creates a Producer. The first parameter is always a
// grpc.ClientConnInterface (a type to allow creating RPCs/streams on the
// associated SubConn), but is declared as `any` to avoid a dependency
// cycle. Should also return a close function that will be called when all
// references to the Producer have been given up.
Build(grpcClientConnInterface any) (p Producer, close func())
}
// A Producer is a type shared among potentially many consumers. It is
// associated with a SubConn, and an implementation will typically contain
// other methods to provide additional functionality, e.g. configuration or
// subscription registration.
type Producer any

View File

@@ -36,12 +36,12 @@ type baseBuilder struct {
config Config
}
func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
func (bb *baseBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer {
bal := &baseBalancer{
cc: cc,
pickerBuilder: bb.pickerBuilder,
subConns: resolver.NewAddressMap(),
subConns: resolver.NewAddressMapV2[balancer.SubConn](),
scStates: make(map[balancer.SubConn]connectivity.State),
csEvltr: &balancer.ConnectivityStateEvaluator{},
config: bb.config,
@@ -65,7 +65,7 @@ type baseBalancer struct {
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns *resolver.AddressMap
subConns *resolver.AddressMapV2[balancer.SubConn]
scStates map[balancer.SubConn]connectivity.State
picker balancer.Picker
config Config
@@ -100,7 +100,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
// Successful resolution; clear resolver error and ensure we return nil.
b.resolverErr = nil
// addrsSet is the set converted from addrs, it's used for quick lookup of an address.
addrsSet := resolver.NewAddressMap()
addrsSet := resolver.NewAddressMapV2[any]()
for _, a := range s.ResolverState.Addresses {
addrsSet.Set(a, nil)
if _, ok := b.subConns.Get(a); !ok {
@@ -122,8 +122,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
}
}
for _, a := range b.subConns.Keys() {
sci, _ := b.subConns.Get(a)
sc := sci.(balancer.SubConn)
sc, _ := b.subConns.Get(a)
// a was removed by resolver.
if _, ok := addrsSet.Get(a); !ok {
sc.Shutdown()
@@ -133,7 +132,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
}
}
// If resolver state contains no addresses, return an error so ClientConn
// will trigger re-resolve. Also records this as an resolver error, so when
// will trigger re-resolve. Also records this as a resolver error, so when
// the overall state turns transient failure, the error message will have
// the zero address information.
if len(s.ResolverState.Addresses) == 0 {
@@ -173,8 +172,7 @@ func (b *baseBalancer) regeneratePicker() {
// Filter out all ready SCs from full subConn map.
for _, addr := range b.subConns.Keys() {
sci, _ := b.subConns.Get(addr)
sc := sci.(balancer.SubConn)
sc, _ := b.subConns.Get(addr)
if st, ok := b.scStates[sc]; ok && st == connectivity.Ready {
readySCs[sc] = SubConnInfo{Address: addr}
}
@@ -259,6 +257,6 @@ type errPicker struct {
err error // Pick() always returns this err.
}
func (p *errPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
func (p *errPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return balancer.PickResult{}, p.err
}

View File

@@ -0,0 +1,356 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package endpointsharding implements a load balancing policy that manages
// homogeneous child policies each owning a single endpoint.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package endpointsharding
import (
"errors"
rand "math/rand/v2"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
// ChildState is the balancer state of a child along with the endpoint which
// identifies the child balancer.
type ChildState struct {
Endpoint resolver.Endpoint
State balancer.State
// Balancer exposes only the ExitIdler interface of the child LB policy.
// Other methods of the child policy are called only by endpointsharding.
Balancer balancer.ExitIdler
}
// Options are the options to configure the behaviour of the
// endpointsharding balancer.
type Options struct {
// DisableAutoReconnect allows the balancer to keep child balancer in the
// IDLE state until they are explicitly triggered to exit using the
// ChildState obtained from the endpointsharding picker. When set to false,
// the endpointsharding balancer will automatically call ExitIdle on child
// connections that report IDLE.
DisableAutoReconnect bool
}
// ChildBuilderFunc creates a new balancer with the ClientConn. It has the same
// type as the balancer.Builder.Build method.
type ChildBuilderFunc func(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer
// NewBalancer returns a load balancing policy that manages homogeneous child
// policies each owning a single endpoint. The endpointsharding balancer
// forwards the LoadBalancingConfig in ClientConn state updates to its children.
func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions, childBuilder ChildBuilderFunc, esOpts Options) balancer.Balancer {
es := &endpointSharding{
cc: cc,
bOpts: opts,
esOpts: esOpts,
childBuilder: childBuilder,
}
es.children.Store(resolver.NewEndpointMap[*balancerWrapper]())
return es
}
// endpointSharding is a balancer that wraps child balancers. It creates a child
// balancer with child config for every unique Endpoint received. It updates the
// child states on any update from parent or child.
type endpointSharding struct {
cc balancer.ClientConn
bOpts balancer.BuildOptions
esOpts Options
childBuilder ChildBuilderFunc
// childMu synchronizes calls to any single child. It must be held for all
// calls into a child. To avoid deadlocks, do not acquire childMu while
// holding mu.
childMu sync.Mutex
children atomic.Pointer[resolver.EndpointMap[*balancerWrapper]]
// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
// calls (calls to children will each produce an update, only want one
// update).
inhibitChildUpdates atomic.Bool
// mu synchronizes access to the state stored in balancerWrappers in the
// children field. mu must not be held during calls into a child since
// synchronous calls back from the child may require taking mu, causing a
// deadlock. To avoid deadlocks, do not acquire childMu while holding mu.
mu sync.Mutex
}
// UpdateClientConnState creates a child for new endpoints and deletes children
// for endpoints that are no longer present. It also updates all the children,
// and sends a single synchronous update of the childrens' aggregated state at
// the end of the UpdateClientConnState operation. If any endpoint has no
// addresses it will ignore that endpoint. Otherwise, returns first error found
// from a child, but fully processes the new update.
func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState) error {
es.childMu.Lock()
defer es.childMu.Unlock()
es.inhibitChildUpdates.Store(true)
defer func() {
es.inhibitChildUpdates.Store(false)
es.updateState()
}()
var ret error
children := es.children.Load()
newChildren := resolver.NewEndpointMap[*balancerWrapper]()
// Update/Create new children.
for _, endpoint := range state.ResolverState.Endpoints {
if _, ok := newChildren.Get(endpoint); ok {
// Endpoint child was already created, continue to avoid duplicate
// update.
continue
}
childBalancer, ok := children.Get(endpoint)
if ok {
// Endpoint attributes may have changed, update the stored endpoint.
es.mu.Lock()
childBalancer.childState.Endpoint = endpoint
es.mu.Unlock()
} else {
childBalancer = &balancerWrapper{
childState: ChildState{Endpoint: endpoint},
ClientConn: es.cc,
es: es,
}
childBalancer.childState.Balancer = childBalancer
childBalancer.child = es.childBuilder(childBalancer, es.bOpts)
}
newChildren.Set(endpoint, childBalancer)
if err := childBalancer.updateClientConnStateLocked(balancer.ClientConnState{
BalancerConfig: state.BalancerConfig,
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{endpoint},
Attributes: state.ResolverState.Attributes,
},
}); err != nil && ret == nil {
// Return first error found, and always commit full processing of
// updating children. If desired to process more specific errors
// across all endpoints, caller should make these specific
// validations, this is a current limitation for simplicity sake.
ret = err
}
}
// Delete old children that are no longer present.
for _, e := range children.Keys() {
child, _ := children.Get(e)
if _, ok := newChildren.Get(e); !ok {
child.closeLocked()
}
}
es.children.Store(newChildren)
if newChildren.Len() == 0 {
return balancer.ErrBadResolverState
}
return ret
}
// ResolverError forwards the resolver error to all of the endpointSharding's
// children and sends a single synchronous update of the childStates at the end
// of the ResolverError operation.
func (es *endpointSharding) ResolverError(err error) {
es.childMu.Lock()
defer es.childMu.Unlock()
es.inhibitChildUpdates.Store(true)
defer func() {
es.inhibitChildUpdates.Store(false)
es.updateState()
}()
children := es.children.Load()
for _, child := range children.Values() {
child.resolverErrorLocked(err)
}
}
func (es *endpointSharding) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) {
// UpdateSubConnState is deprecated.
}
func (es *endpointSharding) Close() {
es.childMu.Lock()
defer es.childMu.Unlock()
children := es.children.Load()
for _, child := range children.Values() {
child.closeLocked()
}
}
// updateState updates this component's state. It sends the aggregated state,
// and a picker with round robin behavior with all the child states present if
// needed.
func (es *endpointSharding) updateState() {
if es.inhibitChildUpdates.Load() {
return
}
var readyPickers, connectingPickers, idlePickers, transientFailurePickers []balancer.Picker
es.mu.Lock()
defer es.mu.Unlock()
children := es.children.Load()
childStates := make([]ChildState, 0, children.Len())
for _, child := range children.Values() {
childState := child.childState
childStates = append(childStates, childState)
childPicker := childState.State.Picker
switch childState.State.ConnectivityState {
case connectivity.Ready:
readyPickers = append(readyPickers, childPicker)
case connectivity.Connecting:
connectingPickers = append(connectingPickers, childPicker)
case connectivity.Idle:
idlePickers = append(idlePickers, childPicker)
case connectivity.TransientFailure:
transientFailurePickers = append(transientFailurePickers, childPicker)
// connectivity.Shutdown shouldn't appear.
}
}
// Construct the round robin picker based off the aggregated state. Whatever
// the aggregated state, use the pickers present that are currently in that
// state only.
var aggState connectivity.State
var pickers []balancer.Picker
if len(readyPickers) >= 1 {
aggState = connectivity.Ready
pickers = readyPickers
} else if len(connectingPickers) >= 1 {
aggState = connectivity.Connecting
pickers = connectingPickers
} else if len(idlePickers) >= 1 {
aggState = connectivity.Idle
pickers = idlePickers
} else if len(transientFailurePickers) >= 1 {
aggState = connectivity.TransientFailure
pickers = transientFailurePickers
} else {
aggState = connectivity.TransientFailure
pickers = []balancer.Picker{base.NewErrPicker(errors.New("no children to pick from"))}
} // No children (resolver error before valid update).
p := &pickerWithChildStates{
pickers: pickers,
childStates: childStates,
next: uint32(rand.IntN(len(pickers))),
}
es.cc.UpdateState(balancer.State{
ConnectivityState: aggState,
Picker: p,
})
}
// pickerWithChildStates delegates to the pickers it holds in a round robin
// fashion. It also contains the childStates of all the endpointSharding's
// children.
type pickerWithChildStates struct {
pickers []balancer.Picker
childStates []ChildState
next uint32
}
func (p *pickerWithChildStates) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
nextIndex := atomic.AddUint32(&p.next, 1)
picker := p.pickers[nextIndex%uint32(len(p.pickers))]
return picker.Pick(info)
}
// ChildStatesFromPicker returns the state of all the children managed by the
// endpoint sharding balancer that created this picker.
func ChildStatesFromPicker(picker balancer.Picker) []ChildState {
p, ok := picker.(*pickerWithChildStates)
if !ok {
return nil
}
return p.childStates
}
// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by
// endpoint, and persists recent child balancer state.
type balancerWrapper struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
// child contains the wrapped balancer. Access its methods only through
// methods on balancerWrapper to ensure proper synchronization
child balancer.Balancer
balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns
es *endpointSharding
// Access to the following fields is guarded by es.mu.
childState ChildState
isClosed bool
}
func (bw *balancerWrapper) UpdateState(state balancer.State) {
bw.es.mu.Lock()
bw.childState.State = state
bw.es.mu.Unlock()
if state.ConnectivityState == connectivity.Idle && !bw.es.esOpts.DisableAutoReconnect {
bw.ExitIdle()
}
bw.es.updateState()
}
// ExitIdle pings an IDLE child balancer to exit idle in a new goroutine to
// avoid deadlocks due to synchronous balancer state updates.
func (bw *balancerWrapper) ExitIdle() {
if ei, ok := bw.child.(balancer.ExitIdler); ok {
go func() {
bw.es.childMu.Lock()
if !bw.isClosed {
ei.ExitIdle()
}
bw.es.childMu.Unlock()
}()
}
}
// updateClientConnStateLocked delivers the ClientConnState to the child
// balancer. Callers must hold the child mutex of the parent endpointsharding
// balancer.
func (bw *balancerWrapper) updateClientConnStateLocked(ccs balancer.ClientConnState) error {
return bw.child.UpdateClientConnState(ccs)
}
// closeLocked closes the child balancer. Callers must hold the child mutext of
// the parent endpointsharding balancer.
func (bw *balancerWrapper) closeLocked() {
bw.child.Close()
bw.isClosed = true
}
func (bw *balancerWrapper) resolverErrorLocked(err error) {
bw.child.ResolverError(err)
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains code internal to the pickfirst package.
package internal
import (
rand "math/rand/v2"
"time"
)
var (
// RandShuffle pseudo-randomizes the order of addresses.
RandShuffle = rand.Shuffle
// TimeAfterFunc allows mocking the timer for testing connection delay
// related functionality.
TimeAfterFunc = func(d time.Duration, f func()) func() {
timer := time.AfterFunc(d, f)
return func() { timer.Stop() }
}
)

View File

@@ -16,45 +16,60 @@
*
*/
package grpc
// Package pickfirst contains the pick_first load balancing policy.
package pickfirst
import (
"encoding/json"
"errors"
"fmt"
rand "math/rand/v2"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
_ "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" // For automatically registering the new pickfirst if required.
)
func init() {
if envconfig.NewPickFirstEnabled {
return
}
balancer.Register(pickfirstBuilder{})
}
var logger = grpclog.Component("pick-first-lb")
const (
// PickFirstBalancerName is the name of the pick_first balancer.
PickFirstBalancerName = "pick_first"
logPrefix = "[pick-first-lb %p] "
// Name is the name of the pick_first balancer.
Name = "pick_first"
logPrefix = "[pick-first-lb %p] "
)
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{cc: cc}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (pickfirstBuilder) Name() string {
return PickFirstBalancerName
return Name
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of addresses received from the name resolver before attempting to
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
@@ -93,9 +108,17 @@ func (b *pickfirstBalancer) ResolverError(err error) {
})
}
// Shuffler is an interface for shuffling an address list.
type Shuffler interface {
ShuffleAddressListForTesting(n int, swap func(i, j int))
}
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j.
func ShuffleAddressListForTesting(n int, swap func(i, j int)) { rand.Shuffle(n, swap) }
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
addrs := state.ResolverState.Addresses
if len(addrs) == 0 {
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// The resolver reported an empty address list. Treat it like an error by
// calling b.ResolverError.
if b.subConn != nil {
@@ -107,22 +130,49 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
b.ResolverError(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
// We don't have to guard this block with the env var because ParseConfig
// already does so.
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig)
}
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var addrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling will
// change the order of endpoints but not touch the order of the addresses
// within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for each
// of the endpoints, in order." - A61
for _, endpoint := range endpoints {
// "In the flattened list, interleave addresses from the two address
// families, as per RFC-8304 section 4." - A61
// TODO: support the above language.
addrs = append(addrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
addrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
}
if b.subConn != nil {
b.cc.UpdateAddresses(b.subConn, addrs)
return nil

View File

@@ -0,0 +1,927 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pickfirstleaf contains the pick_first load balancing policy which
// will be the universal leaf policy after dualstack changes are implemented.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package pickfirstleaf
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func init() {
if envconfig.NewPickFirstEnabled {
// Register as the default pick_first balancer.
Name = "pick_first"
}
balancer.Register(pickfirstBuilder{})
}
type (
// enableHealthListenerKeyType is a unique key type used in resolver
// attributes to indicate whether the health listener usage is enabled.
enableHealthListenerKeyType struct{}
// managedByPickfirstKeyType is an attribute key type to inform Outlier
// Detection that the generic health listener is being used.
// TODO: https://github.com/grpc/grpc-go/issues/7915 - Remove this when
// implementing the dualstack design. This is a hack. Once Dualstack is
// completed, outlier detection will stop sending ejection updates through
// the connectivity listener.
managedByPickfirstKeyType struct{}
)
var (
logger = grpclog.Component("pick-first-leaf-lb")
// Name is the name of the pick_first_leaf balancer.
// It is changed to "pick_first" in init() if this balancer is to be
// registered as the default pickfirst.
Name = "pick_first_leaf"
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "disconnection",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
)
const (
// TODO: change to pick-first when this becomes the default pick_first policy.
logPrefix = "[pick-first-leaf-lb %p] "
// connectionDelayInterval is the time to wait for during the happy eyeballs
// pass before starting the next connection attempt.
connectionDelayInterval = 250 * time.Millisecond
)
type ipAddrFamily int
const (
// ipAddrFamilyUnknown represents strings that can't be parsed as an IP
// address.
ipAddrFamilyUnknown ipAddrFamily = iota
ipAddrFamilyV4
ipAddrFamilyV6
)
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{
cc: cc,
target: bo.Target.String(),
metricsRecorder: cc.MetricsRecorder(),
subConns: resolver.NewAddressMapV2[*scData](),
state: connectivity.Connecting,
cancelConnectionTimer: func() {},
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (b pickfirstBuilder) Name() string {
return Name
}
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
return cfg, nil
}
// EnableHealthListener updates the state to configure pickfirst for using a
// generic health listener.
func EnableHealthListener(state resolver.State) resolver.State {
state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true)
return state
}
// IsManagedByPickfirst returns whether an address belongs to a SubConn
// managed by the pickfirst LB policy.
// TODO: https://github.com/grpc/grpc-go/issues/7915 - This is a hack to disable
// outlier_detection via the with connectivity listener when using pick_first.
// Once Dualstack changes are complete, all SubConns will be created by
// pick_first and outlier detection will only use the health listener for
// ejection. This hack can then be removed.
func IsManagedByPickfirst(addr resolver.Address) bool {
return addr.BalancerAttributes.Value(managedByPickfirstKeyType{}) != nil
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
// scData keeps track of the current state of the subConn.
// It is not safe for concurrent access.
type scData struct {
// The following fields are initialized at build time and read-only after
// that.
subConn balancer.SubConn
addr resolver.Address
rawConnectivityState connectivity.State
// The effective connectivity state based on raw connectivity, health state
// and after following sticky TransientFailure behaviour defined in A62.
effectiveState connectivity.State
lastErr error
connectionFailedInFirstPass bool
}
func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) {
addr.BalancerAttributes = addr.BalancerAttributes.WithValue(managedByPickfirstKeyType{}, true)
sd := &scData{
rawConnectivityState: connectivity.Idle,
effectiveState: connectivity.Idle,
addr: addr,
}
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
StateListener: func(state balancer.SubConnState) {
b.updateSubConnState(sd, state)
},
})
if err != nil {
return nil, err
}
sd.subConn = sc
return sd, nil
}
type pickfirstBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
logger *internalgrpclog.PrefixLogger
cc balancer.ClientConn
target string
metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil
// The mutex is used to ensure synchronization of updates triggered
// from the idle picker and the already serialized resolver,
// SubConn state updates.
mu sync.Mutex
// State reported to the channel based on SubConn states and resolver
// updates.
state connectivity.State
// scData for active subonns mapped by address.
subConns *resolver.AddressMapV2[*scData]
addressList addressList
firstPass bool
numTF int
cancelConnectionTimer func()
healthCheckingEnabled bool
}
// ResolverError is called by the ClientConn when the name resolver produces
// an error or when pickfirst determined the resolver update to be invalid.
func (b *pickfirstBalancer) ResolverError(err error) {
b.mu.Lock()
defer b.mu.Unlock()
b.resolverErrorLocked(err)
}
func (b *pickfirstBalancer) resolverErrorLocked(err error) {
if b.logger.V(2) {
b.logger.Infof("Received error from the name resolver: %v", err)
}
// The picker will not change since the balancer does not currently
// report an error. If the balancer hasn't received a single good resolver
// update yet, transition to TRANSIENT_FAILURE.
if b.state != connectivity.TransientFailure && b.addressList.size() > 0 {
if b.logger.V(2) {
b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.")
}
return
}
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
})
}
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
b.mu.Lock()
defer b.mu.Unlock()
b.cancelConnectionTimer()
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// Cleanup state pertaining to the previous resolver state.
// Treat an empty address list like an error by calling b.ResolverError.
b.closeSubConnsLocked()
b.addressList.updateAddrs(nil)
b.resolverErrorLocked(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState)
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var newAddrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for
// each of the endpoints, in order." - A61
for _, endpoint := range endpoints {
newAddrs = append(newAddrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
newAddrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
newAddrs = append([]resolver.Address{}, newAddrs...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}
// If an address appears in multiple endpoints or in the same endpoint
// multiple times, we keep it only once. We will create only one SubConn
// for the address because an AddressMap is used to store SubConns.
// Not de-duplicating would result in attempting to connect to the same
// SubConn multiple times in the same pass. We don't want this.
newAddrs = deDupAddresses(newAddrs)
newAddrs = interleaveAddresses(newAddrs)
prevAddr := b.addressList.currentAddress()
prevSCData, found := b.subConns.Get(prevAddr)
prevAddrsCount := b.addressList.size()
isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready
b.addressList.updateAddrs(newAddrs)
// If the previous ready SubConn exists in new address list,
// keep this connection and don't create new SubConns.
if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) {
return nil
}
b.reconcileSubConnsLocked(newAddrs)
// If it's the first resolver update or the balancer was already READY
// (but the new address list does not contain the ready SubConn) or
// CONNECTING, enter CONNECTING.
// We may be in TRANSIENT_FAILURE due to a previous empty address list,
// we should still enter CONNECTING because the sticky TF behaviour
// mentioned in A62 applies only when the TRANSIENT_FAILURE is reported
// due to connectivity failures.
if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 {
// Start connection attempt at first address.
b.forceUpdateConcludedStateLocked(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
b.startFirstPassLocked()
} else if b.state == connectivity.TransientFailure {
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
b.startFirstPassLocked()
}
return nil
}
// UpdateSubConnState is unused as a StateListener is always registered when
// creating SubConns.
func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state)
}
func (b *pickfirstBalancer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closeSubConnsLocked()
b.cancelConnectionTimer()
b.state = connectivity.Shutdown
}
// ExitIdle moves the balancer out of idle state. It can be called concurrently
// by the idlePicker and clientConn so access to variables should be
// synchronized.
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle {
b.startFirstPassLocked()
}
}
func (b *pickfirstBalancer) startFirstPassLocked() {
b.firstPass = true
b.numTF = 0
// Reset the connection attempt record for existing SubConns.
for _, sd := range b.subConns.Values() {
sd.connectionFailedInFirstPass = false
}
b.requestConnectionLocked()
}
func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() {
sd.subConn.Shutdown()
}
b.subConns = resolver.NewAddressMapV2[*scData]()
}
// deDupAddresses ensures that each address appears only once in the slice.
func deDupAddresses(addrs []resolver.Address) []resolver.Address {
seenAddrs := resolver.NewAddressMapV2[*scData]()
retAddrs := []resolver.Address{}
for _, addr := range addrs {
if _, ok := seenAddrs.Get(addr); ok {
continue
}
retAddrs = append(retAddrs, addr)
}
return retAddrs
}
// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6)
// as per RFC-8305 section 4.
// Whichever address family is first in the list is followed by an address of
// the other address family; that is, if the first address in the list is IPv6,
// then the first IPv4 address should be moved up in the list to be second in
// the list. It doesn't support configuring "First Address Family Count", i.e.
// there will always be a single member of the first address family at the
// beginning of the interleaved list.
// Addresses that are neither IPv4 nor IPv6 are treated as part of a third
// "unknown" family for interleaving.
// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6
func interleaveAddresses(addrs []resolver.Address) []resolver.Address {
familyAddrsMap := map[ipAddrFamily][]resolver.Address{}
interleavingOrder := []ipAddrFamily{}
for _, addr := range addrs {
family := addressFamily(addr.Addr)
if _, found := familyAddrsMap[family]; !found {
interleavingOrder = append(interleavingOrder, family)
}
familyAddrsMap[family] = append(familyAddrsMap[family], addr)
}
interleavedAddrs := make([]resolver.Address, 0, len(addrs))
for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) {
// Some IP types may have fewer addresses than others, so we look for
// the next type that has a remaining member to add to the interleaved
// list.
family := interleavingOrder[curFamilyIdx]
remainingMembers := familyAddrsMap[family]
if len(remainingMembers) > 0 {
interleavedAddrs = append(interleavedAddrs, remainingMembers[0])
familyAddrsMap[family] = remainingMembers[1:]
}
}
return interleavedAddrs
}
// addressFamily returns the ipAddrFamily after parsing the address string.
// If the address isn't of the format "ip-address:port", it returns
// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when
// using a resolver like passthrough where the address may be a hostname in
// some format that the dialer can resolve.
func addressFamily(address string) ipAddrFamily {
// Parse the IP after removing the port.
host, _, err := net.SplitHostPort(address)
if err != nil {
return ipAddrFamilyUnknown
}
ip, err := netip.ParseAddr(host)
if err != nil {
return ipAddrFamilyUnknown
}
switch {
case ip.Is4() || ip.Is4In6():
return ipAddrFamilyV4
case ip.Is6():
return ipAddrFamilyV6
default:
return ipAddrFamilyUnknown
}
}
// reconcileSubConnsLocked updates the active subchannels based on a new address
// list from the resolver. It does this by:
// - closing subchannels: any existing subchannels associated with addresses
// that are no longer in the updated list are shut down.
// - removing subchannels: entries for these closed subchannels are removed
// from the subchannel map.
//
// This ensures that the subchannel map accurately reflects the current set of
// addresses received from the name resolver.
func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) {
newAddrsMap := resolver.NewAddressMapV2[bool]()
for _, addr := range newAddrs {
newAddrsMap.Set(addr, true)
}
for _, oldAddr := range b.subConns.Keys() {
if _, ok := newAddrsMap.Get(oldAddr); ok {
continue
}
val, _ := b.subConns.Get(oldAddr)
val.subConn.Shutdown()
b.subConns.Delete(oldAddr)
}
}
// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn
// becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
b.cancelConnectionTimer()
for _, sd := range b.subConns.Values() {
if sd.subConn != selected.subConn {
sd.subConn.Shutdown()
}
}
b.subConns = resolver.NewAddressMapV2[*scData]()
b.subConns.Set(selected.addr, selected)
}
// requestConnectionLocked starts connecting on the subchannel corresponding to
// the current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted until a subchannel is found.
func (b *pickfirstBalancer) requestConnectionLocked() {
if !b.addressList.isValid() {
return
}
var lastErr error
for valid := true; valid; valid = b.addressList.increment() {
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
// We want to assign the new scData to sd from the outer scope,
// hence we can't use := below.
sd, err = b.newSCData(curAddr)
if err != nil {
// This should never happen, unless the clientConn is being shut
// down.
if b.logger.V(2) {
b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err)
}
// Do nothing, the LB policy will be closed soon.
return
}
b.subConns.Set(curAddr, sd)
}
switch sd.rawConnectivityState {
case connectivity.Idle:
sd.subConn.Connect()
b.scheduleNextConnectionLocked()
return
case connectivity.TransientFailure:
// The SubConn is being re-used and failed during a previous pass
// over the addressList. It has not completed backoff yet.
// Mark it as having failed and try the next address.
sd.connectionFailedInFirstPass = true
lastErr = sd.lastErr
continue
case connectivity.Connecting:
// Wait for the connection attempt to complete or the timer to fire
// before attempting the next address.
b.scheduleNextConnectionLocked()
return
default:
b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState)
return
}
}
// All the remaining addresses in the list are in TRANSIENT_FAILURE, end the
// first pass if possible.
b.endFirstPassIfPossibleLocked(lastErr)
}
func (b *pickfirstBalancer) scheduleNextConnectionLocked() {
b.cancelConnectionTimer()
if !b.addressList.hasNext() {
return
}
curAddr := b.addressList.currentAddress()
cancelled := false // Access to this is protected by the balancer's mutex.
closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() {
b.mu.Lock()
defer b.mu.Unlock()
// If the scheduled task is cancelled while acquiring the mutex, return.
if cancelled {
return
}
if b.logger.V(2) {
b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr)
}
if b.addressList.increment() {
b.requestConnectionLocked()
}
})
// Access to the cancellation callback held by the balancer is guarded by
// the balancer's mutex, so it's safe to set the boolean from the callback.
b.cancelConnectionTimer = sync.OnceFunc(func() {
cancelled = true
closeFn()
})
}
func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
oldState := sd.rawConnectivityState
sd.rawConnectivityState = newState.ConnectivityState
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes this
// SubConn.
if !b.isActiveSCData(sd) {
return
}
if newState.ConnectivityState == connectivity.Shutdown {
sd.effectiveState = connectivity.Shutdown
return
}
// Record a connection attempt when exiting CONNECTING.
if newState.ConnectivityState == connectivity.TransientFailure {
sd.connectionFailedInFirstPass = true
connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target)
}
if newState.ConnectivityState == connectivity.Ready {
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
b.shutdownRemainingLocked(sd)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one SubConn after
// entering READY. The SubConn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
if !b.healthCheckingEnabled {
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn)
}
sd.effectiveState = connectivity.Ready
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
return
}
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn)
}
// Send a CONNECTING update to take the SubConn out of sticky-TF if
// required.
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) {
b.updateSubConnHealthState(sd, scs)
})
return
}
// If the LB policy is READY, and it receives a subchannel state change,
// it means that the READY subchannel has failed.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) {
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.shutdownRemainingLocked(sd)
sd.effectiveState = newState.ConnectivityState
// READY SubConn interspliced in between CONNECTING and IDLE, need to
// account for that.
if oldState == connectivity.Connecting {
// A known issue (https://github.com/grpc/grpc-go/issues/7862)
// causes a race that prevents the READY state change notification.
// This works around it.
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
}
disconnectionsMetric.Record(b.metricsRecorder, 1, b.target)
b.addressList.reset()
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)},
})
return
}
if b.firstPass {
switch newState.ConnectivityState {
case connectivity.Connecting:
// The effective state can be in either IDLE, CONNECTING or
// TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in
// TRANSIENT_FAILURE until it's READY. See A62.
if sd.effectiveState != connectivity.TransientFailure {
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
}
case connectivity.TransientFailure:
sd.lastErr = newState.ConnectionError
sd.effectiveState = connectivity.TransientFailure
// Since we're re-using common SubConns while handling resolver
// updates, we could receive an out of turn TRANSIENT_FAILURE from
// a pass over the previous address list. Happy Eyeballs will also
// cause out of order updates to arrive.
if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
b.cancelConnectionTimer()
if b.addressList.increment() {
b.requestConnectionLocked()
return
}
}
// End the first pass if we've seen a TRANSIENT_FAILURE from all
// SubConns once.
b.endFirstPassIfPossibleLocked(newState.ConnectionError)
}
return
}
// We have finished the first pass, keep re-connecting failing SubConns.
switch newState.ConnectivityState {
case connectivity.TransientFailure:
b.numTF = (b.numTF + 1) % b.subConns.Len()
sd.lastErr = newState.ConnectionError
if b.numTF%b.subConns.Len() == 0 {
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: newState.ConnectionError},
})
}
// We don't need to request re-resolution since the SubConn already
// does that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from SubConn into
// pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
}
// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the
// addresses are tried and their SubConns have reported a failure.
func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
// An optimization to avoid iterating over the entire SubConn map.
if b.addressList.isValid() {
return
}
// Connect() has been called on all the SubConns. The first pass can be
// ended if all the SubConns have reported a failure.
for _, sd := range b.subConns.Values() {
if !sd.connectionFailedInFirstPass {
return
}
}
b.firstPass = false
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: lastErr},
})
// Start re-connecting all the SubConns that are already in IDLE.
for _, sd := range b.subConns.Values() {
if sd.rawConnectivityState == connectivity.Idle {
sd.subConn.Connect()
}
}
}
func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool {
activeSD, found := b.subConns.Get(sd.addr)
return found && activeSD == sd
}
func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes
// this SubConn.
if !b.isActiveSCData(sd) {
return
}
sd.effectiveState = state.ConnectivityState
switch state.ConnectivityState {
case connectivity.Ready:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
case connectivity.TransientFailure:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)},
})
case connectivity.Connecting:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
default:
b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state)
}
}
// updateBalancerState stores the state reported to the channel and calls
// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate
// updates to the channel.
func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) {
// In case of TransientFailures allow the picker to be updated to update
// the connectivity error, in all other cases don't send duplicate state
// updates.
if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure {
return
}
b.forceUpdateConcludedStateLocked(newState)
}
// forceUpdateConcludedStateLocked stores the state reported to the channel and
// calls ClientConn.UpdateState().
// A separate function is defined to force update the ClientConn state since the
// channel doesn't correctly assume that LB policies start in CONNECTING and
// relies on LB policy to send an initial CONNECTING update.
func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) {
b.state = newState.ConnectivityState
b.cc.UpdateState(newState)
}
type picker struct {
result balancer.PickResult
err error
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// addressList manages sequentially iterating over addresses present in a list
// of endpoints. It provides a 1 dimensional view of the addresses present in
// the endpoints.
// This type is not safe for concurrent access.
type addressList struct {
addresses []resolver.Address
idx int
}
func (al *addressList) isValid() bool {
return al.idx < len(al.addresses)
}
func (al *addressList) size() int {
return len(al.addresses)
}
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
return false
}
al.idx++
return al.idx < len(al.addresses)
}
// currentAddress returns the current address pointed to in the addressList.
// If the list is in an invalid state, it returns an empty address instead.
func (al *addressList) currentAddress() resolver.Address {
if !al.isValid() {
return resolver.Address{}
}
return al.addresses[al.idx]
}
func (al *addressList) reset() {
al.idx = 0
}
func (al *addressList) updateAddrs(addrs []resolver.Address) {
al.addresses = addrs
al.reset()
}
// seekTo returns false if the needle was not found and the current index was
// left unchanged.
func (al *addressList) seekTo(needle resolver.Address) bool {
for ai, addr := range al.addresses {
if !equalAddressIgnoringBalAttributes(&addr, &needle) {
continue
}
al.idx = ai
return true
}
return false
}
// hasNext returns whether incrementing the addressList will result in moving
// past the end of the list. If the list has already moved past the end, it
// returns false.
func (al *addressList) hasNext() bool {
if !al.isValid() {
return false
}
return al.idx+1 < len(al.addresses)
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered
// equal. This is different from the Equal method on the resolver.Address type
// which considers all fields to determine equality. Here, we only consider
// fields that are meaningful to the SubConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes)
}

View File

@@ -22,12 +22,13 @@
package roundrobin
import (
"sync/atomic"
"fmt"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcrand"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
)
// Name is the name of round_robin balancer.
@@ -35,47 +36,44 @@ const Name = "round_robin"
var logger = grpclog.Component("roundrobin")
// newBuilder creates a new roundrobin balancer builder.
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &rrPickerBuilder{}, base.Config{HealthCheck: true})
}
func init() {
balancer.Register(newBuilder())
balancer.Register(builder{})
}
type rrPickerBuilder struct{}
type builder struct{}
func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
logger.Infof("roundrobinPicker: Build called with info: %v", info)
if len(info.ReadySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
func (bb builder) Name() string {
return Name
}
func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
childBuilder := balancer.Get(pickfirstleaf.Name).Build
bal := &rrBalancer{
cc: cc,
Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}),
}
scs := make([]balancer.SubConn, 0, len(info.ReadySCs))
for sc := range info.ReadySCs {
scs = append(scs, sc)
}
return &rrPicker{
subConns: scs,
// Start at a random index, as the same RR balancer rebuilds a new
// picker when SubConn states change, and we don't want to apply excess
// load to the first server in the list.
next: uint32(grpcrand.Intn(len(scs))),
bal.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", bal))
bal.logger.Infof("Created")
return bal
}
type rrBalancer struct {
balancer.Balancer
cc balancer.ClientConn
logger *internalgrpclog.PrefixLogger
}
func (b *rrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
})
}
func (b *rrBalancer) ExitIdle() {
// Should always be ok, as child is endpoint sharding.
if ei, ok := b.Balancer.(balancer.ExitIdler); ok {
ei.ExitIdle()
}
}
type rrPicker struct {
// subConns is the snapshot of the roundrobin balancer when this picker was
// created. The slice is immutable. Each Get() will do a round robin
// selection from it and return the selected SubConn.
subConns []balancer.SubConn
next uint32
}
func (p *rrPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
subConnsLen := uint32(len(p.subConns))
nextIndex := atomic.AddUint32(&p.next, 1)
sc := p.subConns[nextIndex%subConnsLen]
return balancer.PickResult{SubConn: sc}, nil
}

134
vendor/google.golang.org/grpc/balancer/subconn.go generated vendored Normal file
View File

@@ -0,0 +1,134 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package balancer
import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/resolver"
)
// A SubConn represents a single connection to a gRPC backend service.
//
// All SubConns start in IDLE, and will not try to connect. To trigger a
// connection attempt, Balancers must call Connect.
//
// If the connection attempt fails, the SubConn will transition to
// TRANSIENT_FAILURE for a backoff period, and then return to IDLE. If the
// connection attempt succeeds, it will transition to READY.
//
// If a READY SubConn becomes disconnected, the SubConn will transition to IDLE.
//
// If a connection re-enters IDLE, Balancers must call Connect again to trigger
// a new connection attempt.
//
// Each SubConn contains a list of addresses. gRPC will try to connect to the
// addresses in sequence, and stop trying the remainder once the first
// connection is successful. However, this behavior is deprecated. SubConns
// should only use a single address.
//
// NOTICE: This interface is intended to be implemented by gRPC, or intercepted
// by custom load balancing polices. Users should not need their own complete
// implementation of this interface -- they should always delegate to a SubConn
// returned by ClientConn.NewSubConn() by embedding it in their implementations.
// An embedded SubConn must never be nil, or runtime panics will occur.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully close, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
//
// Deprecated: this method will be removed. Create new SubConns for new
// addresses instead.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which may be
// called when the Producer is no longer needed. Otherwise the producer
// will automatically be closed upon connection loss or subchannel close.
// Should only be called on a SubConn in state Ready. Otherwise the
// producer will be unable to create streams.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
// Shutdown shuts down the SubConn gracefully. Any started RPCs will be
// allowed to complete. No future calls should be made on the SubConn.
// One final state update will be delivered to the StateListener (or
// UpdateSubConnState; deprecated) with ConnectivityState of Shutdown to
// indicate the shutdown operation. This may be delivered before
// in-progress RPCs are complete and the actual connection is closed.
Shutdown()
// RegisterHealthListener registers a health listener that receives health
// updates for a Ready SubConn. Only one health listener can be registered
// at a time. A health listener should be registered each time the SubConn's
// connectivity state changes to READY. Registering a health listener when
// the connectivity state is not READY may result in undefined behaviour.
// This method must not be called synchronously while handling an update
// from a previously registered health listener.
RegisterHealthListener(func(SubConnState))
// EnforceSubConnEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceSubConnEmbedding
}
// A ProducerBuilder is a simple constructor for a Producer. It is used by the
// SubConn to create producers when needed.
type ProducerBuilder interface {
// Build creates a Producer. The first parameter is always a
// grpc.ClientConnInterface (a type to allow creating RPCs/streams on the
// associated SubConn), but is declared as `any` to avoid a dependency
// cycle. Build also returns a close function that will be called when all
// references to the Producer have been given up for a SubConn, or when a
// connectivity state change occurs on the SubConn. The close function
// should always block until all asynchronous cleanup work is completed.
Build(grpcClientConnInterface any) (p Producer, close func())
}
// SubConnState describes the state of a SubConn.
type SubConnState struct {
// ConnectivityState is the connectivity state of the SubConn.
ConnectivityState connectivity.State
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
}
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
// A Producer is a type shared among potentially many consumers. It is
// associated with a SubConn, and an implementation will typically contain
// other methods to provide additional functionality, e.g. configuration or
// subscription registration.
type Producer any

View File

@@ -24,11 +24,25 @@ import (
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
var (
setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// noOpRegisterHealthListenerFn is used when client side health checking is
// disabled. It sends a single READY update on the registered listener.
noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() {
listener(balancer.SubConnState{ConnectivityState: connectivity.Ready})
return func() {}
}
)
// ccBalancerWrapper sits between the ClientConn and the Balancer.
@@ -46,6 +60,7 @@ import (
// It uses the gracefulswitch.Balancer internally to ensure that balancer
// switches happen in a graceful manner.
type ccBalancerWrapper struct {
internal.EnforceClientConnEmbedding
// The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex.
cc *ClientConn
@@ -87,12 +102,16 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
return ccb
}
func (ccb *ccBalancerWrapper) MetricsRecorder() stats.MetricsRecorder {
return ccb.cc.metricsRecorderList
}
// updateClientConnState is invoked by grpc to push a ClientConnState update to
// the underlying balancer. This is always executed from the serializer, so
// it is safe to call into the balancer here.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
errCh := make(chan error)
ok := ccb.serializer.Schedule(func(ctx context.Context) {
uccs := func(ctx context.Context) {
defer close(errCh)
if ctx.Err() != nil || ccb.balancer == nil {
return
@@ -107,17 +126,23 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
}
errCh <- err
})
if !ok {
return nil
}
onFailure := func() { close(errCh) }
// UpdateClientConnState can race with Close, and when the latter wins, the
// serializer is closed, and the attempt to schedule the callback will fail.
// It is acceptable to ignore this failure. But since we want to handle the
// state update in a blocking fashion (when we successfully schedule the
// callback), we have to use the ScheduleOr method and not the MaybeSchedule
// method on the serializer.
ccb.serializer.ScheduleOr(uccs, onFailure)
return <-errCh
}
// resolverError is invoked by grpc to push a resolver error to the underlying
// balancer. The call to the balancer is executed from the serializer.
func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.serializer.Schedule(func(ctx context.Context) {
ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
@@ -133,7 +158,7 @@ func (ccb *ccBalancerWrapper) close() {
ccb.closed = true
ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing")
ccb.serializer.Schedule(func(context.Context) {
ccb.serializer.TrySchedule(func(context.Context) {
if ccb.balancer == nil {
return
}
@@ -145,7 +170,7 @@ func (ccb *ccBalancerWrapper) close() {
// exitIdle invokes the balancer's exitIdle method in the serializer.
func (ccb *ccBalancerWrapper) exitIdle() {
ccb.serializer.Schedule(func(ctx context.Context) {
ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
@@ -177,12 +202,13 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
ac: ac,
producers: make(map[balancer.ProducerBuilder]*refCountedProducer),
stateListener: opts.StateListener,
healthData: newHealthData(connectivity.Idle),
}
ac.acbw = acbw
return acbw, nil
}
func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) {
func (ccb *ccBalancerWrapper) RemoveSubConn(balancer.SubConn) {
// The graceful switch balancer will never call this.
logger.Errorf("ccb RemoveSubConn(%v) called unexpectedly, sc")
}
@@ -198,6 +224,10 @@ func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resol
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock()
if ccb.cc.conns == nil {
// The CC has been closed; ignore this update.
return
}
ccb.mu.Lock()
if ccb.closed {
@@ -238,25 +268,77 @@ func (ccb *ccBalancerWrapper) Target() string {
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
internal.EnforceSubConnEmbedding
ac *addrConn // read-only
ccb *ccBalancerWrapper // read-only
stateListener func(balancer.SubConnState)
mu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
producersMu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
// Access to healthData is protected by healthMu.
healthMu sync.Mutex
// healthData is stored as a pointer to detect when the health listener is
// dropped or updated. This is required as closures can't be compared for
// equality.
healthData *healthData
}
// healthData holds data related to health state reporting.
type healthData struct {
// connectivityState stores the most recent connectivity state delivered
// to the LB policy. This is stored to avoid sending updates when the
// SubConn has already exited connectivity state READY.
connectivityState connectivity.State
// closeHealthProducer stores function to close the ref counted health
// producer. The health producer is automatically closed when the SubConn
// state changes.
closeHealthProducer func()
}
func newHealthData(s connectivity.State) *healthData {
return &healthData{
connectivityState: s,
closeHealthProducer: func() {},
}
}
// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Invalidate all producers on any state change.
acbw.closeProducers()
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
// Invalidate the health listener by updating the healthData.
acbw.healthMu.Lock()
// A race may occur if a health listener is registered soon after the
// connectivity state is set but before the stateListener is called.
// Two cases may arise:
// 1. The new state is not READY: RegisterHealthListener has checks to
// ensure no updates are sent when the connectivity state is not
// READY.
// 2. The new state is READY: This means that the old state wasn't Ready.
// The RegisterHealthListener API mentions that a health listener
// must not be registered when a SubConn is not ready to avoid such
// races. When this happens, the LB policy would get health updates
// on the old listener. When the LB policy registers a new listener
// on receiving the connectivity update, the health updates will be
// sent to the new health listener.
acbw.healthData = newHealthData(scs.ConnectivityState)
acbw.healthMu.Unlock()
acbw.stateListener(scs)
})
}
@@ -273,6 +355,7 @@ func (acbw *acBalancerWrapper) Connect() {
}
func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}
@@ -280,9 +363,10 @@ func (acbw *acBalancerWrapper) Shutdown() {
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
transport, err := acbw.ac.getTransport(ctx)
if err != nil {
return nil, err
transport := acbw.ac.getReadyTransport()
if transport == nil {
return nil, status.Errorf(codes.Unavailable, "SubConn state is not Ready")
}
return newNonRetryClientStream(ctx, desc, method, transport, acbw.ac, opts...)
}
@@ -307,15 +391,15 @@ type refCountedProducer struct {
}
func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) {
acbw.mu.Lock()
defer acbw.mu.Unlock()
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
// Look up existing producer from this builder.
pData := acbw.producers[pb]
if pData == nil {
// Not found; create a new one and add it to the producers map.
p, close := pb.Build(acbw)
pData = &refCountedProducer{producer: p, close: close}
p, closeFn := pb.Build(acbw)
pData = &refCountedProducer{producer: p, close: closeFn}
acbw.producers[pb] = pData
}
// Account for this new reference.
@@ -325,13 +409,112 @@ func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (
// and delete the refCountedProducer from the map if the total reference
// count goes to zero.
unref := func() {
acbw.mu.Lock()
acbw.producersMu.Lock()
// If closeProducers has already closed this producer instance, refs is
// set to 0, so the check after decrementing will never pass, and the
// producer will not be double-closed.
pData.refs--
if pData.refs == 0 {
defer pData.close() // Run outside the acbw mutex
delete(acbw.producers, pb)
}
acbw.mu.Unlock()
acbw.producersMu.Unlock()
}
return pData.producer, grpcsync.OnceFunc(unref)
return pData.producer, sync.OnceFunc(unref)
}
func (acbw *acBalancerWrapper) closeProducers() {
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
for pb, pData := range acbw.producers {
pData.refs = 0
pData.close()
delete(acbw.producers, pb)
}
}
// healthProducerRegisterFn is a type alias for the health producer's function
// for registering listeners.
type healthProducerRegisterFn = func(context.Context, balancer.SubConn, string, func(balancer.SubConnState)) func()
// healthListenerRegFn returns a function to register a listener for health
// updates. If client side health checks are disabled, the registered listener
// will get a single READY (raw connectivity state) update.
//
// Client side health checking is enabled when all the following
// conditions are satisfied:
// 1. Health checking is not disabled using the dial option.
// 2. The health package is imported.
// 3. The health check config is present in the service config.
func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func(balancer.SubConnState)) func() {
if acbw.ccb.cc.dopts.disableHealthCheck {
return noOpRegisterHealthListenerFn
}
regHealthLisFn := internal.RegisterClientHealthCheckListener
if regHealthLisFn == nil {
// The health package is not imported.
return noOpRegisterHealthListenerFn
}
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
return noOpRegisterHealthListenerFn
}
return func(ctx context.Context, listener func(balancer.SubConnState)) func() {
return regHealthLisFn.(healthProducerRegisterFn)(ctx, acbw, cfg.ServiceName, listener)
}
}
// RegisterHealthListener accepts a health listener from the LB policy. It sends
// updates to the health listener as long as the SubConn's connectivity state
// doesn't change and a new health listener is not registered. To invalidate
// the currently registered health listener, acbw updates the healthData. If a
// nil listener is registered, the active health listener is dropped.
func (acbw *acBalancerWrapper) RegisterHealthListener(listener func(balancer.SubConnState)) {
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
acbw.healthData.closeHealthProducer()
// listeners should not be registered when the connectivity state
// isn't Ready. This may happen when the balancer registers a listener
// after the connectivityState is updated, but before it is notified
// of the update.
if acbw.healthData.connectivityState != connectivity.Ready {
return
}
// Replace the health data to stop sending updates to any previously
// registered health listeners.
hd := newHealthData(connectivity.Ready)
acbw.healthData = hd
if listener == nil {
return
}
registerFn := acbw.healthListenerRegFn()
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Don't send updates if a new listener is registered.
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
if acbw.healthData != hd {
return
}
// Serialize the health updates from the health producer with
// other calls into the LB policy.
listenerWrapper := func(scs balancer.SubConnState) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
if acbw.healthData != hd {
return
}
listener(scs)
})
}
hd.closeHealthProducer = registerFn(ctx, listenerWrapper)
})
}

View File

@@ -18,8 +18,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc v4.25.2
// protoc-gen-go v1.36.5
// protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto
package grpc_binarylog_v1
@@ -31,6 +31,7 @@ import (
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -233,10 +234,7 @@ func (Address_Type) EnumDescriptor() ([]byte, []int) {
// Log entry we store in binary logs
type GrpcLogEntry struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// The timestamp of the binary log message
Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Uniquely identifies a call. The value must not be 0 in order to disambiguate
@@ -255,7 +253,7 @@ type GrpcLogEntry struct {
// The logger uses one of the following fields to record the payload,
// according to the type of the log entry.
//
// Types that are assignable to Payload:
// Types that are valid to be assigned to Payload:
//
// *GrpcLogEntry_ClientHeader
// *GrpcLogEntry_ServerHeader
@@ -269,16 +267,16 @@ type GrpcLogEntry struct {
// EVENT_TYPE_SERVER_HEADER normally or EVENT_TYPE_SERVER_TRAILER in
// the case of trailers-only. On server side, peer is always
// logged on EVENT_TYPE_CLIENT_HEADER.
Peer *Address `protobuf:"bytes,11,opt,name=peer,proto3" json:"peer,omitempty"`
Peer *Address `protobuf:"bytes,11,opt,name=peer,proto3" json:"peer,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GrpcLogEntry) Reset() {
*x = GrpcLogEntry{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GrpcLogEntry) String() string {
@@ -289,7 +287,7 @@ func (*GrpcLogEntry) ProtoMessage() {}
func (x *GrpcLogEntry) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -339,37 +337,45 @@ func (x *GrpcLogEntry) GetLogger() GrpcLogEntry_Logger {
return GrpcLogEntry_LOGGER_UNKNOWN
}
func (m *GrpcLogEntry) GetPayload() isGrpcLogEntry_Payload {
if m != nil {
return m.Payload
func (x *GrpcLogEntry) GetPayload() isGrpcLogEntry_Payload {
if x != nil {
return x.Payload
}
return nil
}
func (x *GrpcLogEntry) GetClientHeader() *ClientHeader {
if x, ok := x.GetPayload().(*GrpcLogEntry_ClientHeader); ok {
return x.ClientHeader
if x != nil {
if x, ok := x.Payload.(*GrpcLogEntry_ClientHeader); ok {
return x.ClientHeader
}
}
return nil
}
func (x *GrpcLogEntry) GetServerHeader() *ServerHeader {
if x, ok := x.GetPayload().(*GrpcLogEntry_ServerHeader); ok {
return x.ServerHeader
if x != nil {
if x, ok := x.Payload.(*GrpcLogEntry_ServerHeader); ok {
return x.ServerHeader
}
}
return nil
}
func (x *GrpcLogEntry) GetMessage() *Message {
if x, ok := x.GetPayload().(*GrpcLogEntry_Message); ok {
return x.Message
if x != nil {
if x, ok := x.Payload.(*GrpcLogEntry_Message); ok {
return x.Message
}
}
return nil
}
func (x *GrpcLogEntry) GetTrailer() *Trailer {
if x, ok := x.GetPayload().(*GrpcLogEntry_Trailer); ok {
return x.Trailer
if x != nil {
if x, ok := x.Payload.(*GrpcLogEntry_Trailer); ok {
return x.Trailer
}
}
return nil
}
@@ -418,10 +424,7 @@ func (*GrpcLogEntry_Message) isGrpcLogEntry_Payload() {}
func (*GrpcLogEntry_Trailer) isGrpcLogEntry_Payload() {}
type ClientHeader struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
// The name of the RPC method, which looks something like:
@@ -435,16 +438,16 @@ type ClientHeader struct {
// <host> or <host>:<port> .
Authority string `protobuf:"bytes,3,opt,name=authority,proto3" json:"authority,omitempty"`
// the RPC timeout
Timeout *durationpb.Duration `protobuf:"bytes,4,opt,name=timeout,proto3" json:"timeout,omitempty"`
Timeout *durationpb.Duration `protobuf:"bytes,4,opt,name=timeout,proto3" json:"timeout,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ClientHeader) Reset() {
*x = ClientHeader{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ClientHeader) String() string {
@@ -455,7 +458,7 @@ func (*ClientHeader) ProtoMessage() {}
func (x *ClientHeader) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -499,21 +502,18 @@ func (x *ClientHeader) GetTimeout() *durationpb.Duration {
}
type ServerHeader struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ServerHeader) Reset() {
*x = ServerHeader{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ServerHeader) String() string {
@@ -524,7 +524,7 @@ func (*ServerHeader) ProtoMessage() {}
func (x *ServerHeader) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -547,10 +547,7 @@ func (x *ServerHeader) GetMetadata() *Metadata {
}
type Trailer struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
// The gRPC status code.
@@ -561,15 +558,15 @@ type Trailer struct {
// The value of the 'grpc-status-details-bin' metadata key. If
// present, this is always an encoded 'google.rpc.Status' message.
StatusDetails []byte `protobuf:"bytes,4,opt,name=status_details,json=statusDetails,proto3" json:"status_details,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Trailer) Reset() {
*x = Trailer{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Trailer) String() string {
@@ -580,7 +577,7 @@ func (*Trailer) ProtoMessage() {}
func (x *Trailer) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -625,24 +622,21 @@ func (x *Trailer) GetStatusDetails() []byte {
// Message payload, used by CLIENT_MESSAGE and SERVER_MESSAGE
type Message struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// Length of the message. It may not be the same as the length of the
// data field, as the logging payload can be truncated or omitted.
Length uint32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"`
// May be truncated or omitted.
Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Message) Reset() {
*x = Message{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Message) String() string {
@@ -653,7 +647,7 @@ func (*Message) ProtoMessage() {}
func (x *Message) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -704,20 +698,17 @@ func (x *Message) GetData() []byte {
// header is just a normal metadata key.
// The pair will not count towards the size limit.
type Metadata struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Entry []*MetadataEntry `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
unknownFields protoimpl.UnknownFields
Entry []*MetadataEntry `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *Metadata) Reset() {
*x = Metadata{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Metadata) String() string {
@@ -728,7 +719,7 @@ func (*Metadata) ProtoMessage() {}
func (x *Metadata) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -752,21 +743,18 @@ func (x *Metadata) GetEntry() []*MetadataEntry {
// A metadata key value pair
type MetadataEntry struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *MetadataEntry) Reset() {
*x = MetadataEntry{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *MetadataEntry) String() string {
@@ -777,7 +765,7 @@ func (*MetadataEntry) ProtoMessage() {}
func (x *MetadataEntry) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -808,23 +796,20 @@ func (x *MetadataEntry) GetValue() []byte {
// Address information
type Address struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type Address_Type `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.binarylog.v1.Address_Type" json:"type,omitempty"`
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
Type Address_Type `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.binarylog.v1.Address_Type" json:"type,omitempty"`
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
// only for TYPE_IPV4 and TYPE_IPV6
IpPort uint32 `protobuf:"varint,3,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"`
IpPort uint32 `protobuf:"varint,3,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Address) Reset() {
*x = Address{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Address) String() string {
@@ -835,7 +820,7 @@ func (*Address) ProtoMessage() {}
func (x *Address) ProtoReflect() protoreflect.Message {
mi := &file_grpc_binlog_v1_binarylog_proto_msgTypes[7]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -873,7 +858,7 @@ func (x *Address) GetIpPort() uint32 {
var File_grpc_binlog_v1_binarylog_proto protoreflect.FileDescriptor
var file_grpc_binlog_v1_binarylog_proto_rawDesc = []byte{
var file_grpc_binlog_v1_binarylog_proto_rawDesc = string([]byte{
0x0a, 0x1e, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x62, 0x69, 0x6e, 0x6c, 0x6f, 0x67, 0x2f, 0x76, 0x31,
0x2f, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x11, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
@@ -999,23 +984,23 @@ var file_grpc_binlog_v1_binarylog_proto_rawDesc = []byte{
0x69, 0x6e, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x62, 0x69,
0x6e, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
})
var (
file_grpc_binlog_v1_binarylog_proto_rawDescOnce sync.Once
file_grpc_binlog_v1_binarylog_proto_rawDescData = file_grpc_binlog_v1_binarylog_proto_rawDesc
file_grpc_binlog_v1_binarylog_proto_rawDescData []byte
)
func file_grpc_binlog_v1_binarylog_proto_rawDescGZIP() []byte {
file_grpc_binlog_v1_binarylog_proto_rawDescOnce.Do(func() {
file_grpc_binlog_v1_binarylog_proto_rawDescData = protoimpl.X.CompressGZIP(file_grpc_binlog_v1_binarylog_proto_rawDescData)
file_grpc_binlog_v1_binarylog_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_grpc_binlog_v1_binarylog_proto_rawDesc), len(file_grpc_binlog_v1_binarylog_proto_rawDesc)))
})
return file_grpc_binlog_v1_binarylog_proto_rawDescData
}
var file_grpc_binlog_v1_binarylog_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_grpc_binlog_v1_binarylog_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
var file_grpc_binlog_v1_binarylog_proto_goTypes = []interface{}{
var file_grpc_binlog_v1_binarylog_proto_goTypes = []any{
(GrpcLogEntry_EventType)(0), // 0: grpc.binarylog.v1.GrpcLogEntry.EventType
(GrpcLogEntry_Logger)(0), // 1: grpc.binarylog.v1.GrpcLogEntry.Logger
(Address_Type)(0), // 2: grpc.binarylog.v1.Address.Type
@@ -1057,105 +1042,7 @@ func file_grpc_binlog_v1_binarylog_proto_init() {
if File_grpc_binlog_v1_binarylog_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_grpc_binlog_v1_binarylog_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GrpcLogEntry); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ClientHeader); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ServerHeader); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Trailer); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Metadata); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MetadataEntry); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Address); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_grpc_binlog_v1_binarylog_proto_msgTypes[0].OneofWrappers = []interface{}{
file_grpc_binlog_v1_binarylog_proto_msgTypes[0].OneofWrappers = []any{
(*GrpcLogEntry_ClientHeader)(nil),
(*GrpcLogEntry_ServerHeader)(nil),
(*GrpcLogEntry_Message)(nil),
@@ -1165,7 +1052,7 @@ func file_grpc_binlog_v1_binarylog_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_grpc_binlog_v1_binarylog_proto_rawDesc,
RawDescriptor: unsafe.Slice(unsafe.StringData(file_grpc_binlog_v1_binarylog_proto_rawDesc), len(file_grpc_binlog_v1_binarylog_proto_rawDesc)),
NumEnums: 3,
NumMessages: 8,
NumExtensions: 0,
@@ -1177,7 +1064,6 @@ func file_grpc_binlog_v1_binarylog_proto_init() {
MessageInfos: file_grpc_binlog_v1_binarylog_proto_msgTypes,
}.Build()
File_grpc_binlog_v1_binarylog_proto = out.File
file_grpc_binlog_v1_binarylog_proto_rawDesc = nil
file_grpc_binlog_v1_binarylog_proto_goTypes = nil
file_grpc_binlog_v1_binarylog_proto_depIdxs = nil
}

View File

@@ -24,6 +24,7 @@ import (
"fmt"
"math"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
@@ -31,14 +32,15 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle"
"google.golang.org/grpc/internal/pretty"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stats"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver"
@@ -73,6 +75,8 @@ var (
// invalidDefaultServiceConfigErrPrefix is used to prefix the json parsing error for the default
// service config.
invalidDefaultServiceConfigErrPrefix = "grpc: the provided default service config is invalid"
// PickFirstBalancerName is the name of the pick_first balancer.
PickFirstBalancerName = pickfirst.Name
)
// The following errors are returned from Dial and DialContext
@@ -114,15 +118,30 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires
// NewClient creates a new gRPC "channel" for the target URI provided. No I/O
// is performed. Use of the ClientConn for RPCs will automatically cause it to
// connect. Connect may be used to manually create a connection, but for most
// users this is unnecessary.
// connect. The Connect method may be called to manually create a connection,
// but for most users this should be unnecessary.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md. e.g. to use dns
// resolver, a "dns:///" prefix should be applied to the target.
// https://github.com/grpc/grpc/blob/master/doc/naming.md. E.g. to use the dns
// name resolver, a "dns:///" prefix may be applied to the target. The default
// name resolver will be used if no scheme is detected, or if the parsed scheme
// is not a registered name resolver. The default resolver is "dns" but can be
// overridden using the resolver package's SetDefaultScheme.
//
// The DialOptions returned by WithBlock, WithTimeout, and
// WithReturnConnectionError are ignored by this function.
// Examples:
//
// - "foo.googleapis.com:8080"
// - "dns:///foo.googleapis.com:8080"
// - "dns:///foo.googleapis.com"
// - "dns:///10.0.0.213:8080"
// - "dns:///%5B2001:db8:85a3:8d3:1319:8a2e:370:7348%5D:443"
// - "dns://8.8.8.8/foo.googleapis.com:8080"
// - "dns://8.8.8.8/foo.googleapis.com"
// - "zookeeper://zk.example.com:9900/example_service"
//
// The DialOptions returned by WithBlock, WithTimeout,
// WithReturnConnectionError, and FailOnNonTempDialError are ignored by this
// function.
func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
@@ -152,6 +171,16 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
for _, opt := range opts {
opt.apply(&cc.dopts)
}
// Determine the resolver to use.
if err := cc.initParsedTargetAndResolverBuilder(); err != nil {
return nil, err
}
for _, opt := range globalPerTargetDialOptions {
opt.DialOptionForTarget(cc.parsedTarget.URL).apply(&cc.dopts)
}
chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)
@@ -160,42 +189,38 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
}
if cc.dopts.defaultServiceConfigRawJSON != nil {
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON)
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON, cc.dopts.maxCallAttempts)
if scpr.Err != nil {
return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
}
cc.dopts.defaultServiceConfig, _ = scpr.Config.(*ServiceConfig)
}
cc.mkp = cc.dopts.copts.KeepaliveParams
cc.keepaliveParams = cc.dopts.copts.KeepaliveParams
// Register ClientConn with channelz.
if err = cc.initAuthority(); err != nil {
return nil, err
}
// Register ClientConn with channelz. Note that this is only done after
// channel creation cannot fail.
cc.channelzRegistration(target)
// TODO: Ideally it should be impossible to error from this function after
// channelz registration. This will require removing some channelz logs
// from the following functions that can error. Errors can be returned to
// the user, and successful logs can be emitted here, after the checks have
// passed and channelz is subsequently registered.
// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
channelz.RemoveEntry(cc.channelz.ID)
return nil, err
}
if err = cc.determineAuthority(); err != nil {
channelz.RemoveEntry(cc.channelz.ID)
return nil, err
}
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", cc.parsedTarget)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
return cc, nil
}
// Dial calls DialContext(context.Background(), target, opts...).
//
// Deprecated: use NewClient instead. Will be supported throughout 1.x.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
@@ -209,10 +234,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// "passthrough" for backward compatibility. This distinction should not matter
// to most users, but could matter to legacy users that specify a custom dialer
// and expect it to receive the target string directly.
//
// Deprecated: use NewClient instead. Will be supported throughout 1.x.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
// At the end of this method, we kick the channel out of idle, rather than
// waiting for the first rpc.
opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...)
//
// WithLocalDNSResolution dial option in `grpc.Dial` ensures that it
// preserves behavior: when default scheme passthrough is used, skip
// hostname resolution, when "dns" is used for resolution, perform
// resolution on the client.
opts = append([]DialOption{withDefaultScheme("passthrough"), WithLocalDNSResolution()}, opts...)
cc, err := NewClient(target, opts...)
if err != nil {
return nil, err
@@ -582,13 +614,14 @@ type ClientConn struct {
cancel context.CancelFunc // Cancelled on close.
// The following are initialized at dial time, and are read-only after that.
target string // User's dial target.
parsedTarget resolver.Target // See parseTargetAndFindResolver().
authority string // See determineAuthority().
dopts dialOptions // Default and user specified dial options.
channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See parseTargetAndFindResolver().
idlenessMgr *idle.Manager
target string // User's dial target.
parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder().
authority string // See initAuthority().
dopts dialOptions // Default and user specified dial options.
channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager
metricsRecorderList *stats.MetricsRecorderList
// The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them.
@@ -604,7 +637,7 @@ type ClientConn struct {
balancerWrapper *ccBalancerWrapper // Always recreated whenever entering idle to simplify Close.
sc *ServiceConfig // Latest service config received from the resolver.
conns map[*addrConn]struct{} // Set to nil on close.
mkp keepalive.ClientParameters // May be updated upon receipt of a GoAway.
keepaliveParams keepalive.ClientParameters // May be updated upon receipt of a GoAway.
// firstResolveEvent is used to track whether the name resolver sent us at
// least one update. RPCs block on this event. May be accessed without mu
// if we know we cannot be asked to enter idle mode while accessing it (e.g.
@@ -618,11 +651,6 @@ type ClientConn struct {
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
// ctx expires. A true value is returned in former case and false in latter.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
ch := cc.csMgr.getNotifyChan()
if cc.csMgr.getState() != sourceState {
@@ -637,11 +665,6 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec
}
// GetState returns the connectivity.State of ClientConn.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (cc *ClientConn) GetState() connectivity.State {
return cc.csMgr.getState()
}
@@ -688,8 +711,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
var emptyServiceConfig *ServiceConfig
func init() {
balancer.Register(pickfirstBuilder{})
cfg := parseServiceConfig("{}")
cfg := parseServiceConfig("{}", defaultMaxCallAttempts)
if cfg.Err != nil {
panic(fmt.Sprintf("impossible error parsing empty service config: %v", cfg.Err))
}
@@ -772,10 +794,7 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
}
}
var balCfg serviceconfig.LoadBalancingConfig
if cc.sc != nil && cc.sc.lbConfig != nil {
balCfg = cc.sc.lbConfig
}
balCfg := cc.sc.lbConfig
bw := cc.balancerWrapper
cc.mu.Unlock()
@@ -805,17 +824,11 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
cc.csMgr.updateState(connectivity.TransientFailure)
}
// Makes a copy of the input addresses slice and clears out the balancer
// attributes field. Addresses are passed during subconn creation and address
// update operations. In both cases, we will clear the balancer attributes by
// calling this function, and therefore we will be able to use the Equal method
// provided by the resolver.Address type for comparison.
func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address {
// Makes a copy of the input addresses slice. Addresses are passed during
// subconn creation and address update operations.
func copyAddresses(in []resolver.Address) []resolver.Address {
out := make([]resolver.Address, len(in))
for i := range in {
out[i] = in[i]
out[i].BalancerAttributes = nil
}
copy(out, in)
return out
}
@@ -830,14 +843,16 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
ac := &addrConn{
state: connectivity.Idle,
cc: cc,
addrs: copyAddressesWithoutBalancerAttributes(addrs),
addrs: copyAddresses(addrs),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
stateChan: make(chan struct{}),
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
// we connect to different addresses.
ac.channelz.ChannelMetrics.Target.Store(&addrs[0].Addr)
channelz.AddTraceEvent(logger, ac.channelz, 0, &channelz.TraceEvent{
Desc: "Subchannel created",
@@ -871,7 +886,13 @@ func (cc *ClientConn) Target() string {
return cc.target
}
// CanonicalTarget returns the canonical target string of the ClientConn.
// CanonicalTarget returns the canonical target string used when creating cc.
//
// This always has the form "<scheme>://[authority]/<endpoint>". For example:
//
// - "dns:///example.com:42"
// - "dns://8.8.8.8/example.com:42"
// - "unix:///path/to/socket"
func (cc *ClientConn) CanonicalTarget() string {
return cc.parsedTarget.String()
}
@@ -908,32 +929,37 @@ func (ac *addrConn) connect() error {
ac.mu.Unlock()
return nil
}
ac.mu.Unlock()
ac.resetTransport()
ac.resetTransportAndUnlock()
return nil
}
func equalAddresses(a, b []resolver.Address) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !v.Equal(b[i]) {
return false
}
}
return true
// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
// This is different from the Equal method on the resolver.Address type which
// considers all fields to determine equality. Here, we only consider fields
// that are meaningful to the subConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes) &&
a.Metadata == b.Metadata
}
func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
return slices.EqualFunc(a, b, func(a, b resolver.Address) bool { return equalAddressIgnoringBalAttributes(&a, &b) })
}
// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.mu.Lock()
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs curAddr: %v, addrs: %v", pretty.ToJSON(ac.curAddr), pretty.ToJSON(addrs))
addrs = copyAddresses(addrs)
limit := len(addrs)
if limit > 5 {
limit = 5
}
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
addrs = copyAddressesWithoutBalancerAttributes(addrs)
if equalAddresses(ac.addrs, addrs) {
ac.mu.Lock()
if equalAddressesIgnoringBalAttributes(ac.addrs, addrs) {
ac.mu.Unlock()
return
}
@@ -952,7 +978,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
// Try to find the connected address.
for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a)
if a.Equal(ac.curAddr) {
if equalAddressIgnoringBalAttributes(&a, &ac.curAddr) {
// We are connected to a valid address, so do nothing but
// update the addresses.
ac.mu.Unlock()
@@ -978,11 +1004,9 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.updateConnectivityState(connectivity.Idle, nil)
}
ac.mu.Unlock()
// Since we were connecting/connected, we should start a new connection
// attempt.
go ac.resetTransport()
go ac.resetTransportAndUnlock()
}
// getServerName determines the serverName to be used in the connection
@@ -1138,10 +1162,15 @@ func (cc *ClientConn) Close() error {
<-cc.resolverWrapper.serializer.Done()
<-cc.balancerWrapper.serializer.Done()
var wg sync.WaitGroup
for ac := range conns {
ac.tearDown(ErrClientConnClosing)
wg.Add(1)
go func(ac *addrConn) {
defer wg.Done()
ac.tearDown(ErrClientConnClosing)
}(ac)
}
wg.Wait()
cc.addTraceEvent("deleted")
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from being
@@ -1167,13 +1196,16 @@ type addrConn struct {
// is received, transport is closed, ac has been torn down).
transport transport.ClientTransport // The current transport.
// This mutex is used on the RPC path, so its usage should be minimized as
// much as possible.
// TODO: Find a lock-free way to retrieve the transport and state from the
// addrConn.
mu sync.Mutex
curAddr resolver.Address // The current address.
addrs []resolver.Address // All addresses that the resolver resolved to.
// Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State
stateChan chan struct{} // closed and recreated on every state change.
state connectivity.State
backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{}
@@ -1186,9 +1218,6 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s {
return
}
// When changing states, reset the state change channel.
close(ac.stateChan)
ac.stateChan = make(chan struct{})
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
@@ -1196,25 +1225,26 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, lastErr)
ac.acbw.updateState(s, ac.curAddr, lastErr)
}
// adjustParams updates parameters used to create transports upon
// receiving a GoAway.
func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
switch r {
case transport.GoAwayTooManyPings:
if r == transport.GoAwayTooManyPings {
v := 2 * ac.dopts.copts.KeepaliveParams.Time
ac.cc.mu.Lock()
if v > ac.cc.mkp.Time {
ac.cc.mkp.Time = v
if v > ac.cc.keepaliveParams.Time {
ac.cc.keepaliveParams.Time = v
}
ac.cc.mu.Unlock()
}
}
func (ac *addrConn) resetTransport() {
ac.mu.Lock()
// resetTransportAndUnlock unconditionally connects the addrConn.
//
// ac.mu must be held by the caller, and this function will guarantee it is released.
func (ac *addrConn) resetTransportAndUnlock() {
acCtx := ac.ctx
if acCtx.Err() != nil {
ac.mu.Unlock()
@@ -1245,6 +1275,8 @@ func (ac *addrConn) resetTransport() {
ac.mu.Unlock()
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
// to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
ac.mu.Lock()
if acCtx.Err() != nil {
@@ -1286,19 +1318,20 @@ func (ac *addrConn) resetTransport() {
ac.mu.Unlock()
}
// tryAllAddrs tries to creates a connection to the addresses, and stop when at
// tryAllAddrs tries to create a connection to the addresses, and stop when at
// the first successful one. It returns an error if no address was successfully
// connected, or updates ac appropriately with the new transport.
func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error {
var firstConnErr error
for _, addr := range addrs {
ac.channelz.ChannelMetrics.Target.Store(&addr.Addr)
if ctx.Err() != nil {
return errConnClosing
}
ac.mu.Lock()
ac.cc.mu.RLock()
ac.dopts.copts.KeepaliveParams = ac.cc.mkp
ac.dopts.copts.KeepaliveParams = ac.cc.keepaliveParams
ac.cc.mu.RUnlock()
copts := ac.dopts.copts
@@ -1362,7 +1395,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
defer cancel()
copts.ChannelzParent = ac.channelz
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose)
newTr, err := transport.NewHTTP2Client(connectCtx, ac.cc.ctx, addr, copts, onClose)
if err != nil {
if logger.V(2) {
logger.Infof("Creating new client transport to %q: %v", addr, err)
@@ -1436,7 +1469,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
if !ac.scopts.HealthCheckEnabled {
return
}
healthCheckFunc := ac.cc.dopts.healthCheckFunc
healthCheckFunc := internal.HealthCheckFunc
if healthCheckFunc == nil {
// The health package is not imported to set health check function.
//
@@ -1468,7 +1501,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}
// Start the health checking stream.
go func() {
err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if err != nil {
if status.Code(err) == codes.Unimplemented {
channelz.Error(logger, ac.channelz, "Subchannel health check is unimplemented at server side, thus health check is disabled")
@@ -1497,29 +1530,6 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
return nil
}
// getTransport waits until the addrconn is ready and returns the transport.
// If the context expires first, returns an appropriate status. If the
// addrConn is stopped first, returns an Unavailable status error.
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil {
ac.mu.Lock()
t, state, sc := ac.transport, ac.state, ac.stateChan
ac.mu.Unlock()
if state == connectivity.Ready {
return t, nil
}
if state == connectivity.Shutdown {
return nil, status.Errorf(codes.Unavailable, "SubConn shutting down")
}
select {
case <-ctx.Done():
case <-sc:
}
}
return nil, status.FromContextError(ctx.Err()).Err()
}
// tearDown starts to tear down the addrConn.
//
// Note that tearDown doesn't remove ac from ac.cc.conns, so the addrConn struct
@@ -1566,7 +1576,7 @@ func (ac *addrConn) tearDown(err error) {
} else {
// Hard close the transport when the channel is entering idle or is
// being shutdown. In the case where the channel is being shutdown,
// closing of transports is also taken care of by cancelation of cc.ctx.
// closing of transports is also taken care of by cancellation of cc.ctx.
// But in the case where the channel is entering idle, we need to
// explicitly close the transports here. Instead of distinguishing
// between these two cases, it is simpler to close the transport
@@ -1657,22 +1667,19 @@ func (cc *ClientConn) connectionError() error {
return cc.lastConnectionError
}
// parseTargetAndFindResolver parses the user's dial target and stores the
// parsed target in `cc.parsedTarget`.
// initParsedTargetAndResolverBuilder parses the user's dial target and stores
// the parsed target in `cc.parsedTarget`.
//
// The resolver to use is determined based on the scheme in the parsed target
// and the same is stored in `cc.resolverBuilder`.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) parseTargetAndFindResolver() error {
channelz.Infof(logger, cc.channelz, "original dial target is: %q", cc.target)
func (cc *ClientConn) initParsedTargetAndResolverBuilder() error {
logger.Infof("original dial target is: %q", cc.target)
var rb resolver.Builder
parsedTarget, err := parseTarget(cc.target)
if err != nil {
channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", cc.target, err)
} else {
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", parsedTarget)
if err == nil {
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb != nil {
cc.parsedTarget = parsedTarget
@@ -1691,15 +1698,12 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
defScheme = resolver.GetDefaultScheme()
}
channelz.Infof(logger, cc.channelz, "fallback to scheme %q", defScheme)
canonicalTarget := defScheme + ":///" + cc.target
parsedTarget, err = parseTarget(canonicalTarget)
if err != nil {
channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", canonicalTarget, err)
return err
}
channelz.Infof(logger, cc.channelz, "parsed dial target is: %+v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb == nil {
return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme)
@@ -1739,7 +1743,7 @@ func encodeAuthority(authority string) string {
return false
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // Subdelim characters
return false
case ':', '[', ']', '@': // Authority related delimeters
case ':', '[', ']', '@': // Authority related delimiters
return false
}
// Everything else must be escaped.
@@ -1789,7 +1793,7 @@ func encodeAuthority(authority string) string {
// credentials do not match the authority configured through the dial option.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) determineAuthority() error {
func (cc *ClientConn) initAuthority() error {
dopts := cc.dopts
// Historically, we had two options for users to specify the serverName or
// authority for a channel. One was through the transport credentials
@@ -1822,6 +1826,5 @@ func (cc *ClientConn) determineAuthority() error {
} else {
cc.authority = encodeAuthority(endpoint)
}
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
return nil
}

View File

@@ -21,18 +21,73 @@ package grpc
import (
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
"google.golang.org/grpc/mem"
)
// baseCodec contains the functionality of both Codec and encoding.Codec, but
// omits the name/string, which vary between the two and are not needed for
// anything besides the registry in the encoding package.
// baseCodec captures the new encoding.CodecV2 interface without the Name
// function, allowing it to be implemented by older Codec and encoding.Codec
// implementations. The omitted Name function is only needed for the register in
// the encoding package and is not part of the core functionality.
type baseCodec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
Marshal(v any) (mem.BufferSlice, error)
Unmarshal(data mem.BufferSlice, v any) error
}
var _ baseCodec = Codec(nil)
var _ baseCodec = encoding.Codec(nil)
// getCodec returns an encoding.CodecV2 for the codec of the given name (if
// registered). Initially checks the V2 registry with encoding.GetCodecV2 and
// returns the V2 codec if it is registered. Otherwise, it checks the V1 registry
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
// to turn it into an encoding.CodecV2. Returns nil otherwise.
func getCodec(name string) encoding.CodecV2 {
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
return newCodecV1Bridge(codecV1)
}
return encoding.GetCodecV2(name)
}
func newCodecV0Bridge(c Codec) baseCodec {
return codecV0Bridge{codec: c}
}
func newCodecV1Bridge(c encoding.Codec) encoding.CodecV2 {
return codecV1Bridge{
codecV0Bridge: codecV0Bridge{codec: c},
name: c.Name(),
}
}
var _ baseCodec = codecV0Bridge{}
type codecV0Bridge struct {
codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
}
func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) {
data, err := c.codec.Marshal(v)
if err != nil {
return nil, err
}
return mem.BufferSlice{mem.SliceBuffer(data)}, nil
}
func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) {
return c.codec.Unmarshal(data.Materialize(), v)
}
var _ encoding.CodecV2 = codecV1Bridge{}
type codecV1Bridge struct {
codecV0Bridge
name string
}
func (c codecV1Bridge) Name() string {
return c.name
}
// Codec defines the interface gRPC uses to encode and decode messages.
// Note that implementations of this interface must be thread safe;

View File

@@ -1,17 +0,0 @@
#!/usr/bin/env bash
# This script serves as an example to demonstrate how to generate the gRPC-Go
# interface and the related messages from .proto file.
#
# It assumes the installation of i) Google proto buffer compiler at
# https://github.com/google/protobuf (after v2.6.1) and ii) the Go codegen
# plugin at https://github.com/golang/protobuf (after 2015-02-20). If you have
# not, please install them first.
#
# We recommend running this script at $GOPATH/src.
#
# If this is not what you need, feel free to make your own scripts. Again, this
# script is for demonstration purpose.
#
proto=$1
protoc --go_out=plugins=grpc:. $proto

View File

@@ -235,7 +235,7 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci)
return fmt.Errorf("invalid code: %d", ci)
}
*c = Code(ci)

View File

@@ -30,7 +30,7 @@ import (
"google.golang.org/grpc/attributes"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/protobuf/protoadapt"
"google.golang.org/protobuf/proto"
)
// PerRPCCredentials defines the common interface for the credentials which need to
@@ -237,7 +237,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
}
// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// It returns success if 1) the condition is satisfied or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
@@ -287,5 +287,5 @@ type ChannelzSecurityValue interface {
type OtherChannelzSecurityValue struct {
ChannelzSecurityValue
Name string
Value protoadapt.MessageV1
Value proto.Message
}

View File

@@ -40,7 +40,7 @@ func NewCredentials() credentials.TransportCredentials {
// NoSecurity.
type insecureTC struct{}
func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
func (insecureTC) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

View File

@@ -27,9 +27,15 @@ import (
"net/url"
"os"
"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
)
const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
var logger = grpclog.Component("credentials")
// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
@@ -112,6 +118,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
conn.Close()
return nil, nil, ctx.Err()
}
// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
}
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: CommonAuthInfo{
@@ -131,8 +153,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
conn.Close()
return nil, nil, err
}
cs := conn.ConnectionState()
// The negotiated application protocol can be empty only if the client doesn't
// support ALPN. In such cases, we can close the connection since ALPN is required
// for using HTTP/2 over TLS.
if cs.NegotiatedProtocol == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
} else if logger.V(2) {
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
}
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
State: cs,
CommonAuthInfo: CommonAuthInfo{
SecurityLevel: PrivacyAndIntegrity,
},
@@ -168,25 +202,40 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{credinternal.CloneTLSConfig(c)}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
config := applyDefaults(c)
if config.GetConfigForClient != nil {
oldFn := config.GetConfigForClient
config.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
cfgForClient, err := oldFn(hello)
if err != nil || cfgForClient == nil {
return cfgForClient, err
}
return applyDefaults(cfgForClient), nil
}
}
return &tlsCreds{config: config}
}
func applyDefaults(c *tls.Config) *tls.Config {
config := credinternal.CloneTLSConfig(c)
config.NextProtos = credinternal.AppendH2ToNextProtos(config.NextProtos)
// If the user did not configure a MinVersion and did not configure a
// MaxVersion < 1.2, use MinVersion=1.2, which is required by
// https://datatracker.ietf.org/doc/html/rfc7540#section-9.2
if tc.config.MinVersion == 0 && (tc.config.MaxVersion == 0 || tc.config.MaxVersion >= tls.VersionTLS12) {
tc.config.MinVersion = tls.VersionTLS12
if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) {
config.MinVersion = tls.VersionTLS12
}
// If the user did not configure CipherSuites, use all "secure" cipher
// suites reported by the TLS package, but remove some explicitly forbidden
// by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
if tc.config.CipherSuites == nil {
if config.CipherSuites == nil {
for _, cs := range tls.CipherSuites() {
if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok {
tc.config.CipherSuites = append(tc.config.CipherSuites, cs.ID)
config.CipherSuites = append(config.CipherSuites, cs.ID)
}
}
}
return tc
return config
}
// NewClientTLSFromCert constructs TLS credentials from the provided root

View File

@@ -21,6 +21,7 @@ package grpc
import (
"context"
"net"
"net/url"
"time"
"google.golang.org/grpc/backoff"
@@ -32,10 +33,16 @@ import (
"google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats"
)
const (
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#limits-on-retries-and-hedges
defaultMaxCallAttempts = 5
)
func init() {
internal.AddGlobalDialOptions = func(opt ...DialOption) {
globalDialOptions = append(globalDialOptions, opt...)
@@ -43,10 +50,18 @@ func init() {
internal.ClearGlobalDialOptions = func() {
globalDialOptions = nil
}
internal.AddGlobalPerTargetDialOptions = func(opt any) {
if ptdo, ok := opt.(perTargetDialOption); ok {
globalPerTargetDialOptions = append(globalPerTargetDialOptions, ptdo)
}
}
internal.ClearGlobalPerTargetDialOptions = func() {
globalPerTargetDialOptions = nil
}
internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
internal.WithRecvBufferPool = withRecvBufferPool
internal.WithBufferPool = withBufferPool
}
// dialOptions configure a Dial call. dialOptions are set by the DialOption
@@ -58,7 +73,7 @@ type dialOptions struct {
chainUnaryInts []UnaryClientInterceptor
chainStreamInts []StreamClientInterceptor
cp Compressor
compressorV0 Compressor
dc Decompressor
bs internalbackoff.Strategy
block bool
@@ -72,14 +87,15 @@ type dialOptions struct {
disableServiceConfig bool
disableRetry bool
disableHealthCheck bool
healthCheckFunc internal.HealthChecker
minConnectTimeout func() time.Duration
defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON.
defaultServiceConfigRawJSON *string
resolvers []resolver.Builder
idleTimeout time.Duration
recvBufferPool SharedBufferPool
defaultScheme string
maxCallAttempts int
enableLocalDNSResolution bool // Specifies if target hostnames should be resolved when proxying is enabled.
useProxy bool // Specifies if a server should be connected via proxy.
}
// DialOption configures how we set up the connection.
@@ -89,6 +105,19 @@ type DialOption interface {
var globalDialOptions []DialOption
// perTargetDialOption takes a parsed target and returns a dial option to apply.
//
// This gets called after NewClient() parses the target, and allows per target
// configuration set through a returned DialOption. The DialOption will not take
// effect if specifies a resolver builder, as that Dial Option is factored in
// while parsing target.
type perTargetDialOption interface {
// DialOption returns a Dial Option to apply.
DialOptionForTarget(parsedTarget url.URL) DialOption
}
var globalPerTargetDialOptions []perTargetDialOption
// EmptyDialOption does not alter the dial configuration. It can be embedded in
// another structure to build custom dial options.
//
@@ -229,7 +258,7 @@ func WithCodec(c Codec) DialOption {
// Deprecated: use UseCompressor instead. Will be supported throughout 1.x.
func WithCompressor(cp Compressor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.cp = cp
o.compressorV0 = cp
})
}
@@ -300,6 +329,9 @@ func withBackoff(bs internalbackoff.Strategy) DialOption {
//
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithBlock() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.block = true
@@ -314,10 +346,8 @@ func WithBlock() DialOption {
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithReturnConnectionError() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.block = true
@@ -349,7 +379,22 @@ func WithInsecure() DialOption {
// later release.
func WithNoProxy() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.UseProxy = false
o.useProxy = false
})
}
// WithLocalDNSResolution forces local DNS name resolution even when a proxy is
// specified in the environment. By default, the server name is provided
// directly to the proxy as part of the CONNECT handshake. This is ignored if
// WithNoProxy is used.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithLocalDNSResolution() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.enableLocalDNSResolution = true
})
}
@@ -387,8 +432,8 @@ func WithCredentialsBundle(b credentials.Bundle) DialOption {
// WithTimeout returns a DialOption that configures a timeout for dialing a
// ClientConn initially. This is valid if and only if WithBlock() is present.
//
// Deprecated: use DialContext instead of Dial and context.WithTimeout
// instead. Will be supported throughout 1.x.
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.timeout = d
@@ -400,6 +445,11 @@ func WithTimeout(d time.Duration) DialOption {
// returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address.
//
// Note that gRPC by default performs name resolution on the target passed to
// NewClient. To bypass name resolution and cause the target string to be
// passed directly to the dialer here instead, use the "passthrough" resolver
// by specifying it in the target string, e.g. "passthrough:target".
//
// Note: All supported releases of Go (as of December 2023) override the OS
// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive
// with OS defaults for keepalive time and interval, use a net.Dialer that sets
@@ -407,7 +457,7 @@ func WithTimeout(d time.Duration) DialOption {
// option to true from the Control field. For a concrete example of how to do
// this, see internal.NetDialerWithTCPKeepalive().
//
// For more information, please see [issue 23459] in the Go github repo.
// For more information, please see [issue 23459] in the Go GitHub repo.
//
// [issue 23459]: https://github.com/golang/go/issues/23459
func WithContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
@@ -416,10 +466,6 @@ func WithContextDialer(f func(context.Context, string) (net.Conn, error)) DialOp
})
}
func init() {
internal.WithHealthCheckFunc = withHealthCheckFunc
}
// WithDialer returns a DialOption that specifies a function to use for dialing
// network addresses. If FailOnNonTempDialError() is set to true, and an error
// is returned by f, gRPC checks the error's Temporary() method to decide if it
@@ -470,9 +516,8 @@ func withBinaryLogger(bl binarylog.Logger) DialOption {
// Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// Deprecated: this DialOption is not supported by NewClient.
// This API may be changed or removed in a
// later release.
func FailOnNonTempDialError(f bool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
@@ -490,6 +535,8 @@ func WithUserAgent(s string) DialOption {
// WithKeepaliveParams returns a DialOption that specifies keepalive parameters
// for the client transport.
//
// Keepalive is disabled by default.
func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption {
if kp.Time < internal.KeepaliveMinPingTime {
logger.Warningf("Adjusting keepalive ping interval to minimum period of %v", internal.KeepaliveMinPingTime)
@@ -601,12 +648,22 @@ func WithDisableRetry() DialOption {
})
}
// MaxHeaderListSizeDialOption is a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
type MaxHeaderListSizeDialOption struct {
MaxHeaderListSize uint32
}
func (o MaxHeaderListSizeDialOption) apply(do *dialOptions) {
do.copts.MaxHeaderListSize = &o.MaxHeaderListSize
}
// WithMaxHeaderListSize returns a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
func WithMaxHeaderListSize(s uint32) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.MaxHeaderListSize = &s
})
return MaxHeaderListSizeDialOption{
MaxHeaderListSize: s,
}
}
// WithDisableHealthCheck disables the LB channel health checking for all
@@ -622,33 +679,24 @@ func WithDisableHealthCheck() DialOption {
})
}
// withHealthCheckFunc replaces the default health check function with the
// provided one. It makes tests easier to change the health check function.
//
// For testing purpose only.
func withHealthCheckFunc(f internal.HealthChecker) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.healthCheckFunc = f
})
}
func defaultDialOptions() dialOptions {
return dialOptions{
copts: transport.ConnectOptions{
ReadBufferSize: defaultReadBufSize,
WriteBufferSize: defaultWriteBufSize,
UseProxy: true,
UserAgent: grpcUA,
BufferPool: mem.DefaultBufferPool(),
},
bs: internalbackoff.DefaultExponential,
healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
defaultScheme: "dns",
bs: internalbackoff.DefaultExponential,
idleTimeout: 30 * time.Minute,
defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
useProxy: true,
enableLocalDNSResolution: false,
}
}
// withGetMinConnectDeadline specifies the function that clientconn uses to
// withMinConnectDeadline specifies the function that clientconn uses to
// get minConnectDeadline. This can be used to make connection attempts happen
// faster/slower.
//
@@ -702,25 +750,25 @@ func WithIdleTimeout(d time.Duration) DialOption {
})
}
// WithRecvBufferPool returns a DialOption that configures the ClientConn
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
// WithMaxCallAttempts returns a DialOption that configures the maximum number
// of attempts per call (including retries and hedging) using the channel.
// Service owners may specify a higher value for these parameters, but higher
// values will be treated as equal to the maximum value by the client
// implementation. This mitigates security concerns related to the service
// config being transferred to the client via DNS.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// v1.60.0 or later.
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return withRecvBufferPool(bufferPool)
}
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
// A value of 5 will be used if this dial option is not set or n < 2.
func WithMaxCallAttempts(n int) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool
if n < 2 {
n = defaultMaxCallAttempts
}
o.maxCallAttempts = n
})
}
func withBufferPool(bufferPool mem.BufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.BufferPool = bufferPool
})
}

View File

@@ -16,7 +16,7 @@
*
*/
//go:generate ./regenerate.sh
//go:generate ./scripts/regenerate.sh
/*
Package grpc implements an RPC system called gRPC.

View File

@@ -94,7 +94,7 @@ type Codec interface {
Name() string
}
var registeredCodecs = make(map[string]Codec)
var registeredCodecs = make(map[string]any)
// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.
@@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
//
// The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype]
c, _ := registeredCodecs[contentSubtype].(Codec)
return c
}

81
vendor/google.golang.org/grpc/encoding/encoding_v2.go generated vendored Normal file
View File

@@ -0,0 +1,81 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package encoding
import (
"strings"
"google.golang.org/grpc/mem"
)
// CodecV2 defines the interface gRPC uses to encode and decode messages. Note
// that implementations of this interface must be thread safe; a CodecV2's
// methods can be called from concurrent goroutines.
type CodecV2 interface {
// Marshal returns the wire format of v. The buffers in the returned
// [mem.BufferSlice] must have at least one reference each, which will be freed
// by gRPC when they are no longer needed.
Marshal(v any) (out mem.BufferSlice, err error)
// Unmarshal parses the wire format into v. Note that data will be freed as soon
// as this function returns. If the codec wishes to guarantee access to the data
// after this function, it must take its own reference that it frees when it is
// no longer needed.
Unmarshal(data mem.BufferSlice, v any) error
// Name returns the name of the Codec implementation. The returned string
// will be used as part of content type in transmission. The result must be
// static; the result cannot change between calls.
Name() string
}
// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
// servers.
//
// The CodecV2 will be stored and looked up by result of its Name() method, which
// should match the content-subtype of the encoding handled by the CodecV2. This
// is case-insensitive, and is stored and looked up as lowercase. If the
// result of calling Name() is an empty string, RegisterCodecV2 will panic. See
// Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If both a Codec and CodecV2 are registered with the same name, the CodecV2
// will be used.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Codecs are
// registered with the same name, the one registered last will take effect.
func RegisterCodecV2(codec CodecV2) {
if codec == nil {
panic("cannot register a nil CodecV2")
}
if codec.Name() == "" {
panic("cannot register CodecV2 with empty string result for Name()")
}
contentSubtype := strings.ToLower(codec.Name())
registeredCodecs[contentSubtype] = codec
}
// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodecV2(contentSubtype string) CodecV2 {
c, _ := registeredCodecs[contentSubtype].(CodecV2)
return c
}

View File

@@ -1,6 +1,6 @@
/*
*
* Copyright 2018 gRPC authors.
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import (
"fmt"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
)
@@ -32,28 +33,51 @@ import (
const Name = "proto"
func init() {
encoding.RegisterCodec(codec{})
encoding.RegisterCodecV2(&codecV2{})
}
// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{}
// codec is a CodecV2 implementation with protobuf. It is the default codec for
// gRPC.
type codecV2 struct{}
func (codec) Marshal(v any) ([]byte, error) {
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v)
if vv == nil {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err
}
data = append(data, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
}
data = append(data, mem.NewBuffer(buf, pool))
}
return data, nil
}
func (codec) Unmarshal(data []byte, v any) error {
func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
vv := messageV2Of(v)
if vv == nil {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
}
return proto.Unmarshal(data, vv)
buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
defer buf.Free()
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
// really possible without a major overhaul of the proto package, but the
// vtprotobuf library may be able to support this.
return proto.Unmarshal(buf.ReadOnlyData(), vv)
}
func messageV2Of(v any) proto.Message {
@@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message {
return nil
}
func (codec) Name() string {
func (c *codecV2) Name() string {
return Name
}

View File

@@ -0,0 +1,270 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"maps"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/stats"
)
func init() {
internal.SnapshotMetricRegistryForTesting = snapshotMetricsRegistryForTesting
}
var logger = grpclog.Component("metrics-registry")
// DefaultMetrics are the default metrics registered through global metrics
// registry. This is written to at initialization time only, and is read only
// after initialization.
var DefaultMetrics = stats.NewMetricSet()
// MetricDescriptor is the data for a registered metric.
type MetricDescriptor struct {
// The name of this metric. This name must be unique across the whole binary
// (including any per call metrics). See
// https://github.com/grpc/proposal/blob/master/A79-non-per-call-metrics-architecture.md#metric-instrument-naming-conventions
// for metric naming conventions.
Name string
// The description of this metric.
Description string
// The unit (e.g. entries, seconds) of this metric.
Unit string
// The required label keys for this metric. These are intended to
// metrics emitted from a stats handler.
Labels []string
// The optional label keys for this metric. These are intended to attached
// to metrics emitted from a stats handler if configured.
OptionalLabels []string
// Whether this metric is on by default.
Default bool
// The type of metric. This is set by the metric registry, and not intended
// to be set by a component registering a metric.
Type MetricType
// Bounds are the bounds of this metric. This only applies to histogram
// metrics. If unset or set with length 0, stats handlers will fall back to
// default bounds.
Bounds []float64
}
// MetricType is the type of metric.
type MetricType int
// Type of metric supported by this instrument registry.
const (
MetricTypeIntCount MetricType = iota
MetricTypeFloatCount
MetricTypeIntHisto
MetricTypeFloatHisto
MetricTypeIntGauge
)
// Int64CountHandle is a typed handle for a int count metric. This handle
// is passed at the recording point in order to know which metric to record
// on.
type Int64CountHandle MetricDescriptor
// Descriptor returns the int64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 count value on the metrics recorder provided.
func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Count(h, incr, labels...)
}
// Float64CountHandle is a typed handle for a float count metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Float64CountHandle MetricDescriptor
// Descriptor returns the float64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 count value on the metrics recorder provided.
func (h *Float64CountHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Count(h, incr, labels...)
}
// Int64HistoHandle is a typed handle for an int histogram metric. This handle
// is passed at the recording point in order to know which metric to record on.
type Int64HistoHandle MetricDescriptor
// Descriptor returns the int64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64HistoHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Histo(h, incr, labels...)
}
// Float64HistoHandle is a typed handle for a float histogram metric. This
// handle is passed at the recording point in order to know which metric to
// record on.
type Float64HistoHandle MetricDescriptor
// Descriptor returns the float64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 histo value on the metrics recorder provided.
func (h *Float64HistoHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Histo(h, incr, labels...)
}
// Int64GaugeHandle is a typed handle for an int gauge metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Int64GaugeHandle MetricDescriptor
// Descriptor returns the int64 gauge handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64GaugeHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Gauge(h, incr, labels...)
}
// registeredMetrics are the registered metric descriptor names.
var registeredMetrics = make(map[string]bool)
// metricsRegistry contains all of the registered metrics.
//
// This is written to only at init time, and read only after that.
var metricsRegistry = make(map[string]*MetricDescriptor)
// DescriptorForMetric returns the MetricDescriptor from the global registry.
//
// Returns nil if MetricDescriptor not present.
func DescriptorForMetric(metricName string) *MetricDescriptor {
return metricsRegistry[metricName]
}
func registerMetric(metricName string, def bool) {
if registeredMetrics[metricName] {
logger.Fatalf("metric %v already registered", metricName)
}
registeredMetrics[metricName] = true
if def {
DefaultMetrics = DefaultMetrics.Add(metricName)
}
}
// RegisterInt64Count registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Count(descriptor MetricDescriptor) *Int64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64CountHandle)(descPtr)
}
// RegisterFloat64Count registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Count(descriptor MetricDescriptor) *Float64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64CountHandle)(descPtr)
}
// RegisterInt64Histo registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Histo(descriptor MetricDescriptor) *Int64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64HistoHandle)(descPtr)
}
// RegisterFloat64Histo registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Histo(descriptor MetricDescriptor) *Float64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64HistoHandle)(descPtr)
}
// RegisterInt64Gauge registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntGauge
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64GaugeHandle)(descPtr)
}
// snapshotMetricsRegistryForTesting snapshots the global data of the metrics
// registry. Returns a cleanup function that sets the metrics registry to its
// original state.
func snapshotMetricsRegistryForTesting() func() {
oldDefaultMetrics := DefaultMetrics
oldRegisteredMetrics := registeredMetrics
oldMetricsRegistry := metricsRegistry
registeredMetrics = make(map[string]bool)
metricsRegistry = make(map[string]*MetricDescriptor)
maps.Copy(registeredMetrics, registeredMetrics)
maps.Copy(metricsRegistry, metricsRegistry)
return func() {
DefaultMetrics = oldDefaultMetrics
registeredMetrics = oldRegisteredMetrics
metricsRegistry = oldMetricsRegistry
}
}

View File

@@ -0,0 +1,54 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats contains experimental metrics/stats API's.
package stats
import "google.golang.org/grpc/stats"
// MetricsRecorder records on metrics derived from metric registry.
type MetricsRecorder interface {
// RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle.
RecordInt64Count(handle *Int64CountHandle, incr int64, labels ...string)
// RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle.
RecordFloat64Count(handle *Float64CountHandle, incr float64, labels ...string)
// RecordInt64Histo records the measurement alongside labels on the int
// histo associated with the provided handle.
RecordInt64Histo(handle *Int64HistoHandle, incr int64, labels ...string)
// RecordFloat64Histo records the measurement alongside labels on the float
// histo associated with the provided handle.
RecordFloat64Histo(handle *Float64HistoHandle, incr float64, labels ...string)
// RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string)
}
// Metrics is an experimental legacy alias of the now-stable stats.MetricSet.
// Metrics will be deleted in a future release.
type Metrics = stats.MetricSet
// Metric was replaced by direct usage of strings.
type Metric = string
// NewMetrics is an experimental legacy alias of the now-stable
// stats.NewMetricSet. NewMetrics will be deleted in a future release.
func NewMetrics(metrics ...Metric) *Metrics {
return stats.NewMetricSet(metrics...)
}

View File

@@ -20,8 +20,6 @@ package grpclog
import (
"fmt"
"google.golang.org/grpc/internal/grpclog"
)
// componentData records the settings for a component.
@@ -33,22 +31,22 @@ var cache = map[string]*componentData{}
func (c *componentData) InfoDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.InfoDepth(depth+1, args...)
InfoDepth(depth+1, args...)
}
func (c *componentData) WarningDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.WarningDepth(depth+1, args...)
WarningDepth(depth+1, args...)
}
func (c *componentData) ErrorDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.ErrorDepth(depth+1, args...)
ErrorDepth(depth+1, args...)
}
func (c *componentData) FatalDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.FatalDepth(depth+1, args...)
FatalDepth(depth+1, args...)
}
func (c *componentData) Info(args ...any) {

View File

@@ -18,18 +18,15 @@
// Package grpclog defines logging for grpc.
//
// All logs in transport and grpclb packages only go to verbose level 2.
// All logs in other packages in grpc are logged in spite of the verbosity level.
//
// In the default logger,
// severity level can be set by environment variable GRPC_GO_LOG_SEVERITY_LEVEL,
// verbosity level can be set by GRPC_GO_LOG_VERBOSITY_LEVEL.
package grpclog // import "google.golang.org/grpc/grpclog"
// In the default logger, severity level can be set by environment variable
// GRPC_GO_LOG_SEVERITY_LEVEL, verbosity level can be set by
// GRPC_GO_LOG_VERBOSITY_LEVEL.
package grpclog
import (
"os"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/grpclog/internal"
)
func init() {
@@ -38,58 +35,58 @@ func init() {
// V reports whether verbosity level l is at least the requested verbose level.
func V(l int) bool {
return grpclog.Logger.V(l)
return internal.LoggerV2Impl.V(l)
}
// Info logs to the INFO log.
func Info(args ...any) {
grpclog.Logger.Info(args...)
internal.LoggerV2Impl.Info(args...)
}
// Infof logs to the INFO log. Arguments are handled in the manner of fmt.Printf.
func Infof(format string, args ...any) {
grpclog.Logger.Infof(format, args...)
internal.LoggerV2Impl.Infof(format, args...)
}
// Infoln logs to the INFO log. Arguments are handled in the manner of fmt.Println.
func Infoln(args ...any) {
grpclog.Logger.Infoln(args...)
internal.LoggerV2Impl.Infoln(args...)
}
// Warning logs to the WARNING log.
func Warning(args ...any) {
grpclog.Logger.Warning(args...)
internal.LoggerV2Impl.Warning(args...)
}
// Warningf logs to the WARNING log. Arguments are handled in the manner of fmt.Printf.
func Warningf(format string, args ...any) {
grpclog.Logger.Warningf(format, args...)
internal.LoggerV2Impl.Warningf(format, args...)
}
// Warningln logs to the WARNING log. Arguments are handled in the manner of fmt.Println.
func Warningln(args ...any) {
grpclog.Logger.Warningln(args...)
internal.LoggerV2Impl.Warningln(args...)
}
// Error logs to the ERROR log.
func Error(args ...any) {
grpclog.Logger.Error(args...)
internal.LoggerV2Impl.Error(args...)
}
// Errorf logs to the ERROR log. Arguments are handled in the manner of fmt.Printf.
func Errorf(format string, args ...any) {
grpclog.Logger.Errorf(format, args...)
internal.LoggerV2Impl.Errorf(format, args...)
}
// Errorln logs to the ERROR log. Arguments are handled in the manner of fmt.Println.
func Errorln(args ...any) {
grpclog.Logger.Errorln(args...)
internal.LoggerV2Impl.Errorln(args...)
}
// Fatal logs to the FATAL log. Arguments are handled in the manner of fmt.Print.
// It calls os.Exit() with exit code 1.
func Fatal(args ...any) {
grpclog.Logger.Fatal(args...)
internal.LoggerV2Impl.Fatal(args...)
// Make sure fatal logs will exit.
os.Exit(1)
}
@@ -97,15 +94,15 @@ func Fatal(args ...any) {
// Fatalf logs to the FATAL log. Arguments are handled in the manner of fmt.Printf.
// It calls os.Exit() with exit code 1.
func Fatalf(format string, args ...any) {
grpclog.Logger.Fatalf(format, args...)
internal.LoggerV2Impl.Fatalf(format, args...)
// Make sure fatal logs will exit.
os.Exit(1)
}
// Fatalln logs to the FATAL log. Arguments are handled in the manner of fmt.Println.
// It calle os.Exit()) with exit code 1.
// It calls os.Exit() with exit code 1.
func Fatalln(args ...any) {
grpclog.Logger.Fatalln(args...)
internal.LoggerV2Impl.Fatalln(args...)
// Make sure fatal logs will exit.
os.Exit(1)
}
@@ -114,19 +111,76 @@ func Fatalln(args ...any) {
//
// Deprecated: use Info.
func Print(args ...any) {
grpclog.Logger.Info(args...)
internal.LoggerV2Impl.Info(args...)
}
// Printf prints to the logger. Arguments are handled in the manner of fmt.Printf.
//
// Deprecated: use Infof.
func Printf(format string, args ...any) {
grpclog.Logger.Infof(format, args...)
internal.LoggerV2Impl.Infof(format, args...)
}
// Println prints to the logger. Arguments are handled in the manner of fmt.Println.
//
// Deprecated: use Infoln.
func Println(args ...any) {
grpclog.Logger.Infoln(args...)
internal.LoggerV2Impl.Infoln(args...)
}
// InfoDepth logs to the INFO log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func InfoDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.InfoDepth(depth, args...)
} else {
internal.LoggerV2Impl.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WarningDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.WarningDepth(depth, args...)
} else {
internal.LoggerV2Impl.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ErrorDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.ErrorDepth(depth, args...)
} else {
internal.LoggerV2Impl.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func FatalDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.FatalDepth(depth, args...)
} else {
internal.LoggerV2Impl.Fatalln(args...)
}
os.Exit(1)
}

View File

@@ -1,6 +1,6 @@
/*
*
* Copyright 2022 gRPC authors.
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,17 +16,11 @@
*
*/
package grpcsync
// Package internal contains functionality internal to the grpclog package.
package internal
import (
"sync"
)
// LoggerV2Impl is the logger used for the non-depth log functions.
var LoggerV2Impl LoggerV2
// OnceFunc returns a function wrapping f which ensures f is only executed
// once even if the returned function is executed multiple times.
func OnceFunc(f func()) func() {
var once sync.Once
return func() {
once.Do(f)
}
}
// DepthLoggerV2Impl is the logger used for the depth log functions.
var DepthLoggerV2Impl DepthLoggerV2

View File

@@ -0,0 +1,87 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
// Logger mimics golang's standard Logger as an interface.
//
// Deprecated: use LoggerV2.
type Logger interface {
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
// LoggerWrapper wraps Logger into a LoggerV2.
type LoggerWrapper struct {
Logger
}
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Info(args ...any) {
l.Logger.Print(args...)
}
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Infoln(args ...any) {
l.Logger.Println(args...)
}
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Infof(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Warning(args ...any) {
l.Logger.Print(args...)
}
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Warningln(args ...any) {
l.Logger.Println(args...)
}
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Warningf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Error(args ...any) {
l.Logger.Print(args...)
}
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Errorln(args ...any) {
l.Logger.Println(args...)
}
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Errorf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// V reports whether verbosity level l is at least the requested verbose level.
func (*LoggerWrapper) V(int) bool {
// Returns true for all verbose level.
return true
}

View File

@@ -0,0 +1,267 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
import (
"encoding/json"
"fmt"
"io"
"log"
"os"
)
// LoggerV2 does underlying logging work for grpclog.
type LoggerV2 interface {
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any)
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
Infoln(args ...any)
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
Infof(format string, args ...any)
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
Warning(args ...any)
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
Warningln(args ...any)
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
Warningf(format string, args ...any)
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
Error(args ...any)
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
Errorln(args ...any)
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
Errorf(format string, args ...any)
// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatal(args ...any)
// Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalln(args ...any)
// Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalf(format string, args ...any)
// V reports whether verbosity level l is at least the requested verbose level.
V(l int) bool
}
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type DepthLoggerV2 interface {
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
WarningDepth(depth int, args ...any)
// ErrorDepth logs to ERROR log at the specified depth. Arguments are handled in the manner of fmt.Println.
ErrorDepth(depth int, args ...any)
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any)
}
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// sprintf is fmt.Sprintf.
// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily.
var sprintf = fmt.Sprintf
// sprint is fmt.Sprint.
// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily.
var sprint = fmt.Sprint
// sprintln is fmt.Sprintln.
// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily.
var sprintln = fmt.Sprintln
// exit is os.Exit.
// This var exists to make it possible to test functions calling os.Exit.
var exit = os.Exit
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
}
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, sevStr+": "+s)
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) printf(severity int, format string, args ...any) {
// Note the discard check is duplicated in each print func, rather than in
// output, to avoid the expensive Sprint calls.
// De-duplicating this by moving to output would be a significant performance regression!
if lg := g.m[severity]; lg.Writer() == io.Discard {
return
}
g.output(severity, sprintf(format, args...))
}
func (g *loggerT) print(severity int, v ...any) {
if lg := g.m[severity]; lg.Writer() == io.Discard {
return
}
g.output(severity, sprint(v...))
}
func (g *loggerT) println(severity int, v ...any) {
if lg := g.m[severity]; lg.Writer() == io.Discard {
return
}
g.output(severity, sprintln(v...))
}
func (g *loggerT) Info(args ...any) {
g.print(infoLog, args...)
}
func (g *loggerT) Infoln(args ...any) {
g.println(infoLog, args...)
}
func (g *loggerT) Infof(format string, args ...any) {
g.printf(infoLog, format, args...)
}
func (g *loggerT) Warning(args ...any) {
g.print(warningLog, args...)
}
func (g *loggerT) Warningln(args ...any) {
g.println(warningLog, args...)
}
func (g *loggerT) Warningf(format string, args ...any) {
g.printf(warningLog, format, args...)
}
func (g *loggerT) Error(args ...any) {
g.print(errorLog, args...)
}
func (g *loggerT) Errorln(args ...any) {
g.println(errorLog, args...)
}
func (g *loggerT) Errorf(format string, args ...any) {
g.printf(errorLog, format, args...)
}
func (g *loggerT) Fatal(args ...any) {
g.print(fatalLog, args...)
exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.println(fatalLog, args...)
exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.printf(fatalLog, format, args...)
exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// LoggerV2Config configures the LoggerV2 implementation.
type LoggerV2Config struct {
// Verbosity sets the verbosity level of the logger.
Verbosity int
// FormatJSON controls whether the logger should output logs in JSON format.
FormatJSON bool
}
// combineLoggers returns a combined logger for both higher & lower severity logs,
// or only one if the other is io.Discard.
//
// This uses io.Discard instead of io.MultiWriter when all loggers
// are set to io.Discard. Both this package and the standard log package have
// significant optimizations for io.Discard, which io.MultiWriter lacks (as of
// this writing).
func combineLoggers(lower, higher io.Writer) io.Writer {
if lower == io.Discard {
return higher
}
if higher == io.Discard {
return lower
}
return io.MultiWriter(lower, higher)
}
// NewLoggerV2 creates a new LoggerV2 instance with the provided configuration.
// The infoW, warningW, and errorW writers are used to write log messages of
// different severity levels.
func NewLoggerV2(infoW, warningW, errorW io.Writer, c LoggerV2Config) LoggerV2 {
flag := log.LstdFlags
if c.FormatJSON {
flag = 0
}
warningW = combineLoggers(infoW, warningW)
errorW = combineLoggers(errorW, warningW)
fatalW := errorW
m := []*log.Logger{
log.New(infoW, "", flag),
log.New(warningW, "", flag),
log.New(errorW, "", flag),
log.New(fatalW, "", flag),
}
return &loggerT{m: m, v: c.Verbosity, jsonFormat: c.FormatJSON}
}

View File

@@ -18,70 +18,17 @@
package grpclog
import "google.golang.org/grpc/internal/grpclog"
import "google.golang.org/grpc/grpclog/internal"
// Logger mimics golang's standard Logger as an interface.
//
// Deprecated: use LoggerV2.
type Logger interface {
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
type Logger internal.Logger
// SetLogger sets the logger that is used in grpc. Call only from
// init() functions.
//
// Deprecated: use SetLoggerV2.
func SetLogger(l Logger) {
grpclog.Logger = &loggerWrapper{Logger: l}
}
// loggerWrapper wraps Logger into a LoggerV2.
type loggerWrapper struct {
Logger
}
func (g *loggerWrapper) Info(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Infoln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Infof(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Warning(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Warningln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Warningf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Error(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Errorln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Errorf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) V(l int) bool {
// Returns true for all verbose level.
return true
internal.LoggerV2Impl = &internal.LoggerWrapper{Logger: l}
}

View File

@@ -19,52 +19,16 @@
package grpclog
import (
"encoding/json"
"fmt"
"io"
"log"
"os"
"strconv"
"strings"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/grpclog/internal"
)
// LoggerV2 does underlying logging work for grpclog.
type LoggerV2 interface {
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any)
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
Infoln(args ...any)
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
Infof(format string, args ...any)
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
Warning(args ...any)
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
Warningln(args ...any)
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
Warningf(format string, args ...any)
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
Error(args ...any)
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
Errorln(args ...any)
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
Errorf(format string, args ...any)
// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatal(args ...any)
// Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalln(args ...any)
// Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalf(format string, args ...any)
// V reports whether verbosity level l is at least the requested verbose level.
V(l int) bool
}
type LoggerV2 internal.LoggerV2
// SetLoggerV2 sets logger that is used in grpc to a V2 logger.
// Not mutex-protected, should be called before any gRPC functions.
@@ -72,34 +36,8 @@ func SetLoggerV2(l LoggerV2) {
if _, ok := l.(*componentData); ok {
panic("cannot use component logger as grpclog logger")
}
grpclog.Logger = l
grpclog.DepthLogger, _ = l.(grpclog.DepthLoggerV2)
}
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
internal.LoggerV2Impl = l
internal.DepthLoggerV2Impl, _ = l.(internal.DepthLoggerV2)
}
// NewLoggerV2 creates a loggerV2 with the provided writers.
@@ -108,32 +46,13 @@ type loggerT struct {
// Warning logs will be written to warningW and infoW.
// Info logs will be written to infoW.
func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{})
return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{})
}
// NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and
// verbosity level.
func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{verbose: v})
}
type loggerV2Config struct {
verbose int
jsonFormat bool
}
func newLoggerV2WithConfig(infoW, warningW, errorW io.Writer, c loggerV2Config) LoggerV2 {
var m []*log.Logger
flag := log.LstdFlags
if c.jsonFormat {
flag = 0
}
m = append(m, log.New(infoW, "", flag))
m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag))
ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal.
m = append(m, log.New(ew, "", flag))
m = append(m, log.New(ew, "", flag))
return &loggerT{m: m, v: c.verbose, jsonFormat: c.jsonFormat}
return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{Verbosity: v})
}
// newLoggerV2 creates a loggerV2 to be used as default logger.
@@ -161,82 +80,12 @@ func newLoggerV2() LoggerV2 {
jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json")
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{
verbose: v,
jsonFormat: jsonFormat,
return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{
Verbosity: v,
FormatJSON: jsonFormat,
})
}
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s))
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) Info(args ...any) {
g.output(infoLog, fmt.Sprint(args...))
}
func (g *loggerT) Infoln(args ...any) {
g.output(infoLog, fmt.Sprintln(args...))
}
func (g *loggerT) Infof(format string, args ...any) {
g.output(infoLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Warning(args ...any) {
g.output(warningLog, fmt.Sprint(args...))
}
func (g *loggerT) Warningln(args ...any) {
g.output(warningLog, fmt.Sprintln(args...))
}
func (g *loggerT) Warningf(format string, args ...any) {
g.output(warningLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Error(args ...any) {
g.output(errorLog, fmt.Sprint(args...))
}
func (g *loggerT) Errorln(args ...any) {
g.output(errorLog, fmt.Sprintln(args...))
}
func (g *loggerT) Errorf(format string, args ...any) {
g.output(errorLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Fatal(args ...any) {
g.output(fatalLog, fmt.Sprint(args...))
os.Exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.output(fatalLog, fmt.Sprintln(args...))
os.Exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.output(fatalLog, fmt.Sprintf(format, args...))
os.Exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore.
@@ -245,14 +94,4 @@ func (g *loggerT) V(l int) bool {
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type DepthLoggerV2 interface {
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
WarningDepth(depth int, args ...any)
// ErrorDepth logs to ERROR log at the specified depth. Arguments are handled in the manner of fmt.Println.
ErrorDepth(depth int, args ...any)
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any)
}
type DepthLoggerV2 internal.DepthLoggerV2

View File

@@ -17,8 +17,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.32.0
// protoc v4.25.2
// protoc-gen-go v1.36.5
// protoc v5.27.1
// source: grpc/health/v1/health.proto
package grpc_health_v1
@@ -28,6 +28,7 @@ import (
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -90,20 +91,17 @@ func (HealthCheckResponse_ServingStatus) EnumDescriptor() ([]byte, []int) {
}
type HealthCheckRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Service string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"`
unknownFields protoimpl.UnknownFields
Service string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *HealthCheckRequest) Reset() {
*x = HealthCheckRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_health_v1_health_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_health_v1_health_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthCheckRequest) String() string {
@@ -114,7 +112,7 @@ func (*HealthCheckRequest) ProtoMessage() {}
func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message {
mi := &file_grpc_health_v1_health_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -137,20 +135,17 @@ func (x *HealthCheckRequest) GetService() string {
}
type HealthCheckResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Status HealthCheckResponse_ServingStatus `protobuf:"varint,1,opt,name=status,proto3,enum=grpc.health.v1.HealthCheckResponse_ServingStatus" json:"status,omitempty"`
unknownFields protoimpl.UnknownFields
Status HealthCheckResponse_ServingStatus `protobuf:"varint,1,opt,name=status,proto3,enum=grpc.health.v1.HealthCheckResponse_ServingStatus" json:"status,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *HealthCheckResponse) Reset() {
*x = HealthCheckResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_grpc_health_v1_health_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_grpc_health_v1_health_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthCheckResponse) String() string {
@@ -161,7 +156,7 @@ func (*HealthCheckResponse) ProtoMessage() {}
func (x *HealthCheckResponse) ProtoReflect() protoreflect.Message {
mi := &file_grpc_health_v1_health_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -183,9 +178,90 @@ func (x *HealthCheckResponse) GetStatus() HealthCheckResponse_ServingStatus {
return HealthCheckResponse_UNKNOWN
}
type HealthListRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *HealthListRequest) Reset() {
*x = HealthListRequest{}
mi := &file_grpc_health_v1_health_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthListRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthListRequest) ProtoMessage() {}
func (x *HealthListRequest) ProtoReflect() protoreflect.Message {
mi := &file_grpc_health_v1_health_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthListRequest.ProtoReflect.Descriptor instead.
func (*HealthListRequest) Descriptor() ([]byte, []int) {
return file_grpc_health_v1_health_proto_rawDescGZIP(), []int{2}
}
type HealthListResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// statuses contains all the services and their respective status.
Statuses map[string]*HealthCheckResponse `protobuf:"bytes,1,rep,name=statuses,proto3" json:"statuses,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *HealthListResponse) Reset() {
*x = HealthListResponse{}
mi := &file_grpc_health_v1_health_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthListResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthListResponse) ProtoMessage() {}
func (x *HealthListResponse) ProtoReflect() protoreflect.Message {
mi := &file_grpc_health_v1_health_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthListResponse.ProtoReflect.Descriptor instead.
func (*HealthListResponse) Descriptor() ([]byte, []int) {
return file_grpc_health_v1_health_proto_rawDescGZIP(), []int{3}
}
func (x *HealthListResponse) GetStatuses() map[string]*HealthCheckResponse {
if x != nil {
return x.Statuses
}
return nil
}
var File_grpc_health_v1_health_proto protoreflect.FileDescriptor
var file_grpc_health_v1_health_proto_rawDesc = []byte{
var file_grpc_health_v1_health_proto_rawDesc = string([]byte{
0x0a, 0x1b, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2f, 0x76, 0x31,
0x2f, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x67,
0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x22, 0x2e, 0x0a,
@@ -203,56 +279,83 @@ var file_grpc_health_v1_health_proto_rawDesc = []byte{
0x0a, 0x07, 0x53, 0x45, 0x52, 0x56, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x0f, 0x0a, 0x0b, 0x4e,
0x4f, 0x54, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x12, 0x13, 0x0a, 0x0f,
0x53, 0x45, 0x52, 0x56, 0x49, 0x43, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x03, 0x32, 0xae, 0x01, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x50, 0x0a, 0x05,
0x43, 0x68, 0x65, 0x63, 0x6b, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61,
0x03, 0x22, 0x13, 0x0a, 0x11, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4c, 0x69, 0x73, 0x74, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xc4, 0x01, 0x0a, 0x12, 0x48, 0x65, 0x61, 0x6c, 0x74,
0x68, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4c, 0x0a,
0x08, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x30, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31,
0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x65, 0x73, 0x45, 0x6e, 0x74, 0x72,
0x79, 0x52, 0x08, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x65, 0x73, 0x1a, 0x60, 0x0a, 0x0d, 0x53,
0x74, 0x61, 0x74, 0x75, 0x73, 0x65, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03,
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x39,
0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e,
0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48,
0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0xfd, 0x01,
0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x50, 0x0a, 0x05, 0x43, 0x68, 0x65, 0x63,
0x6b, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e,
0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65,
0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63,
0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74,
0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x52,
0x0a, 0x05, 0x57, 0x61, 0x74, 0x63, 0x68, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68,
0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43,
0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72,
0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x30, 0x01, 0x42, 0x61, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x42, 0x0b, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x50,
0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67,
0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x68,
0x65, 0x61, 0x6c, 0x74, 0x68, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x68, 0x65, 0x61, 0x6c, 0x74,
0x68, 0x5f, 0x76, 0x31, 0xaa, 0x02, 0x0e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x48, 0x65, 0x61, 0x6c,
0x74, 0x68, 0x2e, 0x56, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4d, 0x0a, 0x04, 0x4c, 0x69,
0x73, 0x74, 0x12, 0x21, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4c, 0x69, 0x73,
0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x52, 0x0a, 0x05, 0x57, 0x61, 0x74,
0x63, 0x68, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68,
0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x42, 0x70, 0x0a,
0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e,
0x76, 0x31, 0x42, 0x0b, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50,
0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67,
0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68,
0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x5f, 0x76, 0x31, 0xa2,
0x02, 0x0c, 0x47, 0x72, 0x70, 0x63, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x56, 0x31, 0xaa, 0x02,
0x0e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x56, 0x31, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
})
var (
file_grpc_health_v1_health_proto_rawDescOnce sync.Once
file_grpc_health_v1_health_proto_rawDescData = file_grpc_health_v1_health_proto_rawDesc
file_grpc_health_v1_health_proto_rawDescData []byte
)
func file_grpc_health_v1_health_proto_rawDescGZIP() []byte {
file_grpc_health_v1_health_proto_rawDescOnce.Do(func() {
file_grpc_health_v1_health_proto_rawDescData = protoimpl.X.CompressGZIP(file_grpc_health_v1_health_proto_rawDescData)
file_grpc_health_v1_health_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_grpc_health_v1_health_proto_rawDesc), len(file_grpc_health_v1_health_proto_rawDesc)))
})
return file_grpc_health_v1_health_proto_rawDescData
}
var file_grpc_health_v1_health_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_grpc_health_v1_health_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_grpc_health_v1_health_proto_goTypes = []interface{}{
var file_grpc_health_v1_health_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_grpc_health_v1_health_proto_goTypes = []any{
(HealthCheckResponse_ServingStatus)(0), // 0: grpc.health.v1.HealthCheckResponse.ServingStatus
(*HealthCheckRequest)(nil), // 1: grpc.health.v1.HealthCheckRequest
(*HealthCheckResponse)(nil), // 2: grpc.health.v1.HealthCheckResponse
(*HealthListRequest)(nil), // 3: grpc.health.v1.HealthListRequest
(*HealthListResponse)(nil), // 4: grpc.health.v1.HealthListResponse
nil, // 5: grpc.health.v1.HealthListResponse.StatusesEntry
}
var file_grpc_health_v1_health_proto_depIdxs = []int32{
0, // 0: grpc.health.v1.HealthCheckResponse.status:type_name -> grpc.health.v1.HealthCheckResponse.ServingStatus
1, // 1: grpc.health.v1.Health.Check:input_type -> grpc.health.v1.HealthCheckRequest
1, // 2: grpc.health.v1.Health.Watch:input_type -> grpc.health.v1.HealthCheckRequest
2, // 3: grpc.health.v1.Health.Check:output_type -> grpc.health.v1.HealthCheckResponse
2, // 4: grpc.health.v1.Health.Watch:output_type -> grpc.health.v1.HealthCheckResponse
3, // [3:5] is the sub-list for method output_type
1, // [1:3] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
5, // 1: grpc.health.v1.HealthListResponse.statuses:type_name -> grpc.health.v1.HealthListResponse.StatusesEntry
2, // 2: grpc.health.v1.HealthListResponse.StatusesEntry.value:type_name -> grpc.health.v1.HealthCheckResponse
1, // 3: grpc.health.v1.Health.Check:input_type -> grpc.health.v1.HealthCheckRequest
3, // 4: grpc.health.v1.Health.List:input_type -> grpc.health.v1.HealthListRequest
1, // 5: grpc.health.v1.Health.Watch:input_type -> grpc.health.v1.HealthCheckRequest
2, // 6: grpc.health.v1.Health.Check:output_type -> grpc.health.v1.HealthCheckResponse
4, // 7: grpc.health.v1.Health.List:output_type -> grpc.health.v1.HealthListResponse
2, // 8: grpc.health.v1.Health.Watch:output_type -> grpc.health.v1.HealthCheckResponse
6, // [6:9] is the sub-list for method output_type
3, // [3:6] is the sub-list for method input_type
3, // [3:3] is the sub-list for extension type_name
3, // [3:3] is the sub-list for extension extendee
0, // [0:3] is the sub-list for field type_name
}
func init() { file_grpc_health_v1_health_proto_init() }
@@ -260,39 +363,13 @@ func file_grpc_health_v1_health_proto_init() {
if File_grpc_health_v1_health_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_grpc_health_v1_health_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_health_v1_health_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_grpc_health_v1_health_proto_rawDesc,
RawDescriptor: unsafe.Slice(unsafe.StringData(file_grpc_health_v1_health_proto_rawDesc), len(file_grpc_health_v1_health_proto_rawDesc)),
NumEnums: 1,
NumMessages: 2,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},
@@ -302,7 +379,6 @@ func file_grpc_health_v1_health_proto_init() {
MessageInfos: file_grpc_health_v1_health_proto_msgTypes,
}.Build()
File_grpc_health_v1_health_proto = out.File
file_grpc_health_v1_health_proto_rawDesc = nil
file_grpc_health_v1_health_proto_goTypes = nil
file_grpc_health_v1_health_proto_depIdxs = nil
}

View File

@@ -17,8 +17,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v4.25.2
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.27.1
// source: grpc/health/v1/health.proto
package grpc_health_v1
@@ -32,17 +32,22 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
Health_Check_FullMethodName = "/grpc.health.v1.Health/Check"
Health_List_FullMethodName = "/grpc.health.v1.Health/List"
Health_Watch_FullMethodName = "/grpc.health.v1.Health/Watch"
)
// HealthClient is the client API for Health service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthClient interface {
// Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does
@@ -51,9 +56,19 @@ type HealthClient interface {
//
// Clients should set a deadline when calling Check, and can declare the
// server unhealthy if they do not receive a timely response.
//
// Check implementations should be idempotent and side effect free.
Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error)
// List provides a non-atomic snapshot of the health of all the available
// services.
//
// The server may respond with a RESOURCE_EXHAUSTED error if too many services
// exist.
//
// Clients should set a deadline when calling List, and can declare the server
// unhealthy if they do not receive a timely response.
//
// Clients should keep in mind that the list of health services exposed by an
// application can change over the lifetime of the process.
List(ctx context.Context, in *HealthListRequest, opts ...grpc.CallOption) (*HealthListResponse, error)
// Performs a watch for the serving status of the requested service.
// The server will immediately send back a message indicating the current
// serving status. It will then subsequently send a new message whenever
@@ -69,7 +84,7 @@ type HealthClient interface {
// should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff.
Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error)
Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error)
}
type healthClient struct {
@@ -81,20 +96,32 @@ func NewHealthClient(cc grpc.ClientConnInterface) HealthClient {
}
func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) {
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, opts...)
func (c *healthClient) List(ctx context.Context, in *HealthListRequest, opts ...grpc.CallOption) (*HealthListResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthListResponse)
err := c.cc.Invoke(ctx, Health_List_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
x := &healthWatchClient{stream}
return out, nil
}
func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[HealthCheckRequest, HealthCheckResponse]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
@@ -104,26 +131,16 @@ func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts .
return x, nil
}
type Health_WatchClient interface {
Recv() (*HealthCheckResponse, error)
grpc.ClientStream
}
type healthWatchClient struct {
grpc.ClientStream
}
func (x *healthWatchClient) Recv() (*HealthCheckResponse, error) {
m := new(HealthCheckResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type Health_WatchClient = grpc.ServerStreamingClient[HealthCheckResponse]
// HealthServer is the server API for Health service.
// All implementations should embed UnimplementedHealthServer
// for forward compatibility
// for forward compatibility.
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthServer interface {
// Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does
@@ -132,9 +149,19 @@ type HealthServer interface {
//
// Clients should set a deadline when calling Check, and can declare the
// server unhealthy if they do not receive a timely response.
//
// Check implementations should be idempotent and side effect free.
Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error)
// List provides a non-atomic snapshot of the health of all the available
// services.
//
// The server may respond with a RESOURCE_EXHAUSTED error if too many services
// exist.
//
// Clients should set a deadline when calling List, and can declare the server
// unhealthy if they do not receive a timely response.
//
// Clients should keep in mind that the list of health services exposed by an
// application can change over the lifetime of the process.
List(context.Context, *HealthListRequest) (*HealthListResponse, error)
// Performs a watch for the serving status of the requested service.
// The server will immediately send back a message indicating the current
// serving status. It will then subsequently send a new message whenever
@@ -150,19 +177,26 @@ type HealthServer interface {
// should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff.
Watch(*HealthCheckRequest, Health_WatchServer) error
Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error
}
// UnimplementedHealthServer should be embedded to have forward compatible implementations.
type UnimplementedHealthServer struct {
}
// UnimplementedHealthServer should be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedHealthServer struct{}
func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
}
func (UnimplementedHealthServer) Watch(*HealthCheckRequest, Health_WatchServer) error {
func (UnimplementedHealthServer) List(context.Context, *HealthListRequest) (*HealthListResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method List not implemented")
}
func (UnimplementedHealthServer) Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error {
return status.Errorf(codes.Unimplemented, "method Watch not implemented")
}
func (UnimplementedHealthServer) testEmbeddedByValue() {}
// UnsafeHealthServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to HealthServer will
@@ -172,6 +206,13 @@ type UnsafeHealthServer interface {
}
func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) {
// If the following call panics, it indicates UnimplementedHealthServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&Health_ServiceDesc, srv)
}
@@ -193,26 +234,34 @@ func _Health_Check_Handler(srv interface{}, ctx context.Context, dec func(interf
return interceptor(ctx, in, info, handler)
}
func _Health_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthListRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HealthServer).List(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Health_List_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HealthServer).List(ctx, req.(*HealthListRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Health_Watch_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(HealthCheckRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(HealthServer).Watch(m, &healthWatchServer{stream})
return srv.(HealthServer).Watch(m, &grpc.GenericServerStream[HealthCheckRequest, HealthCheckResponse]{ServerStream: stream})
}
type Health_WatchServer interface {
Send(*HealthCheckResponse) error
grpc.ServerStream
}
type healthWatchServer struct {
grpc.ServerStream
}
func (x *healthWatchServer) Send(m *HealthCheckResponse) error {
return x.ServerStream.SendMsg(m)
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type Health_WatchServer = grpc.ServerStreamingServer[HealthCheckResponse]
// Health_ServiceDesc is the grpc.ServiceDesc for Health service.
// It's only intended for direct use with grpc.RegisterService,
@@ -225,6 +274,10 @@ var Health_ServiceDesc = grpc.ServiceDesc{
MethodName: "Check",
Handler: _Health_Check_Handler,
},
{
MethodName: "List",
Handler: _Health_List_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -25,10 +25,10 @@ package backoff
import (
"context"
"errors"
rand "math/rand/v2"
"time"
grpcbackoff "google.golang.org/grpc/backoff"
"google.golang.org/grpc/internal/grpcrand"
)
// Strategy defines the methodology for backing off after a grpc connection
@@ -67,7 +67,7 @@ func (bc Exponential) Backoff(retries int) time.Duration {
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + bc.Config.Jitter*(grpcrand.Float64()*2-1)
backoff *= 1 + bc.Config.Jitter*(rand.Float64()*2-1)
if backoff < 0 {
return 0
}

View File

@@ -33,6 +33,8 @@ type lbConfig struct {
childConfig serviceconfig.LoadBalancingConfig
}
// ChildName returns the name of the child balancer of the gracefulswitch
// Balancer.
func ChildName(l serviceconfig.LoadBalancingConfig) string {
return l.(*lbConfig).childBuilder.Name()
}
@@ -75,7 +77,6 @@ func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error)
if err != nil {
return nil, fmt.Errorf("error parsing config for policy %q: %v", name, err)
}
return &lbConfig{childBuilder: builder, childConfig: cfg}, nil
}

View File

@@ -109,8 +109,9 @@ func (gsb *Balancer) switchTo(builder balancer.Builder) (*balancerWrapper, error
return nil, errBalancerClosed
}
bw := &balancerWrapper{
builder: builder,
gsb: gsb,
ClientConn: gsb.cc,
builder: builder,
gsb: gsb,
lastState: balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
@@ -169,7 +170,6 @@ func (gsb *Balancer) latestBalancer() *balancerWrapper {
func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error {
// The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer()
gsbCfg, ok := state.BalancerConfig.(*lbConfig)
if ok {
// Switch to the child in the config unless it is already active.
@@ -294,6 +294,7 @@ func (gsb *Balancer) Close() {
// State updates from the wrapped balancer can result in invocation of the
// graceful switch logic.
type balancerWrapper struct {
balancer.ClientConn
balancer.Balancer
gsb *Balancer
builder balancer.Builder
@@ -414,7 +415,3 @@ func (bw *balancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver
bw.gsb.mu.Unlock()
bw.gsb.cc.UpdateAddresses(sc, addrs)
}
func (bw *balancerWrapper) Target() string {
return bw.gsb.cc.Target()
}

View File

@@ -65,7 +65,7 @@ type TruncatingMethodLogger struct {
callID uint64
idWithinCallGen *callIDGenerator
sink Sink // TODO(blog): make this plugable.
sink Sink // TODO(blog): make this pluggable.
}
// NewTruncatingMethodLogger returns a new truncating method logger.
@@ -80,7 +80,7 @@ func NewTruncatingMethodLogger(h, m uint64) *TruncatingMethodLogger {
callID: idGen.next(),
idWithinCallGen: &callIDGenerator{},
sink: DefaultSink, // TODO(blog): make it plugable.
sink: DefaultSink, // TODO(blog): make it pluggable.
}
}
@@ -106,7 +106,7 @@ func (ml *TruncatingMethodLogger) Build(c LogEntryConfig) *binlogpb.GrpcLogEntry
}
// Log creates a proto binary log entry, and logs it to the sink.
func (ml *TruncatingMethodLogger) Log(ctx context.Context, c LogEntryConfig) {
func (ml *TruncatingMethodLogger) Log(_ context.Context, c LogEntryConfig) {
ml.sink.Write(ml.Build(c))
}
@@ -397,7 +397,7 @@ func metadataKeyOmit(key string) bool {
switch key {
case "lb-token", ":path", ":authority", "content-encoding", "content-type", "user-agent", "te":
return true
case "grpc-trace-bin": // grpc-trace-bin is special because it's visiable to users.
case "grpc-trace-bin": // grpc-trace-bin is special because it's visible to users.
return false
}
return strings.HasPrefix(key, "grpc-")

View File

@@ -43,6 +43,8 @@ type Channel struct {
// Non-zero traceRefCount means the trace of this channel cannot be deleted.
traceRefCount int32
// ChannelMetrics holds connectivity state, target and call metrics for the
// channel within channelz.
ChannelMetrics ChannelMetrics
}
@@ -50,6 +52,8 @@ type Channel struct {
// nesting.
func (c *Channel) channelzIdentifier() {}
// String returns a string representation of the Channel, including its parent
// entity and ID.
func (c *Channel) String() string {
if c.Parent == nil {
return fmt.Sprintf("Channel #%d", c.ID)
@@ -61,24 +65,31 @@ func (c *Channel) id() int64 {
return c.ID
}
// SubChans returns a copy of the map of sub-channels associated with the
// Channel.
func (c *Channel) SubChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.subChans)
}
// NestedChans returns a copy of the map of nested channels associated with the
// Channel.
func (c *Channel) NestedChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.nestedChans)
}
// Trace returns a copy of the Channel's trace data.
func (c *Channel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()
return c.trace.copy()
}
// ChannelMetrics holds connectivity state, target and call metrics for the
// channel within channelz.
type ChannelMetrics struct {
// The current connectivity state of the channel.
State atomic.Pointer[connectivity.State]
@@ -136,12 +147,16 @@ func strFromPointer(s *string) string {
return *s
}
// String returns a string representation of the ChannelMetrics, including its
// state, target, and call metrics.
func (c *ChannelMetrics) String() string {
return fmt.Sprintf("State: %v, Target: %s, CallsStarted: %v, CallsSucceeded: %v, CallsFailed: %v, LastCallStartedTimestamp: %v",
c.State.Load(), strFromPointer(c.Target.Load()), c.CallsStarted.Load(), c.CallsSucceeded.Load(), c.CallsFailed.Load(), c.LastCallStartedTimestamp.Load(),
)
}
// NewChannelMetricForTesting creates a new instance of ChannelMetrics with
// specified initial values for testing purposes.
func NewChannelMetricForTesting(state connectivity.State, target string, started, succeeded, failed, timestamp int64) *ChannelMetrics {
c := &ChannelMetrics{}
c.State.Store(&state)

View File

@@ -46,7 +46,7 @@ type entry interface {
// channelMap is the storage data structure for channelz.
//
// Methods of channelMap can be divided in two two categories with respect to
// Methods of channelMap can be divided into two categories with respect to
// locking.
//
// 1. Methods acquire the global lock.
@@ -234,13 +234,6 @@ func copyMap(m map[int64]string) map[int64]string {
return n
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *channelMap) getTopChannels(id int64, maxResults int) ([]*Channel, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage

View File

@@ -33,7 +33,7 @@ var (
// outside this package except by tests.
IDGen IDGenerator
db *channelMap = newChannelMap()
db = newChannelMap()
// EntriesPerPage defines the number of channelz entries to be shown on a web page.
EntriesPerPage = 50
curState int32

View File

@@ -59,6 +59,8 @@ func NewServerMetricsForTesting(started, succeeded, failed, timestamp int64) *Se
return sm
}
// CopyFrom copies the metrics data from the provided ServerMetrics
// instance into the current instance.
func (sm *ServerMetrics) CopyFrom(o *ServerMetrics) {
sm.CallsStarted.Store(o.CallsStarted.Load())
sm.CallsSucceeded.Store(o.CallsSucceeded.Load())

View File

@@ -70,13 +70,18 @@ type EphemeralSocketMetrics struct {
RemoteFlowControlWindow int64
}
// SocketType represents the type of socket.
type SocketType string
// SocketType can be one of these.
const (
SocketTypeNormal = "NormalSocket"
SocketTypeListen = "ListenSocket"
)
// Socket represents a socket within channelz which includes socket
// metrics and data related to socket activity and provides methods
// for managing and interacting with sockets.
type Socket struct {
Entity
SocketType SocketType
@@ -100,6 +105,8 @@ type Socket struct {
Security credentials.ChannelzSecurityValue
}
// String returns a string representation of the Socket, including its parent
// entity, socket type, and ID.
func (ls *Socket) String() string {
return fmt.Sprintf("%s %s #%d", ls.Parent, ls.SocketType, ls.ID)
}

View File

@@ -47,12 +47,14 @@ func (sc *SubChannel) id() int64 {
return sc.ID
}
// Sockets returns a copy of the sockets map associated with the SubChannel.
func (sc *SubChannel) Sockets() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(sc.sockets)
}
// Trace returns a copy of the ChannelTrace associated with the SubChannel.
func (sc *SubChannel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()

View File

@@ -35,13 +35,13 @@ type SocketOptionData struct {
// Getsockopt defines the function to get socket options requested by channelz.
// It is to be passed to syscall.RawConn.Control().
// Windows OS doesn't support Socket Option
func (s *SocketOptionData) Getsockopt(fd uintptr) {
func (s *SocketOptionData) Getsockopt(uintptr) {
once.Do(func() {
logger.Warning("Channelz: socket options are not supported on non-linux environments")
})
}
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(c any) *SocketOptionData {
func GetSocketOption(any) *SocketOptionData {
return nil
}

View File

@@ -79,13 +79,21 @@ type TraceEvent struct {
Parent *TraceEvent
}
// ChannelTrace provides tracing information for a channel.
// It tracks various events and metadata related to the channel's lifecycle
// and operations.
type ChannelTrace struct {
cm *channelMap
clearCalled bool
cm *channelMap
clearCalled bool
// The time when the trace was created.
CreationTime time.Time
EventNum int64
mu sync.Mutex
Events []*traceEvent
// A counter for the number of events recorded in the
// trace.
EventNum int64
mu sync.Mutex
// A slice of traceEvent pointers representing the events recorded for
// this channel.
Events []*traceEvent
}
func (c *ChannelTrace) copy() *ChannelTrace {
@@ -175,6 +183,7 @@ var refChannelTypeToString = map[RefChannelType]string{
RefNormalSocket: "NormalSocket",
}
// String returns a string representation of the RefChannelType
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}

View File

@@ -28,9 +28,6 @@ import (
var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true)
// AdvertiseCompressors is set if registered compressor should be advertised
// ("GRPC_GO_ADVERTISE_COMPRESSORS" is not "false").
AdvertiseCompressors = boolFromEnv("GRPC_GO_ADVERTISE_COMPRESSORS", true)
// RingHashCap indicates the maximum ring size which defaults to 4096
// entries but may be overridden by setting the environment variable
// "GRPC_RING_HASH_CAP". This does not override the default bounds
@@ -43,6 +40,35 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true)
// XDSFallbackSupport is the env variable that controls whether support for
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", true)
// NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
// instead of the exiting pickfirst implementation. This can be disabled by
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"
// to "false".
NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true)
// XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash
// key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by
// setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". When the
// implementation of A76 is stable, we will flip the default value to false
// in a subsequent release. A final release will remove this environment
// variable, enabling the new behavior unconditionally.
XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", true)
// RingHashSetRequestHashKey is set if the ring hash balancer can get the
// request hash header by setting the "requestHashHeader" field, according
// to gRFC A76. It can be enabled by setting the environment variable
// "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY" to "true".
RingHashSetRequestHashKey = boolFromEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)
)
func boolFromEnv(envVar string, def bool) bool {

View File

@@ -53,4 +53,14 @@ var (
// C2PResolverTestOnlyTrafficDirectorURI is the TD URI for testing.
C2PResolverTestOnlyTrafficDirectorURI = os.Getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI")
// XDSDualstackEndpointsEnabled is true if gRPC should read the
// "additional addresses" in the xDS endpoint resource.
XDSDualstackEndpointsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS", true)
// XDSSystemRootCertsEnabled is true when xDS enabled gRPC clients can use
// the system's default root certificates for TLS certificate validation.
// For more details, see:
// https://github.com/grpc/proposal/blob/master/A82-xds-system-root-certs.md.
XDSSystemRootCertsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_SYSTEM_ROOT_CERTS", false)
)

View File

@@ -18,11 +18,11 @@
package internal
var (
// WithRecvBufferPool is implemented by the grpc package and returns a dial
// WithBufferPool is implemented by the grpc package and returns a dial
// option to configure a shared buffer pool for a grpc.ClientConn.
WithRecvBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
WithBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
// RecvBufferPool is implemented by the grpc package and returns a server
// BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
RecvBufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
)

View File

@@ -1,126 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpclog (internal) defines depth logging for grpc.
package grpclog
import (
"os"
)
// Logger is the logger used for the non-depth log functions.
var Logger LoggerV2
// DepthLogger is the logger used for the depth log functions.
var DepthLogger DepthLoggerV2
// InfoDepth logs to the INFO log at the specified depth.
func InfoDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.InfoDepth(depth, args...)
} else {
Logger.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
func WarningDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.WarningDepth(depth, args...)
} else {
Logger.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
func ErrorDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.ErrorDepth(depth, args...)
} else {
Logger.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
func FatalDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.FatalDepth(depth, args...)
} else {
Logger.Fatalln(args...)
}
os.Exit(1)
}
// LoggerV2 does underlying logging work for grpclog.
// This is a copy of the LoggerV2 defined in the external grpclog package. It
// is defined here to avoid a circular dependency.
type LoggerV2 interface {
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any)
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
Infoln(args ...any)
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
Infof(format string, args ...any)
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
Warning(args ...any)
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
Warningln(args ...any)
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
Warningf(format string, args ...any)
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
Error(args ...any)
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
Errorln(args ...any)
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
Errorf(format string, args ...any)
// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatal(args ...any)
// Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalln(args ...any)
// Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalf(format string, args ...any)
// V reports whether verbosity level l is at least the requested verbose level.
V(l int) bool
}
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore.
// This is a copy of the DepthLoggerV2 defined in the external grpclog package.
// It is defined here to avoid a circular dependency.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type DepthLoggerV2 interface {
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
WarningDepth(depth int, args ...any)
// ErrorDepth logs to ERROR log at the specified depth. Arguments are handled in the manner of fmt.Println.
ErrorDepth(depth int, args ...any)
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any)
}

View File

@@ -16,17 +16,21 @@
*
*/
// Package grpclog provides logging functionality for internal gRPC packages,
// outside of the functionality provided by the external `grpclog` package.
package grpclog
import (
"fmt"
"google.golang.org/grpc/grpclog"
)
// PrefixLogger does logging with a prefix.
//
// Logging method on a nil logs without any prefix.
type PrefixLogger struct {
logger DepthLoggerV2
logger grpclog.DepthLoggerV2
prefix string
}
@@ -38,7 +42,7 @@ func (pl *PrefixLogger) Infof(format string, args ...any) {
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return
}
InfoDepth(1, fmt.Sprintf(format, args...))
grpclog.InfoDepth(1, fmt.Sprintf(format, args...))
}
// Warningf does warning logging.
@@ -48,7 +52,7 @@ func (pl *PrefixLogger) Warningf(format string, args ...any) {
pl.logger.WarningDepth(1, fmt.Sprintf(format, args...))
return
}
WarningDepth(1, fmt.Sprintf(format, args...))
grpclog.WarningDepth(1, fmt.Sprintf(format, args...))
}
// Errorf does error logging.
@@ -58,36 +62,18 @@ func (pl *PrefixLogger) Errorf(format string, args ...any) {
pl.logger.ErrorDepth(1, fmt.Sprintf(format, args...))
return
}
ErrorDepth(1, fmt.Sprintf(format, args...))
}
// Debugf does info logging at verbose level 2.
func (pl *PrefixLogger) Debugf(format string, args ...any) {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe
// rewrite PrefixLogger a little to ensure that we don't use the global
// `Logger` here, and instead use the `logger` field.
if !Logger.V(2) {
return
}
if pl != nil {
// Handle nil, so the tests can pass in a nil logger.
format = pl.prefix + format
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return
}
InfoDepth(1, fmt.Sprintf(format, args...))
grpclog.ErrorDepth(1, fmt.Sprintf(format, args...))
}
// V reports whether verbosity level l is at least the requested verbose level.
func (pl *PrefixLogger) V(l int) bool {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe
// rewrite PrefixLogger a little to ensure that we don't use the global
// `Logger` here, and instead use the `logger` field.
return Logger.V(l)
if pl != nil {
return pl.logger.V(l)
}
return true
}
// NewPrefixLogger creates a prefix logger with the given prefix.
func NewPrefixLogger(logger DepthLoggerV2, prefix string) *PrefixLogger {
func NewPrefixLogger(logger grpclog.DepthLoggerV2, prefix string) *PrefixLogger {
return &PrefixLogger{logger: logger, prefix: prefix}
}

View File

@@ -1,100 +0,0 @@
//go:build !go1.21
// TODO: when this file is deleted (after Go 1.20 support is dropped), delete
// all of grpcrand and call the rand package directly.
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import (
"math/rand"
"sync"
"time"
)
var (
r = rand.New(rand.NewSource(time.Now().UnixNano()))
mu sync.Mutex
)
// Int implements rand.Int on the grpcrand global source.
func Int() int {
mu.Lock()
defer mu.Unlock()
return r.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
mu.Lock()
defer mu.Unlock()
return r.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
mu.Lock()
defer mu.Unlock()
return r.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
mu.Lock()
defer mu.Unlock()
return r.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
mu.Lock()
defer mu.Unlock()
return r.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
mu.Lock()
defer mu.Unlock()
return r.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
mu.Lock()
defer mu.Unlock()
return r.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
mu.Lock()
defer mu.Unlock()
return r.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
mu.Lock()
defer mu.Unlock()
r.Shuffle(n, f)
}

View File

@@ -1,73 +0,0 @@
//go:build go1.21
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import "math/rand"
// This implementation will be used for Go version 1.21 or newer.
// For older versions, the original implementation with mutex will be used.
// Int implements rand.Int on the grpcrand global source.
func Int() int {
return rand.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
return rand.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
return rand.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
return rand.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
return rand.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
return rand.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
return rand.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
return rand.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
rand.Shuffle(n, f)
}

View File

@@ -53,16 +53,28 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
return cs
}
// Schedule adds a callback to be scheduled after existing callbacks are run.
// TrySchedule tries to schedule the provided callback function f to be
// executed in the order it was added. This is a best-effort operation. If the
// context passed to NewCallbackSerializer was canceled before this method is
// called, the callback will not be scheduled.
//
// Callbacks are expected to honor the context when performing any blocking
// operations, and should return early when the context is canceled.
func (cs *CallbackSerializer) TrySchedule(f func(ctx context.Context)) {
cs.callbacks.Put(f)
}
// ScheduleOr schedules the provided callback function f to be executed in the
// order it was added. If the context passed to NewCallbackSerializer has been
// canceled before this method is called, the onFailure callback will be
// executed inline instead.
//
// Return value indicates if the callback was successfully added to the list of
// callbacks to be executed by the serializer. It is not possible to add
// callbacks once the context passed to NewCallbackSerializer is cancelled.
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
return cs.callbacks.Put(f) == nil
// Callbacks are expected to honor the context when performing any blocking
// operations, and should return early when the context is canceled.
func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure func()) {
if cs.callbacks.Put(f) != nil {
onFailure()
}
}
func (cs *CallbackSerializer) run(ctx context.Context) {

View File

@@ -77,7 +77,7 @@ func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) {
if ps.msg != nil {
msg := ps.msg
ps.cs.Schedule(func(context.Context) {
ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock()
defer ps.mu.Unlock()
if !ps.subscribers[sub] {
@@ -103,7 +103,7 @@ func (ps *PubSub) Publish(msg any) {
ps.msg = msg
for sub := range ps.subscribers {
s := sub
ps.cs.Schedule(func(context.Context) {
ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock()
defer ps.mu.Unlock()
if !ps.subscribers[s] {

View File

@@ -20,8 +20,6 @@ package grpcutil
import (
"strings"
"google.golang.org/grpc/internal/envconfig"
)
// RegisteredCompressorNames holds names of the registered compressors.
@@ -40,8 +38,5 @@ func IsCompressorNameRegistered(name string) bool {
// RegisteredCompressors returns a string of registered compressor names
// separated by comma.
func RegisteredCompressors() string {
if !envconfig.AdvertiseCompressors {
return ""
}
return strings.Join(RegisteredCompressorNames, ",")
}

View File

@@ -39,7 +39,7 @@ func ParseMethod(methodName string) (service, method string, _ error) {
}
// baseContentType is the base content-type for gRPC. This is a valid
// content-type on it's own, but can also include a content-subtype such as
// content-type on its own, but can also include a content-subtype such as
// "proto" as a suffix after "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.

View File

@@ -182,6 +182,7 @@ func (m *Manager) tryEnterIdleMode() bool {
return true
}
// EnterIdleModeForTesting instructs the channel to enter idle mode.
func (m *Manager) EnterIdleModeForTesting() {
m.tryEnterIdleMode()
}
@@ -225,7 +226,7 @@ func (m *Manager) ExitIdleMode() error {
// came in and OnCallBegin() noticed that the calls count is negative.
// - Channel is in idle mode, and multiple new RPCs come in at the same
// time, all of them notice a negative calls count in OnCallBegin and get
// here. The first one to get the lock would got the channel to exit idle.
// here. The first one to get the lock would get the channel to exit idle.
// - Channel is not in idle mode, and the user calls Connect which calls
// m.ExitIdleMode.
//
@@ -266,6 +267,7 @@ func (m *Manager) isClosed() bool {
return atomic.LoadInt32(&m.closed) == 1
}
// Close stops the timer associated with the Manager, if it exists.
func (m *Manager) Close() {
atomic.StoreInt32(&m.closed, 1)

View File

@@ -29,10 +29,12 @@ import (
)
var (
// WithHealthCheckFunc is set by dialoptions.go
WithHealthCheckFunc any // func (HealthChecker) DialOption
// HealthCheckFunc is used to provide client-side LB channel health checking
HealthCheckFunc HealthChecker
// RegisterClientHealthCheckListener is used to provide a listener for
// updates from the client-side health checking service. It returns a
// function that can be called to stop the health producer.
RegisterClientHealthCheckListener any // func(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) func()
// BalancerUnregister is exported by package balancer to unregister a balancer.
BalancerUnregister func(name string)
// KeepaliveMinPingTime is the minimum ping interval. This must be 10s by
@@ -62,6 +64,9 @@ var (
// gRPC server. An xDS-enabled server needs to know what type of credentials
// is configured on the underlying gRPC server. This is set by server.go.
GetServerCredentials any // func (*grpc.Server) credentials.TransportCredentials
// MetricsRecorderForServer returns the MetricsRecorderList derived from a
// server's stats handlers.
MetricsRecorderForServer any // func (*grpc.Server) estats.MetricsRecorder
// CanonicalString returns the canonical string of the code defined here:
// https://github.com/grpc/grpc/blob/master/doc/statuscodes.md.
//
@@ -106,6 +111,14 @@ var (
// This is used in the 1.0 release of gcp/observability, and thus must not be
// deleted or changed.
ClearGlobalDialOptions func()
// AddGlobalPerTargetDialOptions adds a PerTargetDialOption that will be
// configured for newly created ClientConns.
AddGlobalPerTargetDialOptions any // func (opt any)
// ClearGlobalPerTargetDialOptions clears the slice of global late apply
// dial options.
ClearGlobalPerTargetDialOptions func()
// JoinDialOptions combines the dial options passed as arguments into a
// single dial option.
JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption
@@ -126,7 +139,8 @@ var (
// deleted or changed.
BinaryLogger any // func(binarylog.Logger) grpc.ServerOption
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a provided grpc.ClientConn
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a
// provided grpc.ClientConn.
SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber)
// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using
@@ -140,6 +154,34 @@ var (
// other features, including the CSDS service.
NewXDSResolverWithConfigForTesting any // func([]byte) (resolver.Builder, error)
// NewXDSResolverWithPoolForTesting creates a new xDS resolver builder
// using the provided xDS pool instead of creating a new one using the
// bootstrap configuration specified by the supported environment variables.
// The resolver.Builder is meant to be used in conjunction with the
// grpc.WithResolvers DialOption. The resolver.Builder does not take
// ownership of the provided xDS client and it is the responsibility of the
// caller to close the client when no longer required.
//
// Testing Only
//
// This function should ONLY be used for testing and may not work with some
// other features, including the CSDS service.
NewXDSResolverWithPoolForTesting any // func(*xdsclient.Pool) (resolver.Builder, error)
// NewXDSResolverWithClientForTesting creates a new xDS resolver builder
// using the provided xDS client instead of creating a new one using the
// bootstrap configuration specified by the supported environment variables.
// The resolver.Builder is meant to be used in conjunction with the
// grpc.WithResolvers DialOption. The resolver.Builder does not take
// ownership of the provided xDS client and it is the responsibility of the
// caller to close the client when no longer required.
//
// Testing Only
//
// This function should ONLY be used for testing and may not work with some
// other features, including the CSDS service.
NewXDSResolverWithClientForTesting any // func(xdsclient.XDSClient) (resolver.Builder, error)
// RegisterRLSClusterSpecifierPluginForTesting registers the RLS Cluster
// Specifier Plugin for testing purposes, regardless of the XDSRLS environment
// variable.
@@ -174,7 +216,7 @@ var (
// GRPCResolverSchemeExtraMetadata determines when gRPC will add extra
// metadata to RPCs.
GRPCResolverSchemeExtraMetadata string = "xds"
GRPCResolverSchemeExtraMetadata = "xds"
// EnterIdleModeForTesting gets the ClientConn to enter IDLE mode.
EnterIdleModeForTesting any // func(*grpc.ClientConn)
@@ -182,31 +224,56 @@ var (
// ExitIdleModeForTesting gets the ClientConn to exit IDLE mode.
ExitIdleModeForTesting any // func(*grpc.ClientConn) error
// ChannelzTurnOffForTesting disables the Channelz service for testing
// purposes.
ChannelzTurnOffForTesting func()
// TriggerXDSResourceNameNotFoundForTesting triggers the resource-not-found
// error for a given resource type and name. This is usually triggered when
// the associated watch timer fires. For testing purposes, having this
// function makes events more predictable than relying on timer events.
TriggerXDSResourceNameNotFoundForTesting any // func(func(xdsresource.Type, string), string, string) error
// TriggerXDSResourceNotFoundForTesting causes the provided xDS Client to
// invoke resource-not-found error for the given resource type and name.
TriggerXDSResourceNotFoundForTesting any // func(xdsclient.XDSClient, xdsresource.Type, string) error
// TriggerXDSResourceNameNotFoundClient invokes the testing xDS Client
// singleton to invoke resource not found for a resource type name and
// resource name.
TriggerXDSResourceNameNotFoundClient any // func(string, string) error
// FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD.
// FromOutgoingContextRaw returns the un-merged, intermediary contents of
// metadata.rawMD.
FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool)
// UserSetDefaultScheme is set to true if the user has overridden the default resolver scheme.
UserSetDefaultScheme bool = false
// UserSetDefaultScheme is set to true if the user has overridden the
// default resolver scheme.
UserSetDefaultScheme = false
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
// SnapshotMetricRegistryForTesting snapshots the global data of the metric
// registry. Returns a cleanup function that sets the metric registry to its
// original state. Only called in testing functions.
SnapshotMetricRegistryForTesting func() func()
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
// testing purposes.
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes.
SetBufferPoolingThresholdForTesting any // func(int)
// TimeAfterFunc is used to create timers. During tests the function is
// replaced to track allocated timers and fail the test if a timer isn't
// cancelled.
TimeAfterFunc = func(d time.Duration, f func()) Timer {
return time.AfterFunc(d, f)
}
)
// HealthChecker defines the signature of the client-side LB channel health checking function.
// HealthChecker defines the signature of the client-side LB channel health
// checking function.
//
// The implementation is expected to create a health checking RPC stream by
// calling newStream(), watch for the health status of serviceName, and report
// it's health back by calling setConnectivityState().
// its health back by calling setConnectivityState().
//
// The health checking protocol is defined at:
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
@@ -228,3 +295,21 @@ const (
// It currently has an experimental suffix which would be removed once
// end-to-end testing of the policy is completed.
const RLSLoadBalancingPolicyName = "rls_experimental"
// EnforceSubConnEmbedding is used to enforce proper SubConn implementation
// embedding.
type EnforceSubConnEmbedding interface {
enforceSubConnEmbedding()
}
// EnforceClientConnEmbedding is used to enforce proper ClientConn implementation
// embedding.
type EnforceClientConnEmbedding interface {
enforceClientConnEmbedding()
}
// Timer is an interface to allow injecting different time.Timer implementations
// during tests.
type Timer interface {
Stop() bool
}

View File

@@ -97,13 +97,11 @@ func hasNotPrintable(msg string) bool {
return false
}
// ValidatePair validate a key-value pair with the following rules (the pseudo-header will be skipped) :
//
// - key must contain one or more characters.
// - the characters in the key must be contained in [0-9 a-z _ - .].
// - if the key ends with a "-bin" suffix, no validation of the corresponding value is performed.
// - the characters in the every value must be printable (in [%x20-%x7E]).
func ValidatePair(key string, vals ...string) error {
// ValidateKey validates a key with the following rules (pseudo-headers are
// skipped):
// - the key must contain one or more characters.
// - the characters in the key must be in [0-9 a-z _ - .].
func ValidateKey(key string) error {
// key should not be empty
if key == "" {
return fmt.Errorf("there is an empty key in the header")
@@ -119,6 +117,20 @@ func ValidatePair(key string, vals ...string) error {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key)
}
}
return nil
}
// ValidatePair validates a key-value pair with the following rules
// (pseudo-header are skipped):
// - the key must contain one or more characters.
// - the characters in the key must be in [0-9 a-z _ - .].
// - if the key ends with a "-bin" suffix, no validation of the corresponding
// value is performed.
// - the characters in every value must be printable (in [%x20-%x7E]).
func ValidatePair(key string, vals ...string) error {
if err := ValidateKey(key); err != nil {
return err
}
if strings.HasSuffix(key, "-bin") {
return nil
}

View File

@@ -0,0 +1,54 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package proxyattributes contains functions for getting and setting proxy
// attributes like the CONNECT address and user info.
package proxyattributes
import (
"net/url"
"google.golang.org/grpc/resolver"
)
type keyType string
const proxyOptionsKey = keyType("grpc.resolver.delegatingresolver.proxyOptions")
// Options holds the proxy connection details needed during the CONNECT
// handshake.
type Options struct {
User *url.Userinfo
ConnectAddr string
}
// Set returns a copy of addr with opts set in its attributes.
func Set(addr resolver.Address, opts Options) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(proxyOptionsKey, opts)
return addr
}
// Get returns the Options for the proxy [resolver.Address] and a boolean
// value representing if the attribute is present or not. The returned data
// should not be mutated.
func Get(addr resolver.Address) (Options, bool) {
if a := addr.Attributes.Value(proxyOptionsKey); a != nil {
return a.(Options), true
}
return Options{}, false
}

View File

@@ -0,0 +1,427 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package delegatingresolver implements a resolver capable of resolving both
// target URIs and proxy addresses.
package delegatingresolver
import (
"fmt"
"net/http"
"net/url"
"sync"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
var (
logger = grpclog.Component("delegating-resolver")
// HTTPSProxyFromEnvironment will be overwritten in the tests
HTTPSProxyFromEnvironment = http.ProxyFromEnvironment
)
// delegatingResolver manages both target URI and proxy address resolution by
// delegating these tasks to separate child resolvers. Essentially, it acts as
// an intermediary between the gRPC ClientConn and the child resolvers.
//
// It implements the [resolver.Resolver] interface.
type delegatingResolver struct {
target resolver.Target // parsed target URI to be resolved
cc resolver.ClientConn // gRPC ClientConn
proxyURL *url.URL // proxy URL, derived from proxy environment and target
// We do not hold both mu and childMu in the same goroutine. Avoid holding
// both locks when calling into the child, as the child resolver may
// synchronously callback into the channel.
mu sync.Mutex // protects all the fields below
targetResolverState *resolver.State // state of the target resolver
proxyAddrs []resolver.Address // resolved proxy addresses; empty if no proxy is configured
// childMu serializes calls into child resolvers. It also protects access to
// the following fields.
childMu sync.Mutex
targetResolver resolver.Resolver // resolver for the target URI, based on its scheme
proxyResolver resolver.Resolver // resolver for the proxy URI; nil if no proxy is configured
}
// nopResolver is a resolver that does nothing.
type nopResolver struct{}
func (nopResolver) ResolveNow(resolver.ResolveNowOptions) {}
func (nopResolver) Close() {}
// proxyURLForTarget determines the proxy URL for the given address based on the
// environment. It can return the following:
// - nil URL, nil error: No proxy is configured or the address is excluded
// using the `NO_PROXY` environment variable or if req.URL.Host is
// "localhost" (with or without // a port number)
// - nil URL, non-nil error: An error occurred while retrieving the proxy URL.
// - non-nil URL, nil error: A proxy is configured, and the proxy URL was
// retrieved successfully without any errors.
func proxyURLForTarget(address string) (*url.URL, error) {
req := &http.Request{URL: &url.URL{
Scheme: "https",
Host: address,
}}
return HTTPSProxyFromEnvironment(req)
}
// New creates a new delegating resolver that can create up to two child
// resolvers:
// - one to resolve the proxy address specified using the supported
// environment variables. This uses the registered resolver for the "dns"
// scheme. It is lazily built when a target resolver update contains at least
// one TCP address.
// - one to resolve the target URI using the resolver specified by the scheme
// in the target URI or specified by the user using the WithResolvers dial
// option. As a special case, if the target URI's scheme is "dns" and a
// proxy is specified using the supported environment variables, the target
// URI's path portion is used as the resolved address unless target
// resolution is enabled using the dial option.
func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions, targetResolverBuilder resolver.Builder, targetResolutionEnabled bool) (resolver.Resolver, error) {
r := &delegatingResolver{
target: target,
cc: cc,
proxyResolver: nopResolver{},
targetResolver: nopResolver{},
}
var err error
r.proxyURL, err = proxyURLForTarget(target.Endpoint())
if err != nil {
return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %s: %v", target, err)
}
// proxy is not configured or proxy address excluded using `NO_PROXY` env
// var, so only target resolver is used.
if r.proxyURL == nil {
return targetResolverBuilder.Build(target, cc, opts)
}
if logger.V(2) {
logger.Infof("Proxy URL detected : %s", r.proxyURL)
}
// Resolver updates from one child may trigger calls into the other. Block
// updates until the children are initialized.
r.childMu.Lock()
defer r.childMu.Unlock()
// When the scheme is 'dns' and target resolution on client is not enabled,
// resolution should be handled by the proxy, not the client. Therefore, we
// bypass the target resolver and store the unresolved target address.
if target.URL.Scheme == "dns" && !targetResolutionEnabled {
r.targetResolverState = &resolver.State{
Addresses: []resolver.Address{{Addr: target.Endpoint()}},
Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: target.Endpoint()}}}},
}
r.updateTargetResolverState(*r.targetResolverState)
return r, nil
}
wcc := &wrappingClientConn{
stateListener: r.updateTargetResolverState,
parent: r,
}
if r.targetResolver, err = targetResolverBuilder.Build(target, wcc, opts); err != nil {
return nil, fmt.Errorf("delegating_resolver: unable to build the resolver for target %s: %v", target, err)
}
return r, nil
}
// proxyURIResolver creates a resolver for resolving proxy URIs using the "dns"
// scheme. It adjusts the proxyURL to conform to the "dns:///" format and builds
// a resolver with a wrappingClientConn to capture resolved addresses.
func (r *delegatingResolver) proxyURIResolver(opts resolver.BuildOptions) (resolver.Resolver, error) {
proxyBuilder := resolver.Get("dns")
if proxyBuilder == nil {
panic("delegating_resolver: resolver for proxy not found for scheme dns")
}
url := *r.proxyURL
url.Scheme = "dns"
url.Path = "/" + r.proxyURL.Host
url.Host = "" // Clear the Host field to conform to the "dns:///" format
proxyTarget := resolver.Target{URL: url}
wcc := &wrappingClientConn{
stateListener: r.updateProxyResolverState,
parent: r,
}
return proxyBuilder.Build(proxyTarget, wcc, opts)
}
func (r *delegatingResolver) ResolveNow(o resolver.ResolveNowOptions) {
r.childMu.Lock()
defer r.childMu.Unlock()
r.targetResolver.ResolveNow(o)
r.proxyResolver.ResolveNow(o)
}
func (r *delegatingResolver) Close() {
r.childMu.Lock()
defer r.childMu.Unlock()
r.targetResolver.Close()
r.targetResolver = nil
r.proxyResolver.Close()
r.proxyResolver = nil
}
func needsProxyResolver(state *resolver.State) bool {
for _, addr := range state.Addresses {
if !skipProxy(addr) {
return true
}
}
for _, endpoint := range state.Endpoints {
for _, addr := range endpoint.Addresses {
if !skipProxy(addr) {
return true
}
}
}
return false
}
func skipProxy(address resolver.Address) bool {
// Avoid proxy when network is not tcp.
networkType, ok := networktype.Get(address)
if !ok {
networkType, _ = transport.ParseDialTarget(address.Addr)
}
if networkType != "tcp" {
return true
}
req := &http.Request{URL: &url.URL{
Scheme: "https",
Host: address.Addr,
}}
// Avoid proxy when address included in `NO_PROXY` environment variable or
// fails to get the proxy address.
url, err := HTTPSProxyFromEnvironment(req)
if err != nil || url == nil {
return true
}
return false
}
// updateClientConnStateLocked constructs a combined list of addresses by
// pairing each proxy address with every target address of type TCP. For each
// pair, it creates a new [resolver.Address] using the proxy address and
// attaches the corresponding target address and user info as attributes. Target
// addresses that are not of type TCP are appended to the list as-is. The
// function returns nil if either resolver has not yet provided an update, and
// returns the result of ClientConn.UpdateState once both resolvers have
// provided at least one update.
func (r *delegatingResolver) updateClientConnStateLocked() error {
if r.targetResolverState == nil || r.proxyAddrs == nil {
return nil
}
// If multiple resolved proxy addresses are present, we send only the
// unresolved proxy host and let net.Dial handle the proxy host name
// resolution when creating the transport. Sending all resolved addresses
// would increase the number of addresses passed to the ClientConn and
// subsequently to load balancing (LB) policies like Round Robin, leading
// to additional TCP connections. However, if there's only one resolved
// proxy address, we send it directly, as it doesn't affect the address
// count returned by the target resolver and the address count sent to the
// ClientConn.
var proxyAddr resolver.Address
if len(r.proxyAddrs) == 1 {
proxyAddr = r.proxyAddrs[0]
} else {
proxyAddr = resolver.Address{Addr: r.proxyURL.Host}
}
var addresses []resolver.Address
for _, targetAddr := range (*r.targetResolverState).Addresses {
if skipProxy(targetAddr) {
addresses = append(addresses, targetAddr)
continue
}
addresses = append(addresses, proxyattributes.Set(proxyAddr, proxyattributes.Options{
User: r.proxyURL.User,
ConnectAddr: targetAddr.Addr,
}))
}
// For each target endpoint, construct a new [resolver.Endpoint] that
// includes all addresses from all proxy endpoints and the addresses from
// that target endpoint, preserving the number of target endpoints.
var endpoints []resolver.Endpoint
for _, endpt := range (*r.targetResolverState).Endpoints {
var addrs []resolver.Address
for _, targetAddr := range endpt.Addresses {
// Avoid proxy when network is not tcp.
if skipProxy(targetAddr) {
addrs = append(addrs, targetAddr)
continue
}
for _, proxyAddr := range r.proxyAddrs {
addrs = append(addrs, proxyattributes.Set(proxyAddr, proxyattributes.Options{
User: r.proxyURL.User,
ConnectAddr: targetAddr.Addr,
}))
}
}
endpoints = append(endpoints, resolver.Endpoint{Addresses: addrs})
}
// Use the targetResolverState for its service config and attributes
// contents. The state update is only sent after both the target and proxy
// resolvers have sent their updates, and curState has been updated with the
// combined addresses.
curState := *r.targetResolverState
curState.Addresses = addresses
curState.Endpoints = endpoints
return r.cc.UpdateState(curState)
}
// updateProxyResolverState updates the proxy resolver state by storing proxy
// addresses and endpoints, marking the resolver as ready, and triggering a
// state update if both proxy and target resolvers are ready. If the ClientConn
// returns a non-nil error, it calls `ResolveNow()` on the target resolver. It
// is a StateListener function of wrappingClientConn passed to the proxy
// resolver.
func (r *delegatingResolver) updateProxyResolverState(state resolver.State) error {
r.mu.Lock()
defer r.mu.Unlock()
if logger.V(2) {
logger.Infof("Addresses received from proxy resolver: %s", state.Addresses)
}
if len(state.Endpoints) > 0 {
// We expect exactly one address per endpoint because the proxy resolver
// uses "dns" resolution.
r.proxyAddrs = make([]resolver.Address, 0, len(state.Endpoints))
for _, endpoint := range state.Endpoints {
r.proxyAddrs = append(r.proxyAddrs, endpoint.Addresses...)
}
} else if state.Addresses != nil {
r.proxyAddrs = state.Addresses
} else {
r.proxyAddrs = []resolver.Address{} // ensure proxyAddrs is non-nil to indicate an update has been received
}
err := r.updateClientConnStateLocked()
// Another possible approach was to block until updates are received from
// both resolvers. But this is not used because calling `New()` triggers
// `Build()` for the first resolver, which calls `UpdateState()`. And the
// second resolver hasn't sent an update yet, so it would cause `New()` to
// block indefinitely.
if err != nil {
go func() {
r.childMu.Lock()
defer r.childMu.Unlock()
if r.targetResolver != nil {
r.targetResolver.ResolveNow(resolver.ResolveNowOptions{})
}
}()
}
return err
}
// updateTargetResolverState is the StateListener function provided to the
// target resolver via wrappingClientConn. It updates the resolver state and
// marks the target resolver as ready. If the update includes at least one TCP
// address and the proxy resolver has not yet been constructed, it initializes
// the proxy resolver. A combined state update is triggered once both resolvers
// are ready. If all addresses are non-TCP, it proceeds without waiting for the
// proxy resolver. If ClientConn.UpdateState returns a non-nil error,
// ResolveNow() is called on the proxy resolver.
func (r *delegatingResolver) updateTargetResolverState(state resolver.State) error {
r.mu.Lock()
defer r.mu.Unlock()
if logger.V(2) {
logger.Infof("Addresses received from target resolver: %v", state.Addresses)
}
r.targetResolverState = &state
// If all addresses returned by the target resolver have a non-TCP network
// type, or are listed in the `NO_PROXY` environment variable, do not wait
// for proxy update.
if !needsProxyResolver(r.targetResolverState) {
return r.cc.UpdateState(*r.targetResolverState)
}
// The proxy resolver may be rebuilt multiple times, specifically each time
// the target resolver sends an update, even if the target resolver is built
// successfully but building the proxy resolver fails.
if len(r.proxyAddrs) == 0 {
go func() {
r.childMu.Lock()
defer r.childMu.Unlock()
if _, ok := r.proxyResolver.(nopResolver); !ok {
return
}
proxyResolver, err := r.proxyURIResolver(resolver.BuildOptions{})
if err != nil {
r.cc.ReportError(fmt.Errorf("delegating_resolver: unable to build the proxy resolver: %v", err))
return
}
r.proxyResolver = proxyResolver
}()
}
err := r.updateClientConnStateLocked()
if err != nil {
go func() {
r.childMu.Lock()
defer r.childMu.Unlock()
if r.proxyResolver != nil {
r.proxyResolver.ResolveNow(resolver.ResolveNowOptions{})
}
}()
}
return nil
}
// wrappingClientConn serves as an intermediary between the parent ClientConn
// and the child resolvers created here. It implements the resolver.ClientConn
// interface and is passed in that capacity to the child resolvers.
type wrappingClientConn struct {
// Callback to deliver resolver state updates
stateListener func(state resolver.State) error
parent *delegatingResolver
}
// UpdateState receives resolver state updates and forwards them to the
// appropriate listener function (either for the proxy or target resolver).
func (wcc *wrappingClientConn) UpdateState(state resolver.State) error {
return wcc.stateListener(state)
}
// ReportError intercepts errors from the child resolvers and passes them to
// ClientConn.
func (wcc *wrappingClientConn) ReportError(err error) {
wcc.parent.cc.ReportError(err)
}
// NewAddress intercepts the new resolved address from the child resolvers and
// passes them to ClientConn.
func (wcc *wrappingClientConn) NewAddress(addrs []resolver.Address) {
wcc.UpdateState(resolver.State{Addresses: addrs})
}
// ParseServiceConfig parses the provided service config and returns an object
// that provides the parsed config.
func (wcc *wrappingClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
return wcc.parent.cc.ParseServiceConfig(serviceConfigJSON)
}

View File

@@ -24,7 +24,9 @@ import (
"context"
"encoding/json"
"fmt"
rand "math/rand/v2"
"net"
"net/netip"
"os"
"strconv"
"strings"
@@ -35,28 +37,35 @@ import (
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/resolver/dns/internal"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
// addresses from SRV records. Must not be changed after init time.
var EnableSRVLookups = false
var (
// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
// addresses from SRV records. Must not be changed after init time.
EnableSRVLookups = false
// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
// If the timeout expires before a response is received, the request will be canceled.
//
// It is recommended to set this value at application startup. Avoid modifying this variable
// after initialization as it's not thread-safe for concurrent modification.
var ResolvingTimeout = 30 * time.Second
// MinResolutionInterval is the minimum interval at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionInterval = 30 * time.Second
var logger = grpclog.Component("dns")
// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
// If the timeout expires before a response is received, the request will be canceled.
//
// It is recommended to set this value at application startup. Avoid modifying this variable
// after initialization as it's not thread-safe for concurrent modification.
ResolvingTimeout = 30 * time.Second
logger = grpclog.Component("dns")
)
func init() {
resolver.Register(NewBuilder())
internal.TimeAfterFunc = time.After
internal.TimeNowFunc = time.Now
internal.TimeUntilFunc = time.Until
internal.NewNetResolver = newNetResolver
internal.AddressDialer = addressDialer
}
@@ -114,7 +123,7 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
}
// IP address.
if ipAddr, ok := formatIP(host); ok {
if ipAddr, err := formatIP(host); err == nil {
addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
cc.UpdateState(resolver.State{Addresses: addr})
return deadResolver{}, nil
@@ -169,7 +178,7 @@ type dnsResolver struct {
// finished. Otherwise, data race will be possible. [Race Example] in
// dns_resolver_test we replace the real lookup functions with mocked ones to
// facilitate testing. If Close() doesn't wait for watcher() goroutine
// finishes, race detector sometimes will warns lookup (READ the lookup
// finishes, race detector sometimes will warn lookup (READ the lookup
// function pointers) inside watcher() goroutine has data race with
// replaceNetFunc (WRITE the lookup function pointers).
wg sync.WaitGroup
@@ -203,12 +212,12 @@ func (d *dnsResolver) watcher() {
err = d.cc.UpdateState(*state)
}
var waitTime time.Duration
var nextResolutionTime time.Time
if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1
waitTime = internal.MinResolutionRate
nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
select {
case <-d.ctx.Done():
return
@@ -217,19 +226,21 @@ func (d *dnsResolver) watcher() {
} else {
// Poll on an error found in DNS Resolver or an error received from
// ClientConn.
waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
backoffIndex++
}
select {
case <-d.ctx.Done():
return
case <-internal.TimeAfterFunc(waitTime):
case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
}
}
}
func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
if !EnableSRVLookups {
// Skip this particular host to avoid timeouts with some versions of
// systemd-resolved.
if !EnableSRVLookups || d.host == "metadata.google.internal." {
return nil, nil
}
var newAddrs []resolver.Address
@@ -250,9 +261,9 @@ func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error)
return nil, err
}
for _, a := range lbAddrs {
ip, ok := formatIP(a)
if !ok {
return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
ip, err := formatIP(a)
if err != nil {
return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
}
addr := ip + ":" + strconv.Itoa(int(s.Port))
newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
@@ -312,9 +323,9 @@ func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error
}
newAddrs := make([]resolver.Address, 0, len(addrs))
for _, a := range addrs {
ip, ok := formatIP(a)
if !ok {
return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
ip, err := formatIP(a)
if err != nil {
return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
}
addr := ip + ":" + d.port
newAddrs = append(newAddrs, resolver.Address{Addr: addr})
@@ -341,19 +352,19 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
return &state, nil
}
// formatIP returns ok = false if addr is not a valid textual representation of
// an IP address. If addr is an IPv4 address, return the addr and ok = true.
// formatIP returns an error if addr is not a valid textual representation of
// an IP address. If addr is an IPv4 address, return the addr and error = nil.
// If addr is an IPv6 address, return the addr enclosed in square brackets and
// ok = true.
func formatIP(addr string) (addrIP string, ok bool) {
ip := net.ParseIP(addr)
if ip == nil {
return "", false
// error = nil.
func formatIP(addr string) (string, error) {
ip, err := netip.ParseAddr(addr)
if err != nil {
return "", err
}
if ip.To4() != nil {
return addr, true
if ip.Is4() {
return addr, nil
}
return "[" + addr + "]", true
return "[" + addr + "]", nil
}
// parseTarget takes the user input target string and default port, returns
@@ -369,7 +380,7 @@ func parseTarget(target, defaultPort string) (host, port string, err error) {
if target == "" {
return "", "", internal.ErrMissingAddr
}
if ip := net.ParseIP(target); ip != nil {
if _, err := netip.ParseAddr(target); err == nil {
// target is an IPv4 or IPv6(without brackets) address
return target, defaultPort, nil
}
@@ -417,7 +428,7 @@ func chosenByPercentage(a *int) bool {
if a == nil {
return true
}
return grpcrand.Intn(100)+1 <= *a
return rand.IntN(100)+1 <= *a
}
func canaryingSC(js string) string {

View File

@@ -28,7 +28,7 @@ import (
// NetResolver groups the methods on net.Resolver that are used by the DNS
// resolver implementation. This allows the default net.Resolver instance to be
// overidden from tests.
// overridden from tests.
type NetResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
@@ -50,16 +50,23 @@ var (
// The following vars are overridden from tests.
var (
// MinResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionRate = 30 * time.Second
// TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test
// to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is
// blocked waiting for the duration to elapse.
TimeAfterFunc func(time.Duration) <-chan time.Time
// TimeNowFunc is used by the DNS resolver to get the current time.
// In non-test code, this is implemented by time.Now. In test code,
// this can be used to control the current time for the resolver.
TimeNowFunc func() time.Time
// TimeUntilFunc is used by the DNS resolver to calculate the remaining
// wait time for re-resolution. In non-test code, this is implemented by
// time.Until. In test code, this can be used to control the remaining
// time for resolver to wait for re-resolution.
TimeUntilFunc func(time.Time) time.Duration
// NewNetResolver returns the net.Resolver instance for the given target.
NewNetResolver func(string) (NetResolver, error)

View File

@@ -55,7 +55,7 @@ func (r *passthroughResolver) start() {
r.cc.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: r.target.Endpoint()}}})
}
func (*passthroughResolver) ResolveNow(o resolver.ResolveNowOptions) {}
func (*passthroughResolver) ResolveNow(resolver.ResolveNowOptions) {}
func (*passthroughResolver) Close() {}

42
vendor/google.golang.org/grpc/internal/stats/labels.go generated vendored Normal file
View File

@@ -0,0 +1,42 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats provides internal stats related functionality.
package stats
import "context"
// Labels are the labels for metrics.
type Labels struct {
// TelemetryLabels are the telemetry labels to record.
TelemetryLabels map[string]string
}
type labelsKey struct{}
// GetLabels returns the Labels stored in the context, or nil if there is one.
func GetLabels(ctx context.Context) *Labels {
labels, _ := ctx.Value(labelsKey{}).(*Labels)
return labels
}
// SetLabels sets the Labels in the context.
func SetLabels(ctx context.Context, labels *Labels) context.Context {
// could also append
return context.WithValue(ctx, labelsKey{}, labels)
}

View File

@@ -0,0 +1,105 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package stats
import (
"fmt"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/stats"
)
// MetricsRecorderList forwards Record calls to all of its metricsRecorders.
//
// It eats any record calls where the label values provided do not match the
// number of label keys.
type MetricsRecorderList struct {
// metricsRecorders are the metrics recorders this list will forward to.
metricsRecorders []estats.MetricsRecorder
}
// NewMetricsRecorderList creates a new metric recorder list with all the stats
// handlers provided which implement the MetricsRecorder interface.
// If no stats handlers provided implement the MetricsRecorder interface,
// the MetricsRecorder list returned is a no-op.
func NewMetricsRecorderList(shs []stats.Handler) *MetricsRecorderList {
var mrs []estats.MetricsRecorder
for _, sh := range shs {
if mr, ok := sh.(estats.MetricsRecorder); ok {
mrs = append(mrs, mr)
}
}
return &MetricsRecorderList{
metricsRecorders: mrs,
}
}
func verifyLabels(desc *estats.MetricDescriptor, labelsRecv ...string) {
if got, want := len(labelsRecv), len(desc.Labels)+len(desc.OptionalLabels); got != want {
panic(fmt.Sprintf("Received %d labels in call to record metric %q, but expected %d.", got, desc.Name, want))
}
}
// RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle.
func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Count(handle, incr, labels...)
}
}
// RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle.
func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Count(handle, incr, labels...)
}
}
// RecordInt64Histo records the measurement alongside labels on the int
// histo associated with the provided handle.
func (l *MetricsRecorderList) RecordInt64Histo(handle *estats.Int64HistoHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Histo(handle, incr, labels...)
}
}
// RecordFloat64Histo records the measurement alongside labels on the float
// histo associated with the provided handle.
func (l *MetricsRecorderList) RecordFloat64Histo(handle *estats.Float64HistoHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Histo(handle, incr, labels...)
}
}
// RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Gauge(handle, incr, labels...)
}
}

View File

@@ -138,17 +138,19 @@ func (s *Status) WithDetails(details ...protoadapt.MessageV1) (*Status, error) {
// s.Code() != OK implies that s.Proto() != nil.
p := s.Proto()
for _, detail := range details {
any, err := anypb.New(protoadapt.MessageV2Of(detail))
m, err := anypb.New(protoadapt.MessageV2Of(detail))
if err != nil {
return nil, err
}
p.Details = append(p.Details, any)
p.Details = append(p.Details, m)
}
return &Status{s: p}, nil
}
// Details returns a slice of details messages attached to the status.
// If a detail cannot be decoded, the error is returned in place of the detail.
// If the detail can be decoded, the proto message returned is of the same
// type that was given to WithDetails().
func (s *Status) Details() []any {
if s == nil || s.s == nil {
return nil
@@ -160,7 +162,38 @@ func (s *Status) Details() []any {
details = append(details, err)
continue
}
details = append(details, detail)
// The call to MessageV1Of is required to unwrap the proto message if
// it implemented only the MessageV1 API. The proto message would have
// been wrapped in a V2 wrapper in Status.WithDetails. V2 messages are
// added to a global registry used by any.UnmarshalNew().
// MessageV1Of has the following behaviour:
// 1. If the given message is a wrapped MessageV1, it returns the
// unwrapped value.
// 2. If the given message already implements MessageV1, it returns it
// as is.
// 3. Else, it wraps the MessageV2 in a MessageV1 wrapper.
//
// Since the Status.WithDetails() API only accepts MessageV1, calling
// MessageV1Of ensures we return the same type that was given to
// WithDetails:
// * If the give type implemented only MessageV1, the unwrapping from
// point 1 above will restore the type.
// * If the given type implemented both MessageV1 and MessageV2, point 2
// above will ensure no wrapping is performed.
// * If the given type implemented only MessageV2 and was wrapped using
// MessageV1Of before passing to WithDetails(), it would be unwrapped
// in WithDetails by calling MessageV2Of(). Point 3 above will ensure
// that the type is wrapped in a MessageV1 wrapper again before
// returning. Note that protoc-gen-go doesn't generate code which
// implements ONLY MessageV2 at the time of writing.
//
// NOTE: Status details can also be added using the FromProto method.
// This could theoretically allow passing a Detail message that only
// implements the V2 API. In such a case the message will be wrapped in
// a MessageV1 wrapper when fetched using Details().
// Since protoc-gen-go generates only code that implements both V1 and
// V2 APIs for backward compatibility, this is not a concern.
details = append(details, protoadapt.MessageV1Of(detail))
}
return details
}

View File

@@ -58,20 +58,20 @@ func GetRusage() *Rusage {
// CPUTimeDiff returns the differences of user CPU time and system CPU time used
// between two Rusage structs. It a no-op function for non-linux environments.
func CPUTimeDiff(first *Rusage, latest *Rusage) (float64, float64) {
func CPUTimeDiff(*Rusage, *Rusage) (float64, float64) {
log()
return 0, 0
}
// SetTCPUserTimeout is a no-op function under non-linux environments.
func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error {
func SetTCPUserTimeout(net.Conn, time.Duration) error {
log()
return nil
}
// GetTCPUserTimeout is a no-op function under non-linux environments.
// A negative return value indicates the operation is not supported
func GetTCPUserTimeout(conn net.Conn) (int, error) {
func GetTCPUserTimeout(net.Conn) (int, error) {
log()
return -1, nil
}

View File

@@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters.
// the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1)

View File

@@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters.
// the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, 1)

View File

@@ -0,0 +1,144 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"sync/atomic"
"golang.org/x/net/http2"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct {
*Stream // Embed for common stream functionality.
ct *http2Client
done chan struct{} // closed at the end of stream to unblock writers.
doneFunc func() // invoked at the end of stream.
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool
header metadata.MD // the received header metadata
noHeaders bool // set if the client never received headers (set only after the stream is done).
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
status *status.Status // the status error received from the server
}
// Read reads an n byte message from the input stream.
func (s *ClientStream) Read(n int) (mem.BufferSlice, error) {
b, err := s.Stream.read(n)
if err == nil {
s.ct.incrMsgRecv()
}
return b, err
}
// Close closes the stream and propagates err to any readers.
func (s *ClientStream) Close(err error) {
var (
rst bool
rstCode http2.ErrCode
)
if err != nil {
rst = true
rstCode = http2.ErrCodeCancel
}
s.ct.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false)
}
// Write writes the hdr and data bytes to the output stream.
func (s *ClientStream) Write(hdr []byte, data mem.BufferSlice, opts *WriteOptions) error {
return s.ct.write(s, hdr, data, opts)
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *ClientStream) BytesReceived() bool {
return s.bytesReceived.Load()
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *ClientStream) Unprocessed() bool {
return s.unprocessed.Load()
}
func (s *ClientStream) waitOnHeader() {
select {
case <-s.ctx.Done():
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.Close(ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *ClientStream) RecvCompress() string {
s.waitOnHeader()
return s.recvCompress
}
// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *ClientStream) Done() <-chan struct{} {
return s.done
}
// Header returns the header metadata of the stream. Acquires the key-value
// pairs of header metadata once it is available. It blocks until i) the
// metadata is ready or ii) there is no header metadata or iii) the stream is
// canceled/expired.
func (s *ClientStream) Header() (metadata.MD, error) {
s.waitOnHeader()
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}
return s.header.Copy(), nil
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil.
func (s *ClientStream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *ClientStream) Status() *status.Status {
return s.status
}

View File

@@ -32,6 +32,7 @@ import (
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
)
@@ -148,9 +149,9 @@ type dataFrame struct {
streamID uint32
endStream bool
h []byte
d []byte
reader mem.Reader
// onEachWrite is called every time
// a part of d is written out.
// a part of data is written out.
onEachWrite func()
}
@@ -193,7 +194,7 @@ type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn error // if set, loopyWriter will exit, resulting in conn closure
closeConn error // if set, loopyWriter will exit with this error
}
func (*goAway) isTransportResponseFrame() bool { return false }
@@ -289,18 +290,22 @@ func (l *outStreamList) dequeue() *outStream {
}
// controlBuffer is a way to pass information to loopy.
// Information is passed as specific struct types called control frames.
// A control frame not only represents data, messages or headers to be sent out
// but can also be used to instruct loopy to update its internal state.
// It shouldn't be confused with an HTTP2 frame, although some of the control frames
// like dataFrame and headerFrame do go out on wire as HTTP2 frames.
//
// Information is passed as specific struct types called control frames. A
// control frame not only represents data, messages or headers to be sent out
// but can also be used to instruct loopy to update its internal state. It
// shouldn't be confused with an HTTP2 frame, although some of the control
// frames like dataFrame and headerFrame do go out on wire as HTTP2 frames.
type controlBuffer struct {
ch chan struct{}
done <-chan struct{}
wakeupCh chan struct{} // Unblocks readers waiting for something to read.
done <-chan struct{} // Closed when the transport is done.
// Mutex guards all the fields below, except trfChan which can be read
// atomically without holding mu.
mu sync.Mutex
consumerWaiting bool
list *itemList
err error
consumerWaiting bool // True when readers are blocked waiting for new data.
closed bool // True when the controlbuf is finished.
list *itemList // List of queued control frames.
// transportResponseFrames counts the number of queued items that represent
// the response of an action initiated by the peer. trfChan is created
@@ -308,47 +313,59 @@ type controlBuffer struct {
// closed and nilled when transportResponseFrames drops below the
// threshold. Both fields are protected by mu.
transportResponseFrames int
trfChan atomic.Value // chan struct{}
trfChan atomic.Pointer[chan struct{}]
}
func newControlBuffer(done <-chan struct{}) *controlBuffer {
return &controlBuffer{
ch: make(chan struct{}, 1),
list: &itemList{},
done: done,
wakeupCh: make(chan struct{}, 1),
list: &itemList{},
done: done,
}
}
// throttle blocks if there are too many incomingSettings/cleanupStreams in the
// controlbuf.
// throttle blocks if there are too many frames in the control buf that
// represent the response of an action initiated by the peer, like
// incomingSettings cleanupStreams etc.
func (c *controlBuffer) throttle() {
ch, _ := c.trfChan.Load().(chan struct{})
if ch != nil {
if ch := c.trfChan.Load(); ch != nil {
select {
case <-ch:
case <-(*ch):
case <-c.done:
}
}
}
// put adds an item to the controlbuf.
func (c *controlBuffer) put(it cbItem) error {
_, err := c.executeAndPut(nil, it)
return err
}
func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, error) {
var wakeUp bool
// executeAndPut runs f, and if the return value is true, adds the given item to
// the controlbuf. The item could be nil, in which case, this method simply
// executes f and does not add the item to the controlbuf.
//
// The first return value indicates whether the item was successfully added to
// the control buffer. A non-nil error, specifically ErrConnClosing, is returned
// if the control buffer is already closed.
func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return false, c.err
defer c.mu.Unlock()
if c.closed {
return false, ErrConnClosing
}
if f != nil {
if !f(it) { // f wasn't successful
c.mu.Unlock()
if !f() { // f wasn't successful
return false, nil
}
}
if it == nil {
return true, nil
}
var wakeUp bool
if c.consumerWaiting {
wakeUp = true
c.consumerWaiting = false
@@ -359,98 +376,102 @@ func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, err
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are adding the frame that puts us over the threshold; create
// a throttling channel.
c.trfChan.Store(make(chan struct{}))
ch := make(chan struct{})
c.trfChan.Store(&ch)
}
}
c.mu.Unlock()
if wakeUp {
select {
case c.ch <- struct{}{}:
case c.wakeupCh <- struct{}{}:
default:
}
}
return true, nil
}
// Note argument f should never be nil.
func (c *controlBuffer) execute(f func(it any) bool, it any) (bool, error) {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return false, c.err
}
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
c.mu.Unlock()
return true, nil
}
// get returns the next control frame from the control buffer. If block is true
// **and** there are no control frames in the control buffer, the call blocks
// until one of the conditions is met: there is a frame to return or the
// transport is closed.
func (c *controlBuffer) get(block bool) (any, error) {
for {
c.mu.Lock()
if c.err != nil {
frame, err := c.getOnceLocked()
if frame != nil || err != nil || !block {
// If we read a frame or an error, we can return to the caller. The
// call to getOnceLocked() returns a nil frame and a nil error if
// there is nothing to read, and in that case, if the caller asked
// us not to block, we can return now as well.
c.mu.Unlock()
return nil, c.err
}
if !c.list.isEmpty() {
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Load().(chan struct{})
close(ch)
c.trfChan.Store((chan struct{})(nil))
}
c.transportResponseFrames--
}
c.mu.Unlock()
return h, nil
}
if !block {
c.mu.Unlock()
return nil, nil
return frame, err
}
c.consumerWaiting = true
c.mu.Unlock()
// Release the lock above and wait to be woken up.
select {
case <-c.ch:
case <-c.wakeupCh:
case <-c.done:
return nil, errors.New("transport closed by client")
}
}
}
// Callers must not use this method, but should instead use get().
//
// Caller must hold c.mu.
func (c *controlBuffer) getOnceLocked() (any, error) {
if c.closed {
return false, ErrConnClosing
}
if c.list.isEmpty() {
return nil, nil
}
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Swap(nil)
close(*ch)
}
c.transportResponseFrames--
}
return h, nil
}
// finish closes the control buffer, cleaning up any streams that have queued
// header frames. Once this method returns, no more frames can be added to the
// control buffer, and attempts to do so will return ErrConnClosing.
func (c *controlBuffer) finish() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
defer c.mu.Unlock()
if c.closed {
return
}
c.err = ErrConnClosing
c.closed = true
// There may be headers for streams in the control buffer.
// These streams need to be cleaned out since the transport
// is still not aware of these yet.
for head := c.list.dequeueAll(); head != nil; head = head.next {
hdr, ok := head.it.(*headerFrame)
if !ok {
continue
}
if hdr.onOrphaned != nil { // It will be nil on the server-side.
hdr.onOrphaned(ErrConnClosing)
switch v := head.it.(type) {
case *headerFrame:
if v.onOrphaned != nil { // It will be nil on the server-side.
v.onOrphaned(ErrConnClosing)
}
case *dataFrame:
_ = v.reader.Close()
}
}
// In case throttle() is currently in flight, it needs to be unblocked.
// Otherwise, the transport may not close, since the transport is closed by
// the reader encountering the connection error.
ch, _ := c.trfChan.Load().(chan struct{})
ch := c.trfChan.Swap(nil)
if ch != nil {
close(ch)
close(*ch)
}
c.trfChan.Store((chan struct{})(nil))
c.mu.Unlock()
}
type side int
@@ -466,7 +487,7 @@ const (
// stream maintains a queue of data frames; as loopy receives data frames
// it gets added to the queue of the relevant stream.
// Loopy goes over this list of active streams by processing one node every iteration,
// thereby closely resemebling to a round-robin scheduling over all streams. While
// thereby closely resembling a round-robin scheduling over all streams. While
// processing a stream, loopy writes out data bytes from this stream capped by the min
// of http2MaxFrameLen, connection-level flow control and stream-level flow control.
type loopyWriter struct {
@@ -490,26 +511,29 @@ type loopyWriter struct {
draining bool
conn net.Conn
logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
// Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error)
}
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger) *loopyWriter {
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
var buf bytes.Buffer
l := &loopyWriter{
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
ssGoAwayHandler: goAwayHandler,
bufferPool: bufferPool,
}
return l
}
@@ -767,6 +791,11 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
// not be established yet.
delete(l.estdStreams, c.streamID)
str.deleteSelf()
for head := str.itl.dequeueAll(); head != nil; head = head.next {
if df, ok := head.it.(*dataFrame); ok {
_ = df.reader.Close()
}
}
}
if c.rst { // If RST_STREAM needs to be sent.
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
@@ -902,16 +931,18 @@ func (l *loopyWriter) processData() (bool, error) {
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
// A data item is represented by a dataFrame, since it later translates into
// multiple HTTP2 data frames.
// Every dataFrame has two buffers; h that keeps grpc-message header and d that is actual data.
// As an optimization to keep wire traffic low, data from d is copied to h to make as big as the
// maximum possible HTTP2 frame size.
// Every dataFrame has two buffers; h that keeps grpc-message header and data
// that is the actual message. As an optimization to keep wire traffic low, data
// from data is copied to h to make as big as the maximum possible HTTP2 frame
// size.
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame
if len(dataItem.h) == 0 && dataItem.reader.Remaining() == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err
}
str.itl.dequeue() // remove the empty data item from stream
_ = dataItem.reader.Close()
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
@@ -926,9 +957,7 @@ func (l *loopyWriter) processData() (bool, error) {
}
return false, nil
}
var (
buf []byte
)
// Figure out the maximum size we can send
maxSize := http2MaxFrameLen
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
@@ -942,43 +971,50 @@ func (l *loopyWriter) processData() (bool, error) {
}
// Compute how much of the header and data we can send within quota and max frame length
hSize := min(maxSize, len(dataItem.h))
dSize := min(maxSize-hSize, len(dataItem.d))
if hSize != 0 {
if dSize == 0 {
buf = dataItem.h
} else {
// We can add some data to grpc message header to distribute bytes more equally across frames.
// Copy on the stack to avoid generating garbage
var localBuf [http2MaxFrameLen]byte
copy(localBuf[:hSize], dataItem.h)
copy(localBuf[hSize:], dataItem.d[:dSize])
buf = localBuf[:hSize+dSize]
}
} else {
buf = dataItem.d
}
dSize := min(maxSize-hSize, dataItem.reader.Remaining())
remainingBytes := len(dataItem.h) + dataItem.reader.Remaining() - hSize - dSize
size := hSize + dSize
var buf *[]byte
if hSize != 0 && dSize == 0 {
buf = &dataItem.h
} else {
// Note: this is only necessary because the http2.Framer does not support
// partially writing a frame, so the sequence must be materialized into a buffer.
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed.
pool := l.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
buf = pool.Get(size)
defer pool.Put(buf)
copy((*buf)[:hSize], dataItem.h)
_, _ = dataItem.reader.Read((*buf)[hSize:])
}
// Now that outgoing flow controls are checked we can replenish str's write quota
str.wq.replenish(size)
var endStream bool
// If this is the last data message on this stream and all of it can be written in this iteration.
if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size {
if dataItem.endStream && remainingBytes == 0 {
endStream = true
}
if dataItem.onEachWrite != nil {
dataItem.onEachWrite()
}
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil {
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil {
return false, err
}
str.bytesOutStanding += size
l.sendQuota -= uint32(size)
dataItem.h = dataItem.h[hSize:]
dataItem.d = dataItem.d[dSize:]
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out.
if remainingBytes == 0 { // All the data from that message was written out.
_ = dataItem.reader.Close()
str.itl.dequeue()
}
if str.itl.isEmpty() {
@@ -997,10 +1033,3 @@ func (l *loopyWriter) processData() (bool, error) {
}
return false, nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -92,14 +92,11 @@ func (f *trInFlow) newLimit(n uint32) uint32 {
func (f *trInFlow) onData(n uint32) uint32 {
f.unacked += n
if f.unacked >= f.limit/4 {
w := f.unacked
f.unacked = 0
if f.unacked < f.limit/4 {
f.updateEffectiveWindowSize()
return w
return 0
}
f.updateEffectiveWindowSize()
return 0
return f.reset()
}
func (f *trInFlow) reset() uint32 {

View File

@@ -24,7 +24,6 @@
package transport
import (
"bytes"
"context"
"errors"
"fmt"
@@ -40,6 +39,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -50,7 +50,7 @@ import (
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
@@ -98,6 +98,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
bufferPool: bufferPool,
}
st.logger = prefixLoggerForServerHandlerTransport(st)
@@ -171,6 +172,8 @@ type serverHandlerTransport struct {
stats []stats.Handler
logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
}
func (ht *serverHandlerTransport) Close(err error) {
@@ -222,7 +225,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
}
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status) error {
ht.writeStatusMu.Lock()
defer ht.writeStatusMu.Unlock()
@@ -244,6 +247,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
}
s.hdrMu.Lock()
defer s.hdrMu.Unlock()
if p := st.Proto(); p != nil && len(p.Details) > 0 {
delete(s.trailer, grpcStatusDetailsBinHeader)
stBytes, err := proto.Marshal(p)
@@ -268,7 +272,6 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
}
}
}
s.hdrMu.Unlock()
})
if err == nil { // transport has not been closed
@@ -286,14 +289,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
// writePendingHeaders sets common and custom headers on the first
// write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
ht.writeCommonHeaders(s)
ht.writeCustomHeaders(s)
}
// writeCommonHeaders sets common headers on the first write
// call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", ht.contentType)
@@ -314,7 +317,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
// writeCustomHeaders sets custom headers set on the stream via SetHeader
// on the first write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
h := ht.rw.Header()
s.hdrMu.Lock()
@@ -330,19 +333,31 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
s.hdrMu.Unlock()
}
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (ht *serverHandlerTransport) write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *WriteOptions) error {
// Always take a reference because otherwise there is no guarantee the data will
// be available after this function returns. This is what callers to Write
// expect.
data.Ref()
headersWritten := s.updateHeaderSent()
return ht.do(func() {
err := ht.do(func() {
defer data.Free()
if !headersWritten {
ht.writePendingHeaders(s)
}
ht.rw.Write(hdr)
ht.rw.Write(data)
for _, b := range data {
_, _ = ht.rw.Write(b.ReadOnlyData())
}
ht.rw.(http.Flusher).Flush()
})
if err != nil {
data.Free()
return err
}
return nil
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) error {
if err := s.SetHeader(md); err != nil {
return err
}
@@ -370,7 +385,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var cancel context.CancelFunc
if ht.timeoutSet {
@@ -393,20 +408,22 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req
s := &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
s := &ServerStream{
Stream: &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
buf: newRecvBuffer(),
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
}
@@ -415,21 +432,19 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
go func() {
defer close(readerDone)
// TODO: minimize garbage, optimize recvBuffer code/ownership
const readSize = 8196
for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf)
for {
buf := ht.bufferPool.Get(http2MaxFrameLen)
n, err := req.Body.Read(*buf)
if n > 0 {
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
buf = buf[n:]
*buf = (*buf)[:n]
s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
} else {
ht.bufferPool.Put(buf)
}
if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
return
}
if len(buf) == 0 {
buf = make([]byte, readSize)
}
}
}()
@@ -458,11 +473,9 @@ func (ht *serverHandlerTransport) runStream() {
}
}
func (ht *serverHandlerTransport) IncrMsgSent() {}
func (ht *serverHandlerTransport) incrMsgRecv() {}
func (ht *serverHandlerTransport) IncrMsgRecv() {}
func (ht *serverHandlerTransport) Drain(debugData string) {
func (ht *serverHandlerTransport) Drain(string) {
panic("Drain() is not implemented")
}
@@ -485,5 +498,5 @@ func mapRecvMsgError(err error) error {
if strings.Contains(err.Error(), "body closed by handler") {
return status.Error(codes.Canceled, err.Error())
}
return connectionErrorf(true, err, err.Error())
return connectionErrorf(true, err, "%s", err.Error())
}

View File

@@ -43,10 +43,12 @@ import (
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/proxyattributes"
istatus "google.golang.org/grpc/internal/status"
isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
@@ -59,6 +61,8 @@ import (
// atomically.
var clientConnectionCounter uint64
var goAwayLoopyWriterTimeout = 5 * time.Second
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
// http2Client implements the ClientTransport interface with HTTP2.
@@ -83,9 +87,9 @@ type http2Client struct {
writerDone chan struct{} // sync point to enable testing.
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
framer *framer
goAway chan struct{}
keepaliveDone chan struct{} // Closed when the keepalive goroutine exits.
framer *framer
// controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller.
// Do not access controlBuf with mu held.
@@ -114,13 +118,13 @@ type http2Client struct {
streamQuota int64
streamsQuotaAvailable chan struct{}
waitingStreams uint32
nextID uint32
registeredCompressors string
// Do not access controlBuf with mu held.
mu sync.Mutex // guard the following variables
nextID uint32
state transportState
activeStreams map[uint32]*Stream
activeStreams map[uint32]*ClientStream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
// goAwayReason records the http2.ErrCode and debug data received with the
@@ -144,13 +148,13 @@ type http2Client struct {
onClose func(GoAwayReason)
bufferPool *bufferPool
bufferPool mem.BufferPool
connectionID uint64
logger *grpclog.PrefixLogger
}
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, grpcUA string) (net.Conn, error) {
address := addr.Addr
networkType, ok := networktype.Get(addr)
if fn != nil {
@@ -172,10 +176,10 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error
return fn(ctx, address)
}
if !ok {
networkType, address = parseDialTarget(address)
networkType, address = ParseDialTarget(address)
}
if networkType == "tcp" && useProxy {
return proxyDial(ctx, address, grpcUA)
if opts, present := proxyattributes.Get(addr); present {
return proxyDial(ctx, addr, grpcUA, opts)
}
return internal.NetDialerWithTCPKeepalive().DialContext(ctx, networkType, address)
}
@@ -196,10 +200,10 @@ func isTemporary(err error) bool {
return true
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// NewHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (_ *http2Client, err error) {
func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (_ ClientTransport, err error) {
scheme := "http"
ctx, cancel := context.WithCancel(ctx)
defer func() {
@@ -214,7 +218,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
// address specific arbitrary data to reach custom dialers and credential handshakers.
connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent)
conn, err := dial(connectCtx, opts.Dialer, addr, opts.UserAgent)
if err != nil {
if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
@@ -229,7 +233,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
}(conn)
// The following defer and goroutine monitor the connectCtx for cancelation
// The following defer and goroutine monitor the connectCtx for cancellation
// and deadline. On context expiration, the connection is hard closed and
// this function will naturally fail as a result. Otherwise, the defer
// waits for the goroutine to exit to prevent the context from being
@@ -332,10 +336,11 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
goAway: make(chan struct{}),
keepaliveDone: make(chan struct{}),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
activeStreams: make(map[uint32]*Stream),
activeStreams: make(map[uint32]*ClientStream),
isSecure: isSecure,
perRPCCreds: perRPCCreds,
kp: kp,
@@ -346,7 +351,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1),
keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
bufferPool: opts.BufferPool,
onClose: onClose,
}
var czSecurity credentials.ChannelzSecurityValue
@@ -408,10 +413,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerErrCh := make(chan error, 1)
go t.reader(readerErrCh)
defer func() {
if err == nil {
err = <-readerErrCh
}
if err != nil {
// writerDone should be closed since the loopy goroutine
// wouldn't have started in the case this function returns an error.
close(t.writerDone)
t.Close(err)
}
}()
@@ -458,8 +463,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err := t.framer.writer.Flush(); err != nil {
return nil, err
}
// Block until the server preface is received successfully or an error occurs.
if err = <-readerErrCh; err != nil {
return nil, err
}
go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
if err := t.loopy.run(); !isIOError(err) {
// Immediately close the connection, as the loopy writer returns
// when there are no more active streams and we were draining (the
@@ -472,17 +481,19 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
ct: t,
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
headerChan: make(chan struct{}),
contentSubtype: callHdr.ContentSubtype,
doneFunc: callHdr.DoneFunc,
s := &ClientStream{
Stream: &Stream{
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
contentSubtype: callHdr.ContentSubtype,
},
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
}
s.wq = newWriteQuota(defaultWriteQuota, s.done)
s.requestRead = func(n int) {
@@ -498,9 +509,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
ctxDone: s.ctx.Done(),
recv: s.buf,
closeStream: func(err error) {
t.CloseStream(s, err)
s.Close(err)
},
freeBuffer: t.bufferPool.put,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
@@ -517,6 +527,18 @@ func (t *http2Client) getPeer() *peer.Peer {
}
}
// OutgoingGoAwayHandler writes a GOAWAY to the connection. Always returns (false, err) as we want the GoAway
// to be the last frame loopy writes to the transport.
func (t *http2Client) outgoingGoAwayHandler(g *goAway) (bool, error) {
t.mu.Lock()
maxStreamID := t.nextID - 2
t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(maxStreamID, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
return false, g.closeConn
}
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{
@@ -578,12 +600,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
for k, v := range callAuthData {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
if b := stats.OutgoingTags(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
}
if b := stats.OutgoingTrace(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
}
if md, added, ok := metadataFromOutgoingContextRaw(ctx); ok {
var k string
@@ -719,7 +735,7 @@ func (e NewStreamError) Error() string {
// NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
// ServerName field of the resolver returned address takes precedence over
@@ -744,7 +760,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return
}
// The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1)
s.unprocessed.Store(true)
s.write(recvMsg{err: err})
close(s.done)
// If headerChan isn't closed, then close it.
@@ -755,7 +771,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
hdr := &headerFrame{
hf: headerFields,
endStream: false,
initStream: func(id uint32) error {
initStream: func(uint32) error {
t.mu.Lock()
// TODO: handle transport closure in loopy instead and remove this
// initStream is never called when transport is draining.
@@ -781,7 +797,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
firstTry := true
var ch chan struct{}
transportDrainRequired := false
checkForStreamQuota := func(it any) bool {
checkForStreamQuota := func() bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
t.waitingStreams++
@@ -793,23 +809,24 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
t.waitingStreams--
}
t.streamQuota--
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
if t.state == draining || t.activeStreams == nil { // Can be niled from Close().
t.mu.Unlock()
return false // Don't create a stream if the transport is already closed.
}
hdr.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = hdr.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.activeStreams[s.id] = s
t.mu.Unlock()
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
case t.streamsQuotaAvailable <- struct{}{}:
@@ -819,13 +836,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true
}
var hdrListSizeErr error
checkForHeaderListSize := func(it any) bool {
checkForHeaderListSize := func() bool {
if t.maxSendHeaderListSize == nil {
return true
}
hdrFrame := it.(*headerFrame)
var sz int64
for _, f := range hdrFrame.hf {
for _, f := range hdr.hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize)
return false
@@ -834,8 +850,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true
}
for {
success, err := t.controlBuf.executeAndPut(func(it any) bool {
return checkForHeaderListSize(it) && checkForStreamQuota(it)
success, err := t.controlBuf.executeAndPut(func() bool {
return checkForHeaderListSize() && checkForStreamQuota()
}, hdr)
if err != nil {
// Connection closed.
@@ -889,21 +905,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return s, nil
}
// CloseStream clears the footprint of a stream when the stream is not needed any more.
// This must not be executed in reader's goroutine.
func (t *http2Client) CloseStream(s *Stream, err error) {
var (
rst bool
rstCode http2.ErrCode
)
if err != nil {
rst = true
rstCode = http2.ErrCodeCancel
}
t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false)
}
func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
// Set stream status to done.
if s.swapState(streamDone) == streamDone {
// If it was already done, return. If multiple closeStream calls
@@ -946,7 +948,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
rst: rst,
rstCode: rstCode,
}
addBackStreamQuota := func(any) bool {
addBackStreamQuota := func() bool {
t.streamQuota++
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
@@ -966,8 +968,9 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be
// accessed any more.
// accessed anymore.
func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
t.mu.Lock()
// Make sure we only close once.
if t.state == closing {
@@ -990,15 +993,33 @@ func (t *http2Client) Close(err error) {
// should unblock it so that the goroutine eventually exits.
t.kpDormancyCond.Signal()
}
t.mu.Unlock()
t.controlBuf.finish()
t.cancel()
t.conn.Close()
channelz.RemoveEntry(t.channelz.ID)
// Append info about previous goaways if there were any, since this may be important
// for understanding the root cause for this connection to be closed.
_, goAwayDebugMessage := t.GetGoAwayReason()
goAwayDebugMessage := t.goAwayDebugMessage
t.mu.Unlock()
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It
// also waits for loopyWriter to be closed with a timer to avoid the
// long blocking in case the connection is blackholed, i.e. TCP is
// just stuck.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
timer := time.NewTimer(goAwayLoopyWriterTimeout)
defer timer.Stop()
select {
case <-t.writerDone: // success
case <-timer.C:
t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout)
}
t.cancel()
t.conn.Close()
// Waits for the reader and keepalive goroutines to exit before returning to
// ensure all resources are cleaned up before Close can return.
<-t.readerDone
if t.keepaliveEnabled {
<-t.keepaliveDone
}
channelz.RemoveEntry(t.channelz.ID)
var st *status.Status
if len(goAwayDebugMessage) > 0 {
st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage)
@@ -1047,30 +1068,40 @@ func (t *http2Client) GracefulClose() {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (t *http2Client) write(s *ClientStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error {
reader := data.Reader()
if opts.Last {
// If it's the last message, update stream state.
if !s.compareAndSwapState(streamActive, streamWriteDone) {
_ = reader.Close()
return errStreamDone
}
} else if s.getState() != streamActive {
_ = reader.Close()
return errStreamDone
}
df := &dataFrame{
streamID: s.id,
endStream: opts.Last,
h: hdr,
d: data,
reader: reader,
}
if hdr != nil || data != nil { // If it's not an empty data frame, check quota.
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
if hdr != nil || df.reader.Remaining() != 0 { // If it's not an empty data frame, check quota.
if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return err
}
}
return t.controlBuf.put(df)
if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
t.incrMsgSent()
return nil
}
func (t *http2Client) getStream(f http2.Frame) *Stream {
func (t *http2Client) getStream(f http2.Frame) *ClientStream {
t.mu.Lock()
s := t.activeStreams[f.Header().StreamID]
t.mu.Unlock()
@@ -1080,7 +1111,7 @@ func (t *http2Client) getStream(f http2.Frame) *Stream {
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
func (t *http2Client) adjustWindow(s *ClientStream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@@ -1089,7 +1120,7 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) {
// updateWindow adjusts the inbound quota for the stream.
// Window updates will be sent out when the cumulative quota
// exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) updateWindow(s *ClientStream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@@ -1099,7 +1130,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
// for the transport and the stream based on the current bdp
// estimation.
func (t *http2Client) updateFlowControl(n uint32) {
updateIWS := func(any) bool {
updateIWS := func() bool {
t.initialWindowSize = int32(n)
t.mu.Lock()
for _, s := range t.activeStreams {
@@ -1172,10 +1203,13 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
buffer := t.bufferPool.get()
buffer.Reset()
buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
}
}
// The server has closed the stream without sending trailers. Record that
@@ -1192,7 +1226,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
}
if f.ErrCode == http2.ErrCodeRefusedStream {
// The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1)
s.unprocessed.Store(true)
}
statusCode, ok := http2ErrConvTab[f.ErrCode]
if !ok {
@@ -1204,11 +1238,12 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if statusCode == codes.Canceled {
if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
// Our deadline was already exceeded, and that was likely the cause
// of this cancelation. Alter the status code accordingly.
// of this cancellation. Alter the status code accordingly.
statusCode = codes.DeadlineExceeded
}
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
st := status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)
t.closeStream(s, st.Err(), false, http2.ErrCodeNo, st, nil, false)
}
func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
@@ -1252,7 +1287,7 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
}
updateFuncs = append(updateFuncs, updateStreamQuota)
}
t.controlBuf.executeAndPut(func(any) bool {
t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs {
f()
}
@@ -1273,11 +1308,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
t.controlBuf.put(pingAck)
}
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return
return nil
}
if f.ErrCode == http2.ErrCodeEnhanceYourCalm && string(f.DebugData()) == "too_many_pings" {
// When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug
@@ -1289,8 +1324,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID
if id > 0 && id%2 == 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id))
return
return connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id)
}
// A client can receive multiple GoAways from the server (see
// https://github.com/grpc/grpc-go/issues/1387). The idea is that the first
@@ -1307,8 +1341,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// If there are multiple GoAways the first one should always have an ID greater than the following ones.
if id > t.prevGoAwayID {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID))
return
return connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)
}
default:
t.setGoAwayReason(f)
@@ -1332,15 +1365,14 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.prevGoAwayID = id
if len(t.activeStreams) == 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))
return
return connectionErrorf(true, nil, "received goaway and there are no active streams")
}
streamsToClose := make([]*Stream, 0)
streamsToClose := make([]*ClientStream, 0)
for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server.
atomic.StoreUint32(&stream.unprocessed, 1)
stream.unprocessed.Store(true)
streamsToClose = append(streamsToClose, stream)
}
}
@@ -1350,6 +1382,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
for _, stream := range streamsToClose {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
return nil
}
// setGoAwayReason sets the value of t.goAwayReason based
@@ -1358,8 +1391,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// the caller.
func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
t.goAwayReason = GoAwayNoReason
switch f.ErrCode {
case http2.ErrCodeEnhanceYourCalm:
if f.ErrCode == http2.ErrCodeEnhanceYourCalm {
if string(f.DebugData()) == "too_many_pings" {
t.goAwayReason = GoAwayTooManyPings
}
@@ -1391,7 +1423,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return
}
endStream := frame.StreamEnded()
atomic.StoreUint32(&s.bytesReceived, 1)
s.bytesReceived.Store(true)
initialHeader := atomic.LoadUint32(&s.headerChanClosed) == 0
if !initialHeader && !endStream {
@@ -1585,7 +1617,13 @@ func (t *http2Client) readServerPreface() error {
// network connection. If the server preface is not read successfully, an
// error is pushed to errCh; otherwise errCh is closed with no error.
func (t *http2Client) reader(errCh chan<- error) {
defer close(t.readerDone)
var errClose error
defer func() {
close(t.readerDone)
if errClose != nil {
t.Close(errClose)
}
}()
if err := t.readServerPreface(); err != nil {
errCh <- err
@@ -1624,11 +1662,10 @@ func (t *http2Client) reader(errCh chan<- error) {
t.closeStream(s, status.Error(code, msg), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false)
}
continue
} else {
// Transport error.
t.Close(connectionErrorf(true, err, "error reading from server: %v", err))
return
}
// Transport error.
errClose = connectionErrorf(true, err, "error reading from server: %v", err)
return
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
@@ -1642,7 +1679,7 @@ func (t *http2Client) reader(errCh chan<- error) {
case *http2.PingFrame:
t.handlePing(frame)
case *http2.GoAwayFrame:
t.handleGoAway(frame)
errClose = t.handleGoAway(frame)
case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame)
default:
@@ -1653,15 +1690,15 @@ func (t *http2Client) reader(errCh chan<- error) {
}
}
func minTime(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
// keepalive running in a separate goroutine makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() {
var err error
defer func() {
close(t.keepaliveDone)
if err != nil {
t.Close(err)
}
}()
p := &ping{data: [8]byte{}}
// True iff a ping has been sent, and no data has been received since then.
outstandingPing := false
@@ -1685,7 +1722,7 @@ func (t *http2Client) keepalive() {
continue
}
if outstandingPing && timeoutLeft <= 0 {
t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout"))
err = connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")
return
}
t.mu.Lock()
@@ -1727,7 +1764,7 @@ func (t *http2Client) keepalive() {
// timeoutLeft. This will ensure that we wait only for kp.Time
// before sending out the next ping (for cases where the ping is
// acked).
sleepDuration := minTime(t.kp.Time, timeoutLeft)
sleepDuration := min(t.kp.Time, timeoutLeft)
timeoutLeft -= sleepDuration
timer.Reset(sleepDuration)
case <-t.ctx.Done():
@@ -1756,14 +1793,18 @@ func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics {
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) IncrMsgSent() {
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Store(time.Now().UnixNano())
func (t *http2Client) incrMsgSent() {
if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Store(time.Now().UnixNano())
}
}
func (t *http2Client) IncrMsgRecv() {
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Store(time.Now().UnixNano())
func (t *http2Client) incrMsgRecv() {
if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Store(time.Now().UnixNano())
}
}
func (t *http2Client) getOutFlowWindow() int64 {

View File

@@ -25,6 +25,7 @@ import (
"fmt"
"io"
"math"
rand "math/rand/v2"
"net"
"net/http"
"strconv"
@@ -34,16 +35,17 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
@@ -110,7 +112,7 @@ type http2Server struct {
// already initialized since draining is already underway.
drainEvent *grpcsync.Event
state transportState
activeStreams map[uint32]*Stream
activeStreams map[uint32]*ServerStream
// idle is the time instant when the connection went idle.
// This is either the beginning of the connection or when the number of
// RPCs go down to 0.
@@ -119,7 +121,7 @@ type http2Server struct {
// Fields below are for channelz metric collection.
channelz *channelz.Socket
bufferPool *bufferPool
bufferPool mem.BufferPool
connectionID uint64
@@ -255,13 +257,13 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
activeStreams: make(map[uint32]*Stream),
activeStreams: make(map[uint32]*ServerStream),
stats: config.StatsHandlers,
kp: kp,
idle: time.Now(),
kep: kep,
initialWindowSize: iwz,
bufferPool: newBufferPool(),
bufferPool: config.BufferPool,
}
var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
@@ -330,8 +332,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.handleSettings(sf)
go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
err := t.loopy.run()
close(t.loopyWriterDone)
if !isIOError(err) {
@@ -359,7 +360,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*ServerStream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
@@ -385,11 +386,13 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
t.maxStreamID = streamID
buf := newRecvBuffer()
s := &Stream{
id: streamID,
s := &ServerStream{
Stream: &Stream{
id: streamID,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
},
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
}
var (
@@ -537,12 +540,6 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
// Attach the received metadata to the context.
if len(mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, mdata)
if statsTags := mdata["grpc-tags-bin"]; len(statsTags) > 0 {
s.ctx = stats.SetIncomingTags(s.ctx, []byte(statsTags[len(statsTags)-1]))
}
if statsTrace := mdata["grpc-trace-bin"]; len(statsTrace) > 0 {
s.ctx = stats.SetIncomingTrace(s.ctx, []byte(statsTrace[len(statsTrace)-1]))
}
}
t.mu.Lock()
if t.state != reachable {
@@ -568,7 +565,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
t.logger.Infof("Aborting the stream early: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: 405,
httpStatus: http.StatusMethodNotAllowed,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
@@ -589,7 +586,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
stat = status.New(codes.PermissionDenied, err.Error())
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: 200,
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
@@ -602,6 +599,22 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if len(t.activeStreams) == 1 {
t.idle = time.Time{}
}
// Start a timer to close the stream on reaching the deadline.
if timeoutSet {
// We need to wait for s.cancel to be updated before calling
// t.closeStream to avoid data races.
cancelUpdated := make(chan struct{})
timer := internal.TimeAfterFunc(timeout, func() {
<-cancelUpdated
t.closeStream(s, true, http2.ErrCodeCancel, false)
})
oldCancel := s.cancel
s.cancel = func() {
oldCancel()
timer.Stop()
}
close(cancelUpdated)
}
t.mu.Unlock()
if channelz.IsOn() {
t.channelz.SocketMetrics.StreamsStarted.Add(1)
@@ -614,10 +627,9 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
freeBuffer: t.bufferPool.put,
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
@@ -635,7 +647,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStream)) {
defer func() {
close(t.readerDone)
<-t.loopyWriterDone
@@ -699,7 +711,7 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
}
}
func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
func (t *http2Server) getStream(f http2.Frame) (*ServerStream, bool) {
t.mu.Lock()
defer t.mu.Unlock()
if t.activeStreams == nil {
@@ -717,7 +729,7 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
func (t *http2Server) adjustWindow(s *ServerStream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@@ -727,7 +739,7 @@ func (t *http2Server) adjustWindow(s *Stream, n uint32) {
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) {
func (t *http2Server) updateWindow(s *ServerStream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
increment: w,
@@ -814,10 +826,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
buffer := t.bufferPool.get()
buffer.Reset()
buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
}
}
if f.StreamEnded() {
@@ -860,7 +875,7 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
}
return nil
})
t.controlBuf.executeAndPut(func(any) bool {
t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs {
f()
}
@@ -961,7 +976,7 @@ func (t *http2Server) checkForHeaderListSize(it any) bool {
return true
}
func (t *http2Server) streamContextErr(s *Stream) error {
func (t *http2Server) streamContextErr(s *ServerStream) error {
select {
case <-t.done:
return ErrConnClosing
@@ -971,7 +986,7 @@ func (t *http2Server) streamContextErr(s *Stream) error {
}
// WriteHeader sends the header metadata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
func (t *http2Server) writeHeader(s *ServerStream, md metadata.MD) error {
s.hdrMu.Lock()
defer s.hdrMu.Unlock()
if s.getState() == streamDone {
@@ -1004,7 +1019,7 @@ func (t *http2Server) setResetPingStrikes() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
func (t *http2Server) writeHeaderLocked(s *Stream) error {
func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
@@ -1014,12 +1029,13 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{
hf := &headerFrame{
streamID: s.id,
hf: headerFields,
endStream: false,
onWrite: t.setResetPingStrikes,
})
}
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf)
if !success {
if err != nil {
return err
@@ -1043,7 +1059,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
// There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
s.hdrMu.Lock()
defer s.hdrMu.Unlock()
@@ -1089,7 +1105,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
onWrite: t.setResetPingStrikes,
}
success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader)
success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader)
}, nil)
if !success {
if err != nil {
return err
@@ -1112,27 +1130,38 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (t *http2Server) write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *WriteOptions) error {
reader := data.Reader()
if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
if err := t.writeHeader(s, nil); err != nil {
_ = reader.Close()
return err
}
} else {
// Writing headers checks for this condition.
if s.getState() == streamDone {
_ = reader.Close()
return t.streamContextErr(s)
}
}
df := &dataFrame{
streamID: s.id,
h: hdr,
d: data,
reader: reader,
onEachWrite: t.setResetPingStrikes,
}
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return t.streamContextErr(s)
}
return t.controlBuf.put(df)
if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
t.incrMsgSent()
return nil
}
// keepalive running in a separate goroutine does the following:
@@ -1208,7 +1237,7 @@ func (t *http2Server) keepalive() {
continue
}
if outstandingPing && kpTimeoutLeft <= 0 {
t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Time))
t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Timeout))
return
}
if !outstandingPing {
@@ -1223,7 +1252,7 @@ func (t *http2Server) keepalive() {
// timeoutLeft. This will ensure that we wait only for kp.Time
// before sending out the next ping (for cases where the ping is
// acked).
sleepDuration := minTime(t.kp.Time, kpTimeoutLeft)
sleepDuration := min(t.kp.Time, kpTimeoutLeft)
kpTimeoutLeft -= sleepDuration
kpTimer.Reset(sleepDuration)
case <-t.done:
@@ -1261,8 +1290,7 @@ func (t *http2Server) Close(err error) {
}
// deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok {
delete(t.activeStreams, s.id)
@@ -1282,7 +1310,7 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
}
// finishStream closes the stream and puts the trailing headerFrame into controlbuf.
func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
func (t *http2Server) finishStream(s *ServerStream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
@@ -1306,13 +1334,16 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h
}
// closeStream clears the footprint of a stream when the stream is not needed any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
s.swapState(streamDone)
oldState := s.swapState(streamDone)
if oldState == streamDone {
return
}
t.deleteStream(s, eosReceived)
t.controlBuf.put(&cleanupStream{
@@ -1400,14 +1431,18 @@ func (t *http2Server) socketMetrics() *channelz.EphemeralSocketMetrics {
}
}
func (t *http2Server) IncrMsgSent() {
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Add(1)
func (t *http2Server) incrMsgSent() {
if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesSent.Add(1)
t.channelz.SocketMetrics.LastMessageSentTimestamp.Add(1)
}
}
func (t *http2Server) IncrMsgRecv() {
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Add(1)
func (t *http2Server) incrMsgRecv() {
if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesReceived.Add(1)
t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Add(1)
}
}
func (t *http2Server) getOutFlowWindow() int64 {
@@ -1440,7 +1475,7 @@ func getJitter(v time.Duration) time.Duration {
}
// Generate a jitter between +/- 10% of the value.
r := int64(v / 10)
j := grpcrand.Int63n(2*r) - r
j := rand.Int64N(2*r) - r
return time.Duration(j)
}

View File

@@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
return w
}
func (w *bufWriter) Write(b []byte) (n int, err error) {
func (w *bufWriter) Write(b []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
if w.batchSize == 0 { // Buffer has been disabled.
n, err = w.conn.Write(b)
n, err := w.conn.Write(b)
return n, toIOError(err)
}
if w.buf == nil {
b := w.pool.Get().(*[]byte)
w.buf = *b
}
written := 0
for len(b) > 0 {
nn := copy(w.buf[w.offset:], b)
b = b[nn:]
w.offset += nn
n += nn
if w.offset >= w.batchSize {
err = w.flushKeepBuffer()
copied := copy(w.buf[w.offset:], b)
b = b[copied:]
written += copied
w.offset += copied
if w.offset < w.batchSize {
continue
}
if err := w.flushKeepBuffer(); err != nil {
return written, err
}
}
return n, err
return written, nil
}
func (w *bufWriter) Flush() error {
@@ -389,7 +393,7 @@ type framer struct {
fr *http2.Framer
}
var writeBufferPoolMap map[int]*sync.Pool = make(map[int]*sync.Pool)
var writeBufferPoolMap = make(map[int]*sync.Pool)
var writeBufferMutex sync.Mutex
func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
@@ -435,8 +439,8 @@ func getWriteBufferPool(size int) *sync.Pool {
return pool
}
// parseDialTarget returns the network and address to pass to dialer.
func parseDialTarget(target string) (string, string) {
// ParseDialTarget returns the network and address to pass to dialer.
func ParseDialTarget(target string) (string, string) {
net := "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")

View File

@@ -30,34 +30,16 @@ import (
"net/url"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/resolver"
)
const proxyAuthHeaderKey = "Proxy-Authorization"
var (
// The following variable will be overwritten in the tests.
httpProxyFromEnvironment = http.ProxyFromEnvironment
)
func mapAddress(address string) (*url.URL, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
Host: address,
},
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return nil, err
}
return url, nil
}
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
// It's possible that this reader reads more than what's need for the response and stores
// those bytes in the buffer.
// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
// bytes in the buffer.
// It's possible that this reader reads more than what's need for the response
// and stores those bytes in the buffer. bufConn wraps the original net.Conn
// and the bufio.Reader to make sure we don't lose the bytes in the buffer.
type bufConn struct {
net.Conn
r io.Reader
@@ -72,7 +54,7 @@ func basicAuth(username, password string) string {
return base64.StdEncoding.EncodeToString([]byte(auth))
}
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) {
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, grpcUA string, opts proxyattributes.Options) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
@@ -81,15 +63,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: backendAddr},
URL: &url.URL{Host: opts.ConnectAddr},
Header: map[string][]string{"User-Agent": {grpcUA}},
}
if t := proxyURL.User; t != nil {
u := t.Username()
p, _ := t.Password()
if user := opts.User; user != nil {
u := user.Username()
p, _ := user.Password()
req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
}
if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
}
@@ -107,32 +88,23 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}
return &bufConn{Conn: conn, r: r}, nil
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
}
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
// is necessary, dials, does the HTTP CONNECT handshake, and returns the
// connection.
func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) {
newAddr := addr
proxyURL, err := mapAddress(addr)
// proxyDial establishes a TCP connection to the specified address and performs an HTTP CONNECT handshake.
func proxyDial(ctx context.Context, addr resolver.Address, grpcUA string, opts proxyattributes.Options) (net.Conn, error) {
conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", addr.Addr)
if err != nil {
return nil, err
}
if proxyURL != nil {
newAddr = proxyURL.Host
}
conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr)
if err != nil {
return nil, err
}
if proxyURL == nil {
// proxy is disabled if proxyURL is nil.
return conn, err
}
return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA)
return doHTTPConnectHandshake(ctx, conn, grpcUA, opts)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {

View File

@@ -0,0 +1,180 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// ServerStream implements streaming functionality for a gRPC server.
type ServerStream struct {
*Stream // Embed for common stream functionality.
st internalServerTransport
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
// cancel is invoked at the end of stream to cancel ctx. It also stops the
// timer for monitoring the rpc deadline if configured.
cancel func()
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client.
clientAdvertisedCompressors string
headerWireLength int
// hdrMu protects outgoing header and trailer metadata.
hdrMu sync.Mutex
header metadata.MD // the outgoing header metadata. Updated by WriteHeader.
headerSent atomic.Bool // atomically set when the headers are sent out.
}
// Read reads an n byte message from the input stream.
func (s *ServerStream) Read(n int) (mem.BufferSlice, error) {
b, err := s.Stream.read(n)
if err == nil {
s.st.incrMsgRecv()
}
return b, err
}
// SendHeader sends the header metadata for the given stream.
func (s *ServerStream) SendHeader(md metadata.MD) error {
return s.st.writeHeader(s, md)
}
// Write writes the hdr and data bytes to the output stream.
func (s *ServerStream) Write(hdr []byte, data mem.BufferSlice, opts *WriteOptions) error {
return s.st.write(s, hdr, data, opts)
}
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
func (s *ServerStream) WriteStatus(st *status.Status) error {
return s.st.writeStatus(s, st)
}
// isHeaderSent indicates whether headers have been sent.
func (s *ServerStream) isHeaderSent() bool {
return s.headerSent.Load()
}
// updateHeaderSent updates headerSent and returns true
// if it was already set.
func (s *ServerStream) updateHeaderSent() bool {
return s.headerSent.Swap(true)
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *ServerStream) RecvCompress() string {
return s.recvCompress
}
// SendCompress returns the send compressor name.
func (s *ServerStream) SendCompress() string {
return s.sendCompress
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *ServerStream) ContentSubtype() string {
return s.contentSubtype
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *ServerStream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}
s.sendCompress = name
return nil
}
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *ServerStream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *ServerStream) ClientAdvertisedCompressors() []string {
values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
}
// Header returns the header metadata of the stream. It returns the out header
// after t.WriteHeader is called. It does not block and must not be called
// until after WriteHeader.
func (s *ServerStream) Header() (metadata.MD, error) {
// Return the header in stream. It will be the out
// header after t.WriteHeader is called.
return s.header.Copy(), nil
}
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire.
func (s *ServerStream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times.
// This should not be called in parallel to other data writes.
func (s *ServerStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
s.hdrMu.Unlock()
return nil
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times.
// This should not be called parallel to other data writes.
func (s *ServerStream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.hdrMu.Unlock()
return nil
}

View File

@@ -22,13 +22,11 @@
package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -37,9 +35,9 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
@@ -47,32 +45,10 @@ import (
const logLevel = 2
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
buffer *bytes.Buffer
buffer mem.Buffer
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
@@ -102,6 +78,9 @@ func newRecvBuffer() *recvBuffer {
func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock()
if b.err != nil {
// drop the buffer on the floor. Since b.err is not nil, any subsequent reads
// will always return an error, making this buffer inaccessible.
r.buffer.Free()
b.mu.Unlock()
// An error had occurred earlier, don't accept more
// data or errors.
@@ -148,45 +127,70 @@ type recvBufferReader struct {
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last *bytes.Buffer // Stores the remaining data in the previous calls.
last mem.Buffer // Stores the remaining data in the previous calls.
err error
freeBuffer func(*bytes.Buffer)
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
if r.last != nil {
// Read remaining data left in last call.
copied, _ := r.last.Read(p)
if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil
n, r.last = mem.ReadUnsafe(header, r.last)
return n, nil
}
if r.closeStream != nil {
n, r.err = r.readClient(p)
n, r.err = r.readMessageHeaderClient(header)
} else {
n, r.err = r.read(p)
n, r.err = r.readMessageHeader(header)
}
return n, r.err
}
func (r *recvBufferReader) read(p []byte) (n int, err error) {
// Read reads the next n bytes from last. If last is drained, it tries to read
// additional data from recv. It blocks if there no additional data available in
// recv. If Read returns any non-nil error, it will continue to return that
// error.
func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
if r.err != nil {
return nil, r.err
}
if r.last != nil {
buf = r.last
if r.last.Len() > n {
buf, r.last = mem.SplitUnsafe(buf, n)
} else {
r.last = nil
}
return buf, nil
}
if r.closeStream != nil {
buf, r.err = r.readClient(n)
} else {
buf, r.err = r.read(n)
}
return buf, r.err
}
func (r *recvBufferReader) readMessageHeader(header []byte) (n int, err error) {
select {
case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
return r.readAdditional(m, p)
return r.readMessageHeaderAdditional(m, header)
}
}
func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
func (r *recvBufferReader) read(n int) (buf mem.Buffer, err error) {
select {
case <-r.ctxDone:
return nil, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readMessageHeaderClient(header []byte) (n int, err error) {
// If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error.
@@ -207,25 +211,67 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, p)
return r.readMessageHeaderAdditional(m, header)
case m := <-r.recv.get():
return r.readAdditional(m, p)
return r.readMessageHeaderAdditional(m, header)
}
}
func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) {
func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error.
select {
case <-r.ctxDone:
// Note that this adds the ctx error to the end of recv buffer, and
// reads from the head. This will delay the error until recv buffer is
// empty, thus will delay ctx cancellation in Recv().
//
// It's done this way to fix a race between ctx cancel and trailer. The
// race was, stream.Recv() may return ctx error if ctxDone wins the
// race, but stream.Trailer() may return a non-nil md because the stream
// was not marked as done when trailer is received. This closeStream
// call will mark stream as done, thus fix the race.
//
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, n)
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readMessageHeaderAdditional(m recvMsg, header []byte) (n int, err error) {
r.recv.load()
if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
}
return 0, m.err
}
copied, _ := m.buffer.Read(p)
if m.buffer.Len() == 0 {
r.freeBuffer(m.buffer)
r.last = nil
} else {
r.last = m.buffer
n, r.last = mem.ReadUnsafe(header, m.buffer)
return n, nil
}
func (r *recvBufferReader) readAdditional(m recvMsg, n int) (b mem.Buffer, err error) {
r.recv.load()
if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
}
return nil, m.err
}
return copied, nil
if m.buffer.Len() > n {
m.buffer, r.last = mem.SplitUnsafe(m.buffer, n)
}
return m.buffer, nil
}
type streamState uint32
@@ -240,73 +286,26 @@ const (
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ct *http2Client // nil for server side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
doneFunc func() // invoked at the end of stream on client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
ctx context.Context // the associated context of the stream
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
trReader io.Reader
trReader *transportReader
fc *inFlow
wq *writeQuota
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). Not valid on server side.
headerValid bool
headerWireLength int // Only set on server side.
// hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex
// On client side, header keeps the received header metadata.
//
// On server side, header keeps the header set by SetHeader(). The complete
// header will merged into this after t.WriteHeader() is called.
header metadata.MD
trailer metadata.MD // the key-value map of trailer metadata.
noHeaders bool // set if the client never received headers (set only after the stream is done).
// On the server-side, headerSent is atomically set to 1 when the headers are sent out.
headerSent uint32
state streamState
// On client-side it is the status error received from the server.
// On server-side it is unused.
status *status.Status
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
// isHeaderSent is only valid on the server-side.
func (s *Stream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
}
// updateHeaderSent updates headerSent and returns true
// if it was alreay set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
trailer metadata.MD // the key-value map of trailer metadata.
}
func (s *Stream) swapState(st streamState) streamState {
@@ -321,110 +320,12 @@ func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return
}
select {
case <-s.ctx.Done():
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
s.waitOnHeader()
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}
s.sendCompress = name
return nil
}
// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() []string {
values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
}
// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header returns the header metadata of the stream.
//
// On client side, it acquires the key-value pairs of header metadata once it is
// available. It blocks until i) the metadata is ready or ii) there is no header
// metadata or iii) the stream is canceled/expired.
//
// On server side, it returns the out header after t.WriteHeader is called. It
// does not block and must not be called until after WriteHeader.
func (s *Stream) Header() (metadata.MD, error) {
if s.headerChan == nil {
// On server side, return the header in stream. It will be the out
// header after t.WriteHeader is called.
return s.header.Copy(), nil
}
s.waitOnHeader()
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}
return s.header.Copy(), nil
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil. Client-side only.
func (s *Stream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// Trailer returns the cached trailer metadata. Note that if it is not called
// after the entire stream is done, it could return an empty MD.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
return s.trailer.Copy()
}
// Context returns the context of the stream.
@@ -432,114 +333,104 @@ func (s *Stream) Context() context.Context {
return s.ctx
}
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *Stream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *Stream) Status() *status.Status {
return s.status
}
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire. Valid only on the server.
func (s *Stream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
s.hdrMu.Unlock()
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
return s.st.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.hdrMu.Unlock()
return nil
}
func (s *Stream) write(m recvMsg) {
s.buf.put(m)
}
// Read reads all p bytes from the wire for this stream.
func (s *Stream) Read(p []byte) (n int, err error) {
// ReadMessageHeader reads data into the provided header slice from the stream.
// It first checks if there was an error during a previous read operation and
// returns it if present. It then requests a read operation for the length of
// the header. It continues to read from the stream until the entire header
// slice is filled or an error occurs. If an `io.EOF` error is encountered with
// partially read data, it is converted to `io.ErrUnexpectedEOF` to indicate an
// unexpected end of the stream. The method returns any error encountered during
// the read process or nil if the header was successfully read.
func (s *Stream) ReadMessageHeader(header []byte) (err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil {
return 0, er
if er := s.trReader.er; er != nil {
return er
}
s.requestRead(len(p))
return io.ReadFull(s.trReader, p)
s.requestRead(len(header))
for len(header) != 0 {
n, err := s.trReader.ReadMessageHeader(header)
header = header[n:]
if len(header) == 0 {
err = nil
}
if err != nil {
if n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
}
return nil
}
// tranportReader reads all the data available for this Stream from the transport and
// Read reads n bytes from the wire for this stream.
func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil {
return nil, er
}
s.requestRead(n)
for n != 0 {
buf, err := s.trReader.Read(n)
var bufLen int
if buf != nil {
bufLen = buf.Len()
}
n -= bufLen
if n == 0 {
err = nil
}
if err != nil {
if bufLen > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
data.Free()
return nil, err
}
data = append(data, buf)
}
return data, nil
}
// transportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
type transportReader struct {
reader io.Reader
reader *recvBufferReader
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
er error
}
func (t *transportReader) Read(p []byte) (n int, err error) {
n, err = t.reader.Read(p)
func (t *transportReader) ReadMessageHeader(header []byte) (int, error) {
n, err := t.reader.ReadMessageHeader(header)
if err != nil {
t.er = err
return
return 0, err
}
t.windowHandler(n)
return
return n, nil
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
func (t *transportReader) Read(n int) (mem.Buffer, error) {
buf, err := t.reader.Read(n)
if err != nil {
t.er = err
return buf, err
}
t.windowHandler(buf.Len())
return buf, nil
}
// GoString is implemented by Stream so context.String() won't
@@ -574,6 +465,7 @@ type ServerConfig struct {
ChannelzParent *channelz.Server
MaxHeaderListSize *uint32
HeaderTableSize *uint32
BufferPool mem.BufferPool
}
// ConnectOptions covers all relevant options for communicating with the server.
@@ -610,19 +502,13 @@ type ConnectOptions struct {
ChannelzParent *channelz.SubChannel
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used.
UseProxy bool
// The mem.BufferPool to use when reading/writing to the wire.
BufferPool mem.BufferPool
}
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, addr, opts, onClose)
}
// Options provides additional hints and information for message
// WriteOptions provides additional hints and information for message
// transmission.
type Options struct {
type WriteOptions struct {
// Last indicates whether this write is the last piece for
// this stream.
Last bool
@@ -671,18 +557,8 @@ type ClientTransport interface {
// It does not block.
GracefulClose()
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
// CloseStream clears the footprint of a stream when the stream is
// not needed any more. The err indicates the error incurred when
// CloseStream is called. Must be called when a stream is finished
// unless the associated transport is closing.
CloseStream(stream *Stream, err error)
NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error)
// Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor
@@ -702,12 +578,6 @@ type ClientTransport interface {
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
}
// ServerTransport is the common interface for all gRPC server-side transport
@@ -717,19 +587,7 @@ type ClientTransport interface {
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler.
HandleStreams(context.Context, func(*Stream))
// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
WriteHeader(s *Stream, md metadata.MD) error
// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
WriteStatus(s *Stream, st *status.Status) error
HandleStreams(context.Context, func(*ServerStream))
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
@@ -741,12 +599,14 @@ type ServerTransport interface {
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain(debugData string)
}
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
type internalServerTransport interface {
ServerTransport
writeHeader(s *ServerStream, md metadata.MD) error
write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error
writeStatus(s *ServerStream, st *status.Status) error
incrMsgRecv()
}
// connectionErrorf creates an ConnectionError with the specified error description.
@@ -798,7 +658,7 @@ var (
// connection is draining. This could be caused by goaway or balancer
// removing the address.
errStreamDrain = status.Error(codes.Unavailable, "the connection is draining")
// errStreamDone is returned from write at the client side to indiacte application
// errStreamDone is returned from write at the client side to indicate application
// layer of an error.
errStreamDone = errors.New("the stream is done")
// StatusGoAway indicates that the server sent a GOAWAY that included this

View File

@@ -34,15 +34,29 @@ type ClientParameters struct {
// After a duration of this time if the client doesn't see any activity it
// pings the server to see if the transport is still alive.
// If set below 10s, a minimum value of 10s will be used instead.
Time time.Duration // The current default value is infinity.
//
// Note that gRPC servers have a default EnforcementPolicy.MinTime of 5
// minutes (which means the client shouldn't ping more frequently than every
// 5 minutes).
//
// Though not ideal, it's not a strong requirement for Time to be less than
// EnforcementPolicy.MinTime. Time will automatically double if the server
// disconnects due to its enforcement policy.
//
// For more details, see
// https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md
Time time.Duration
// After having pinged for keepalive check, the client waits for a duration
// of Timeout and if no activity is seen even after that the connection is
// closed.
Timeout time.Duration // The current default value is 20 seconds.
//
// If keepalive is enabled, and this value is not explicitly set, the default
// is 20 seconds.
Timeout time.Duration
// If true, client sends keepalive pings even with no active RPCs. If false,
// when there are no active RPCs, Time and Timeout will be ignored and no
// keepalive pings will be sent.
PermitWithoutStream bool // false by default.
PermitWithoutStream bool
}
// ServerParameters is used to set keepalive and max-age parameters on the

194
vendor/google.golang.org/grpc/mem/buffer_pool.go generated vendored Normal file
View File

@@ -0,0 +1,194 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"sort"
"sync"
"google.golang.org/grpc/internal"
)
// BufferPool is a pool of buffers that can be shared and reused, resulting in
// decreased memory allocation.
type BufferPool interface {
// Get returns a buffer with specified length from the pool.
Get(length int) *[]byte
// Put returns a buffer to the pool.
Put(*[]byte)
}
var defaultBufferPoolSizes = []int{
256,
4 << 10, // 4KB (go page size)
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB
}
var defaultBufferPool BufferPool
func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
defaultBufferPool = pool
}
internal.SetBufferPoolingThresholdForTesting = func(threshold int) {
bufferPoolingThreshold = threshold
}
}
// DefaultBufferPool returns the current default buffer pool. It is a BufferPool
// created with NewBufferPool that uses a set of default sizes optimized for
// expected workflows.
func DefaultBufferPool() BufferPool {
return defaultBufferPool
}
// NewTieredBufferPool returns a BufferPool implementation that uses multiple
// underlying pools of the given pool sizes.
func NewTieredBufferPool(poolSizes ...int) BufferPool {
sort.Ints(poolSizes)
pools := make([]*sizedBufferPool, len(poolSizes))
for i, s := range poolSizes {
pools[i] = newSizedBufferPool(s)
}
return &tieredBufferPool{
sizedPools: pools,
}
}
// tieredBufferPool implements the BufferPool interface with multiple tiers of
// buffer pools for different sizes of buffers.
type tieredBufferPool struct {
sizedPools []*sizedBufferPool
fallbackPool simpleBufferPool
}
func (p *tieredBufferPool) Get(size int) *[]byte {
return p.getPool(size).Get(size)
}
func (p *tieredBufferPool) Put(buf *[]byte) {
p.getPool(cap(*buf)).Put(buf)
}
func (p *tieredBufferPool) getPool(size int) BufferPool {
poolIdx := sort.Search(len(p.sizedPools), func(i int) bool {
return p.sizedPools[i].defaultSize >= size
})
if poolIdx == len(p.sizedPools) {
return &p.fallbackPool
}
return p.sizedPools[poolIdx]
}
// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
// capacity of 16kb. Note that however it does not support returning larger
// buffers and in fact panics if such a buffer is requested. Because of this,
// this BufferPool implementation is not meant to be used on its own and rather
// is intended to be embedded in a tieredBufferPool such that Get is only
// invoked when the required size is smaller than or equal to defaultSize.
type sizedBufferPool struct {
pool sync.Pool
defaultSize int
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf := p.pool.Get().(*[]byte)
b := *buf
clear(b[:cap(b)])
*buf = b[:size]
return buf
}
func (p *sizedBufferPool) Put(buf *[]byte) {
if cap(*buf) < p.defaultSize {
// Ignore buffers that are too small to fit in the pool. Otherwise, when
// Get is called it will panic as it tries to index outside the bounds
// of the buffer.
return
}
p.pool.Put(buf)
}
func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{
pool: sync.Pool{
New: func() any {
buf := make([]byte, size)
return &buf
},
},
defaultSize: size,
}
}
var _ BufferPool = (*simpleBufferPool)(nil)
// simpleBufferPool is an implementation of the BufferPool interface that
// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to
// acquire a buffer from the pool but if that buffer is too small, it returns it
// to the pool and creates a new one.
type simpleBufferPool struct {
pool sync.Pool
}
func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
*bs = (*bs)[:size]
return bs
}
// A buffer was pulled from the pool, but it is too small. Put it back in
// the pool and create one large enough.
if ok {
p.pool.Put(bs)
}
b := make([]byte, size)
return &b
}
func (p *simpleBufferPool) Put(buf *[]byte) {
p.pool.Put(buf)
}
var _ BufferPool = NopBufferPool{}
// NopBufferPool is a buffer pool that returns new buffers without pooling.
type NopBufferPool struct{}
// Get returns a buffer with specified length from the pool.
func (NopBufferPool) Get(length int) *[]byte {
b := make([]byte, length)
return &b
}
// Put returns a buffer to the pool.
func (NopBufferPool) Put(*[]byte) {
}

281
vendor/google.golang.org/grpc/mem/buffer_slice.go generated vendored Normal file
View File

@@ -0,0 +1,281 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"io"
)
const (
// 32 KiB is what io.Copy uses.
readAllBufSize = 32 * 1024
)
// BufferSlice offers a means to represent data that spans one or more Buffer
// instances. A BufferSlice is meant to be immutable after creation, and methods
// like Ref create and return copies of the slice. This is why all methods have
// value receivers rather than pointer receivers.
//
// Note that any of the methods that read the underlying buffers such as Ref,
// Len or CopyTo etc., will panic if any underlying buffers have already been
// freed. It is recommended to not directly interact with any of the underlying
// buffers directly, rather such interactions should be mediated through the
// various methods on this type.
//
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
// the burden on the caller by never returning a mem.BufferSlice that needs to
// be freed if the error is non-nil, unless explicitly stated.
type BufferSlice []Buffer
// Len returns the sum of the length of all the Buffers in this slice.
//
// # Warning
//
// Invoking the built-in len on a BufferSlice will return the number of buffers
// in the slice, and *not* the value returned by this function.
func (s BufferSlice) Len() int {
var length int
for _, b := range s {
length += b.Len()
}
return length
}
// Ref invokes Ref on each buffer in the slice.
func (s BufferSlice) Ref() {
for _, b := range s {
b.Ref()
}
}
// Free invokes Buffer.Free() on each Buffer in the slice.
func (s BufferSlice) Free() {
for _, b := range s {
b.Free()
}
}
// CopyTo copies each of the underlying Buffer's data into the given buffer,
// returning the number of bytes copied. Has the same semantics as the copy
// builtin in that it will copy as many bytes as it can, stopping when either dst
// is full or s runs out of data, returning the minimum of s.Len() and len(dst).
func (s BufferSlice) CopyTo(dst []byte) int {
off := 0
for _, b := range s {
off += copy(dst[off:], b.ReadOnlyData())
}
return off
}
// Materialize concatenates all the underlying Buffer's data into a single
// contiguous buffer using CopyTo.
func (s BufferSlice) Materialize() []byte {
l := s.Len()
if l == 0 {
return nil
}
out := make([]byte, l)
s.CopyTo(out)
return out
}
// MaterializeToBuffer functions like Materialize except that it writes the data
// to a single Buffer pulled from the given BufferPool.
//
// As a special case, if the input BufferSlice only actually has one Buffer, this
// function simply increases the refcount before returning said Buffer. Freeing this
// buffer won't release it until the BufferSlice is itself released.
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
if len(s) == 1 {
s[0].Ref()
return s[0]
}
sLen := s.Len()
if sLen == 0 {
return emptyBuffer{}
}
buf := pool.Get(sLen)
s.CopyTo(*buf)
return NewBuffer(buf, pool)
}
// Reader returns a new Reader for the input slice after taking references to
// each underlying buffer.
func (s BufferSlice) Reader() Reader {
s.Ref()
return &sliceReader{
data: s,
len: s.Len(),
}
}
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
// with other parts systems. It also provides an additional convenience method
// Remaining(), which returns the number of unread bytes remaining in the slice.
// Buffers will be freed as they are read.
type Reader interface {
io.Reader
io.ByteReader
// Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
Close() error
// Remaining returns the number of unread bytes remaining in the slice.
Remaining() int
}
type sliceReader struct {
data BufferSlice
len int
// The index into data[0].ReadOnlyData().
bufferIdx int
}
func (r *sliceReader) Remaining() int {
return r.len
}
func (r *sliceReader) Close() error {
r.data.Free()
r.data = nil
r.len = 0
return nil
}
func (r *sliceReader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
return false
}
r.data[0].Free()
r.data = r.data[1:]
r.bufferIdx = 0
return true
}
func (r *sliceReader) Read(buf []byte) (n int, _ error) {
if r.len == 0 {
return 0, io.EOF
}
for len(buf) != 0 && r.len != 0 {
// Copy as much as possible from the first Buffer in the slice into the
// given byte slice.
data := r.data[0].ReadOnlyData()
copied := copy(buf, data[r.bufferIdx:])
r.len -= copied // Reduce len by the number of bytes copied.
r.bufferIdx += copied // Increment the buffer index.
n += copied // Increment the total number of bytes read.
buf = buf[copied:] // Shrink the given byte slice.
// If we have copied all the data from the first Buffer, free it and advance to
// the next in the slice.
r.freeFirstBufferIfEmpty()
}
return n, nil
}
func (r *sliceReader) ReadByte() (byte, error) {
if r.len == 0 {
return 0, io.EOF
}
// There may be any number of empty buffers in the slice, clear them all until a
// non-empty buffer is reached. This is guaranteed to exit since r.len is not 0.
for r.freeFirstBufferIfEmpty() {
}
b := r.data[0].ReadOnlyData()[r.bufferIdx]
r.len--
r.bufferIdx++
// Free the first buffer in the slice if the last byte was read
r.freeFirstBufferIfEmpty()
return b, nil
}
var _ io.Writer = (*writer)(nil)
type writer struct {
buffers *BufferSlice
pool BufferPool
}
func (w *writer) Write(p []byte) (n int, err error) {
b := Copy(p, w.pool)
*w.buffers = append(*w.buffers, b)
return b.Len(), nil
}
// NewWriter wraps the given BufferSlice and BufferPool to implement the
// io.Writer interface. Every call to Write copies the contents of the given
// buffer into a new Buffer pulled from the given pool and the Buffer is
// added to the given BufferSlice.
func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer {
return &writer{buffers: buffers, pool: pool}
}
// ReadAll reads from r until an error or EOF and returns the data it read.
// A successful call returns err == nil, not err == EOF. Because ReadAll is
// defined to read from src until EOF, it does not treat an EOF from Read
// as an error to be reported.
//
// Important: A failed call returns a non-nil error and may also return
// partially read buffers. It is the responsibility of the caller to free the
// BufferSlice returned, or its memory will not be reused.
func ReadAll(r io.Reader, pool BufferPool) (BufferSlice, error) {
var result BufferSlice
if wt, ok := r.(io.WriterTo); ok {
// This is more optimal since wt knows the size of chunks it wants to
// write and, hence, we can allocate buffers of an optimal size to fit
// them. E.g. might be a single big chunk, and we wouldn't chop it
// into pieces.
w := NewWriter(&result, pool)
_, err := wt.WriteTo(w)
return result, err
}
nextBuffer:
for {
buf := pool.Get(readAllBufSize)
// We asked for 32KiB but may have been given a bigger buffer.
// Use all of it if that's the case.
*buf = (*buf)[:cap(*buf)]
usedCap := 0
for {
n, err := r.Read((*buf)[usedCap:])
usedCap += n
if err != nil {
if usedCap == 0 {
// Nothing in this buf, put it back
pool.Put(buf)
} else {
*buf = (*buf)[:usedCap]
result = append(result, NewBuffer(buf, pool))
}
if err == io.EOF {
err = nil
}
return result, err
}
if len(*buf) == usedCap {
result = append(result, NewBuffer(buf, pool))
continue nextBuffer
}
}
}
}

268
vendor/google.golang.org/grpc/mem/buffers.go generated vendored Normal file
View File

@@ -0,0 +1,268 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package mem provides utilities that facilitate memory reuse in byte slices
// that are used as buffers.
//
// # Experimental
//
// Notice: All APIs in this package are EXPERIMENTAL and may be changed or
// removed in a later release.
package mem
import (
"fmt"
"sync"
"sync/atomic"
)
// A Buffer represents a reference counted piece of data (in bytes) that can be
// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be
// released by calling Free(), which invokes the free function given at creation
// only after all references are released.
//
// Note that a Buffer is not safe for concurrent access and instead each
// goroutine should use its own reference to the data, which can be acquired via
// a call to Ref().
//
// Attempts to access the underlying data after releasing the reference to the
// Buffer will panic.
type Buffer interface {
// ReadOnlyData returns the underlying byte slice. Note that it is undefined
// behavior to modify the contents of this slice in any way.
ReadOnlyData() []byte
// Ref increases the reference counter for this Buffer.
Ref()
// Free decrements this Buffer's reference counter and frees the underlying
// byte slice if the counter reaches 0 as a result of this call.
Free()
// Len returns the Buffer's size.
Len() int
split(n int) (left, right Buffer)
read(buf []byte) (int, Buffer)
}
var (
bufferPoolingThreshold = 1 << 10
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
)
// IsBelowBufferPoolingThreshold returns true if the given size is less than or
// equal to the threshold for buffer pooling. This is used to determine whether
// to pool buffers or allocate them directly.
func IsBelowBufferPoolingThreshold(size int) bool {
return size <= bufferPoolingThreshold
}
type buffer struct {
origData *[]byte
data []byte
refs *atomic.Int32
pool BufferPool
}
func newBuffer() *buffer {
return bufferObjectPool.Get().(*buffer)
}
// NewBuffer creates a new Buffer from the given data, initializing the reference
// counter to 1. The data will then be returned to the given pool when all
// references to the returned Buffer are released. As a special case to avoid
// additional allocations, if the given buffer pool is nil, the returned buffer
// will be a "no-op" Buffer where invoking Buffer.Free() does nothing and the
// underlying data is never freed.
//
// Note that the backing array of the given data is not copied.
func NewBuffer(data *[]byte, pool BufferPool) Buffer {
// Use the buffer's capacity instead of the length, otherwise buffers may
// not be reused under certain conditions. For example, if a large buffer
// is acquired from the pool, but fewer bytes than the buffering threshold
// are written to it, the buffer will not be returned to the pool.
if pool == nil || IsBelowBufferPoolingThreshold(cap(*data)) {
return (SliceBuffer)(*data)
}
b := newBuffer()
b.origData = data
b.data = *data
b.pool = pool
b.refs = refObjectPool.Get().(*atomic.Int32)
b.refs.Add(1)
return b
}
// Copy creates a new Buffer from the given data, initializing the reference
// counter to 1.
//
// It acquires a []byte from the given pool and copies over the backing array
// of the given data. The []byte acquired from the pool is returned to the
// pool when all references to the returned Buffer are released.
func Copy(data []byte, pool BufferPool) Buffer {
if IsBelowBufferPoolingThreshold(len(data)) {
buf := make(SliceBuffer, len(data))
copy(buf, data)
return buf
}
buf := pool.Get(len(data))
copy(*buf, data)
return NewBuffer(buf, pool)
}
func (b *buffer) ReadOnlyData() []byte {
if b.refs == nil {
panic("Cannot read freed buffer")
}
return b.data
}
func (b *buffer) Ref() {
if b.refs == nil {
panic("Cannot ref freed buffer")
}
b.refs.Add(1)
}
func (b *buffer) Free() {
if b.refs == nil {
panic("Cannot free freed buffer")
}
refs := b.refs.Add(-1)
switch {
case refs > 0:
return
case refs == 0:
if b.pool != nil {
b.pool.Put(b.origData)
}
refObjectPool.Put(b.refs)
b.origData = nil
b.data = nil
b.refs = nil
b.pool = nil
bufferObjectPool.Put(b)
default:
panic("Cannot free freed buffer")
}
}
func (b *buffer) Len() int {
return len(b.ReadOnlyData())
}
func (b *buffer) split(n int) (Buffer, Buffer) {
if b.refs == nil {
panic("Cannot split freed buffer")
}
b.refs.Add(1)
split := newBuffer()
split.origData = b.origData
split.data = b.data[n:]
split.refs = b.refs
split.pool = b.pool
b.data = b.data[:n]
return b, split
}
func (b *buffer) read(buf []byte) (int, Buffer) {
if b.refs == nil {
panic("Cannot read freed buffer")
}
n := copy(buf, b.data)
if n == len(b.data) {
b.Free()
return n, nil
}
b.data = b.data[n:]
return n, b
}
func (b *buffer) String() string {
return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData()))
}
// ReadUnsafe reads bytes from the given Buffer into the provided slice.
// It does not perform safety checks.
func ReadUnsafe(dst []byte, buf Buffer) (int, Buffer) {
return buf.read(dst)
}
// SplitUnsafe modifies the receiver to point to the first n bytes while it
// returns a new reference to the remaining bytes. The returned Buffer
// functions just like a normal reference acquired using Ref().
func SplitUnsafe(buf Buffer, n int) (left, right Buffer) {
return buf.split(n)
}
type emptyBuffer struct{}
func (e emptyBuffer) ReadOnlyData() []byte {
return nil
}
func (e emptyBuffer) Ref() {}
func (e emptyBuffer) Free() {}
func (e emptyBuffer) Len() int {
return 0
}
func (e emptyBuffer) split(int) (left, right Buffer) {
return e, e
}
func (e emptyBuffer) read([]byte) (int, Buffer) {
return 0, e
}
// SliceBuffer is a Buffer implementation that wraps a byte slice. It provides
// methods for reading, splitting, and managing the byte slice.
type SliceBuffer []byte
// ReadOnlyData returns the byte slice.
func (s SliceBuffer) ReadOnlyData() []byte { return s }
// Ref is a noop implementation of Ref.
func (s SliceBuffer) Ref() {}
// Free is a noop implementation of Free.
func (s SliceBuffer) Free() {}
// Len is a noop implementation of Len.
func (s SliceBuffer) Len() int { return len(s) }
func (s SliceBuffer) split(n int) (left, right Buffer) {
return s[:n], s[n:]
}
func (s SliceBuffer) read(buf []byte) (int, Buffer) {
n := copy(buf, s)
if n == len(s) {
return n, nil
}
return n, s[n:]
}

View File

@@ -213,11 +213,6 @@ func FromIncomingContext(ctx context.Context) (MD, bool) {
// ValueFromIncomingContext returns the metadata value corresponding to the metadata
// key from the incoming metadata if it exists. Keys are matched in a case insensitive
// manner.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ValueFromIncomingContext(ctx context.Context, key string) []string {
md, ok := ctx.Value(mdIncomingKey{}).(MD)
if !ok {
@@ -228,7 +223,7 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string {
return copyOf(v)
}
for k, v := range md {
// Case insenitive comparison: MD is a map, and there's no guarantee
// Case insensitive comparison: MD is a map, and there's no guarantee
// that the MD attached to the context is created using our helper
// functions.
if strings.EqualFold(k, key) {

View File

@@ -22,7 +22,9 @@ package peer
import (
"context"
"fmt"
"net"
"strings"
"google.golang.org/grpc/credentials"
)
@@ -39,6 +41,34 @@ type Peer struct {
AuthInfo credentials.AuthInfo
}
// String ensures the Peer types implements the Stringer interface in order to
// allow to print a context with a peerKey value effectively.
func (p *Peer) String() string {
if p == nil {
return "Peer<nil>"
}
sb := &strings.Builder{}
sb.WriteString("Peer{")
if p.Addr != nil {
fmt.Fprintf(sb, "Addr: '%s', ", p.Addr.String())
} else {
fmt.Fprintf(sb, "Addr: <nil>, ")
}
if p.LocalAddr != nil {
fmt.Fprintf(sb, "LocalAddr: '%s', ", p.LocalAddr.String())
} else {
fmt.Fprintf(sb, "LocalAddr: <nil>, ")
}
if p.AuthInfo != nil {
fmt.Fprintf(sb, "AuthInfo: '%s'", p.AuthInfo.AuthType())
} else {
fmt.Fprintf(sb, "AuthInfo: <nil>")
}
sb.WriteString("}")
return sb.String()
}
type peerKey struct{}
// NewContext creates a new context with peer information attached.

View File

@@ -20,8 +20,9 @@ package grpc
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
@@ -32,35 +33,43 @@ import (
"google.golang.org/grpc/status"
)
// pickerGeneration stores a picker and a channel used to signal that a picker
// newer than this one is available.
type pickerGeneration struct {
// picker is the picker produced by the LB policy. May be nil if a picker
// has never been produced.
picker balancer.Picker
// blockingCh is closed when the picker has been invalidated because there
// is a new one available.
blockingCh chan struct{}
}
// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick
// actions and unblock when there's a picker update.
type pickerWrapper struct {
mu sync.Mutex
done bool
blockingCh chan struct{}
picker balancer.Picker
// If pickerGen holds a nil pointer, the pickerWrapper is closed.
pickerGen atomic.Pointer[pickerGeneration]
statsHandlers []stats.Handler // to record blocking picker calls
}
func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper {
return &pickerWrapper{
blockingCh: make(chan struct{}),
pw := &pickerWrapper{
statsHandlers: statsHandlers,
}
pw.pickerGen.Store(&pickerGeneration{
blockingCh: make(chan struct{}),
})
return pw
}
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick.
// updatePicker is called by UpdateState calls from the LB policy. It
// unblocks all blocked pick.
func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
return
}
pw.picker = p
// pw.blockingCh should never be nil.
close(pw.blockingCh)
pw.blockingCh = make(chan struct{})
pw.mu.Unlock()
old := pw.pickerGen.Swap(&pickerGeneration{
picker: p,
blockingCh: make(chan struct{}),
})
close(old.blockingCh)
}
// doneChannelzWrapper performs the following:
@@ -97,27 +106,24 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
var lastPickErr error
for {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
pg := pw.pickerGen.Load()
if pg == nil {
return nil, balancer.PickResult{}, ErrClientConnClosing
}
if pw.picker == nil {
ch = pw.blockingCh
if pg.picker == nil {
ch = pg.blockingCh
}
if ch == pw.blockingCh {
if ch == pg.blockingCh {
// This could happen when either:
// - pw.picker is nil (the previous if condition), or
// - has called pick on the current picker.
pw.mu.Unlock()
// - we have already called pick on the current picker.
select {
case <-ctx.Done():
var errStr string
if lastPickErr != nil {
errStr = "latest balancer error: " + lastPickErr.Error()
} else {
errStr = ctx.Err().Error()
errStr = fmt.Sprintf("%v while waiting for connections to become ready", ctx.Err())
}
switch ctx.Err() {
case context.DeadlineExceeded:
@@ -144,9 +150,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
}
ch = pw.blockingCh
p := pw.picker
pw.mu.Unlock()
ch = pg.blockingCh
p := pg.picker
pickResult, err := p.Pick(info)
if err != nil {
@@ -196,24 +201,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
func (pw *pickerWrapper) close() {
pw.mu.Lock()
defer pw.mu.Unlock()
if pw.done {
return
}
pw.done = true
close(pw.blockingCh)
old := pw.pickerGen.Swap(nil)
close(old.blockingCh)
}
// reset clears the pickerWrapper and prepares it for being used again when idle
// mode is exited.
func (pw *pickerWrapper) reset() {
pw.mu.Lock()
defer pw.mu.Unlock()
if pw.done {
return
}
pw.blockingCh = make(chan struct{})
old := pw.pickerGen.Swap(&pickerGeneration{blockingCh: make(chan struct{})})
close(old.blockingCh)
}
// dropError is a wrapper error that indicates the LB policy wishes to drop the

View File

@@ -20,6 +20,7 @@ package grpc
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
)
@@ -31,9 +32,10 @@ import (
// later release.
type PreparedMsg struct {
// Struct for preparing msg before sending them
encodedData []byte
encodedData mem.BufferSlice
hdr []byte
payload []byte
payload mem.BufferSlice
pf payloadFormat
}
// Encode marshalls and compresses the message using the codec and compressor for the stream.
@@ -57,11 +59,27 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
if err != nil {
return err
}
p.encodedData = data
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp)
materializedData := data.Materialize()
data.Free()
p.encodedData = mem.BufferSlice{mem.SliceBuffer(materializedData)}
// TODO: it should be possible to grab the bufferPool from the underlying
// stream implementation with a type cast to its actual type (such as
// addrConnStream) and accessing the buffer pool directly.
var compData mem.BufferSlice
compData, p.pf, err = compress(p.encodedData, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp, mem.DefaultBufferPool())
if err != nil {
return err
}
p.hdr, p.payload = msgHeader(data, compData)
if p.pf.isCompressed() {
materializedCompData := compData.Materialize()
compData.Free()
compData = mem.BufferSlice{mem.SliceBuffer(materializedCompData)}
}
p.hdr, p.payload = msgHeader(p.encodedData, compData, p.pf)
return nil
}

View File

@@ -1,123 +0,0 @@
#!/bin/bash
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eu -o pipefail
WORKDIR=$(mktemp -d)
function finish {
rm -rf "$WORKDIR"
}
trap finish EXIT
export GOBIN=${WORKDIR}/bin
export PATH=${GOBIN}:${PATH}
mkdir -p ${GOBIN}
echo "remove existing generated files"
# grpc_testing_not_regenerate/*.pb.go is not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm -f $(find . -name '*.pb.go' | grep -v 'grpc_testing_not_regenerate')
echo "go install google.golang.org/protobuf/cmd/protoc-gen-go"
(cd test/tools && go install google.golang.org/protobuf/cmd/protoc-gen-go)
echo "go install cmd/protoc-gen-go-grpc"
(cd cmd/protoc-gen-go-grpc && go install .)
echo "git clone https://github.com/grpc/grpc-proto"
git clone --quiet https://github.com/grpc/grpc-proto ${WORKDIR}/grpc-proto
echo "git clone https://github.com/protocolbuffers/protobuf"
git clone --quiet https://github.com/protocolbuffers/protobuf ${WORKDIR}/protobuf
# Pull in code.proto as a proto dependency
mkdir -p ${WORKDIR}/googleapis/google/rpc
echo "curl https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto"
curl --silent https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto > ${WORKDIR}/googleapis/google/rpc/code.proto
mkdir -p ${WORKDIR}/out
# Generates sources without the embed requirement
LEGACY_SOURCES=(
${WORKDIR}/grpc-proto/grpc/binlog/v1/binarylog.proto
${WORKDIR}/grpc-proto/grpc/channelz/v1/channelz.proto
${WORKDIR}/grpc-proto/grpc/health/v1/health.proto
${WORKDIR}/grpc-proto/grpc/lb/v1/load_balancer.proto
profiling/proto/service.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1alpha/reflection.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1/reflection.proto
)
# Generates only the new gRPC Service symbols
SOURCES=(
$(git ls-files --exclude-standard --cached --others "*.proto" | grep -v '^\(profiling/proto/service.proto\|reflection/grpc_reflection_v1alpha/reflection.proto\)$')
${WORKDIR}/grpc-proto/grpc/gcp/altscontext.proto
${WORKDIR}/grpc-proto/grpc/gcp/handshaker.proto
${WORKDIR}/grpc-proto/grpc/gcp/transport_security_common.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls_config.proto
${WORKDIR}/grpc-proto/grpc/testing/*.proto
${WORKDIR}/grpc-proto/grpc/core/*.proto
)
# These options of the form 'Mfoo.proto=bar' instruct the codegen to use an
# import path of 'bar' in the generated code when 'foo.proto' is imported in
# one of the sources.
#
# Note that the protos listed here are all for testing purposes. All protos to
# be used externally should have a go_package option (and they don't need to be
# listed here).
OPTS=Mgrpc/core/stats.proto=google.golang.org/grpc/interop/grpc_testing/core,\
Mgrpc/testing/benchmark_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/stats.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/report_qps_scenario_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/messages.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/worker_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/control.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/test.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/payloads.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/empty.proto=google.golang.org/grpc/interop/grpc_testing
for src in ${SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS}:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
for src in ${LEGACY_SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS},require_unimplemented_servers=false:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
# The go_package option in grpc/lookup/v1/rls.proto doesn't match the
# current location. Move it into the right place.
mkdir -p ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
mv ${WORKDIR}/out/google.golang.org/grpc/lookup/grpc_lookup_v1/* ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
# grpc_testing_not_regenerate/*.pb.go are not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm ${WORKDIR}/out/google.golang.org/grpc/reflection/grpc_testing_not_regenerate/*.pb.go
cp -R ${WORKDIR}/out/google.golang.org/grpc/* .

View File

@@ -18,9 +18,6 @@
// Package dns implements a dns resolver to be installed as the default resolver
// in grpc.
//
// Deprecated: this package is imported by grpc and should not need to be
// imported directly by users.
package dns
import (
@@ -52,3 +49,12 @@ func SetResolvingTimeout(timeout time.Duration) {
func NewBuilder() resolver.Builder {
return dns.NewBuilder()
}
// SetMinResolutionInterval sets the default minimum interval at which DNS
// re-resolutions are allowed. This helps to prevent excessive re-resolution.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
func SetMinResolutionInterval(d time.Duration) {
dns.MinResolutionInterval = d
}

View File

@@ -18,16 +18,28 @@
package resolver
type addressMapEntry struct {
import (
"encoding/base64"
"sort"
"strings"
)
type addressMapEntry[T any] struct {
addr Address
value any
value T
}
// AddressMap is a map of addresses to arbitrary values taking into account
// AddressMap is an AddressMapV2[any]. It will be deleted in an upcoming
// release of grpc-go.
//
// Deprecated: use the generic AddressMapV2 type instead.
type AddressMap = AddressMapV2[any]
// AddressMapV2 is a map of addresses to arbitrary values taking into account
// Attributes. BalancerAttributes are ignored, as are Metadata and Type.
// Multiple accesses may not be performed concurrently. Must be created via
// NewAddressMap; do not construct directly.
type AddressMap struct {
type AddressMapV2[T any] struct {
// The underlying map is keyed by an Address with fields that we don't care
// about being set to their zero values. The only fields that we care about
// are `Addr`, `ServerName` and `Attributes`. Since we need to be able to
@@ -41,23 +53,30 @@ type AddressMap struct {
// The value type of the map contains a slice of addresses which match the key
// in their `Addr` and `ServerName` fields and contain the corresponding value
// associated with them.
m map[Address]addressMapEntryList
m map[Address]addressMapEntryList[T]
}
func toMapKey(addr *Address) Address {
return Address{Addr: addr.Addr, ServerName: addr.ServerName}
}
type addressMapEntryList []*addressMapEntry
type addressMapEntryList[T any] []*addressMapEntry[T]
// NewAddressMap creates a new AddressMap.
// NewAddressMap creates a new AddressMapV2[any].
//
// Deprecated: use the generic NewAddressMapV2 constructor instead.
func NewAddressMap() *AddressMap {
return &AddressMap{m: make(map[Address]addressMapEntryList)}
return NewAddressMapV2[any]()
}
// NewAddressMapV2 creates a new AddressMapV2.
func NewAddressMapV2[T any]() *AddressMapV2[T] {
return &AddressMapV2[T]{m: make(map[Address]addressMapEntryList[T])}
}
// find returns the index of addr in the addressMapEntry slice, or -1 if not
// present.
func (l addressMapEntryList) find(addr Address) int {
func (l addressMapEntryList[T]) find(addr Address) int {
for i, entry := range l {
// Attributes are the only thing to match on here, since `Addr` and
// `ServerName` are already equal.
@@ -69,28 +88,28 @@ func (l addressMapEntryList) find(addr Address) int {
}
// Get returns the value for the address in the map, if present.
func (a *AddressMap) Get(addr Address) (value any, ok bool) {
func (a *AddressMapV2[T]) Get(addr Address) (value T, ok bool) {
addrKey := toMapKey(&addr)
entryList := a.m[addrKey]
if entry := entryList.find(addr); entry != -1 {
return entryList[entry].value, true
}
return nil, false
return value, false
}
// Set updates or adds the value to the address in the map.
func (a *AddressMap) Set(addr Address, value any) {
func (a *AddressMapV2[T]) Set(addr Address, value T) {
addrKey := toMapKey(&addr)
entryList := a.m[addrKey]
if entry := entryList.find(addr); entry != -1 {
entryList[entry].value = value
return
}
a.m[addrKey] = append(entryList, &addressMapEntry{addr: addr, value: value})
a.m[addrKey] = append(entryList, &addressMapEntry[T]{addr: addr, value: value})
}
// Delete removes addr from the map.
func (a *AddressMap) Delete(addr Address) {
func (a *AddressMapV2[T]) Delete(addr Address) {
addrKey := toMapKey(&addr)
entryList := a.m[addrKey]
entry := entryList.find(addr)
@@ -107,7 +126,7 @@ func (a *AddressMap) Delete(addr Address) {
}
// Len returns the number of entries in the map.
func (a *AddressMap) Len() int {
func (a *AddressMapV2[T]) Len() int {
ret := 0
for _, entryList := range a.m {
ret += len(entryList)
@@ -116,7 +135,7 @@ func (a *AddressMap) Len() int {
}
// Keys returns a slice of all current map keys.
func (a *AddressMap) Keys() []Address {
func (a *AddressMapV2[T]) Keys() []Address {
ret := make([]Address, 0, a.Len())
for _, entryList := range a.m {
for _, entry := range entryList {
@@ -127,8 +146,8 @@ func (a *AddressMap) Keys() []Address {
}
// Values returns a slice of all current map values.
func (a *AddressMap) Values() []any {
ret := make([]any, 0, a.Len())
func (a *AddressMapV2[T]) Values() []T {
ret := make([]T, 0, a.Len())
for _, entryList := range a.m {
for _, entry := range entryList {
ret = append(ret, entry.value)
@@ -137,70 +156,65 @@ func (a *AddressMap) Values() []any {
return ret
}
type endpointNode struct {
addrs map[string]struct{}
}
// Equal returns whether the unordered set of addrs are the same between the
// endpoint nodes.
func (en *endpointNode) Equal(en2 *endpointNode) bool {
if len(en.addrs) != len(en2.addrs) {
return false
}
for addr := range en.addrs {
if _, ok := en2.addrs[addr]; !ok {
return false
}
}
return true
}
func toEndpointNode(endpoint Endpoint) endpointNode {
en := make(map[string]struct{})
for _, addr := range endpoint.Addresses {
en[addr.Addr] = struct{}{}
}
return endpointNode{
addrs: en,
}
}
type endpointMapKey string
// EndpointMap is a map of endpoints to arbitrary values keyed on only the
// unordered set of address strings within an endpoint. This map is not thread
// safe, thus it is unsafe to access concurrently. Must be created via
// NewEndpointMap; do not construct directly.
type EndpointMap struct {
endpoints map[*endpointNode]any
type EndpointMap[T any] struct {
endpoints map[endpointMapKey]endpointData[T]
}
type endpointData[T any] struct {
// decodedKey stores the original key to avoid decoding when iterating on
// EndpointMap keys.
decodedKey Endpoint
value T
}
// NewEndpointMap creates a new EndpointMap.
func NewEndpointMap() *EndpointMap {
return &EndpointMap{
endpoints: make(map[*endpointNode]any),
func NewEndpointMap[T any]() *EndpointMap[T] {
return &EndpointMap[T]{
endpoints: make(map[endpointMapKey]endpointData[T]),
}
}
// encodeEndpoint returns a string that uniquely identifies the unordered set of
// addresses within an endpoint.
func encodeEndpoint(e Endpoint) endpointMapKey {
addrs := make([]string, 0, len(e.Addresses))
// base64 encoding the address strings restricts the characters present
// within the strings. This allows us to use a delimiter without the need of
// escape characters.
for _, addr := range e.Addresses {
addrs = append(addrs, base64.StdEncoding.EncodeToString([]byte(addr.Addr)))
}
sort.Strings(addrs)
// " " should not appear in base64 encoded strings.
return endpointMapKey(strings.Join(addrs, " "))
}
// Get returns the value for the address in the map, if present.
func (em *EndpointMap) Get(e Endpoint) (value any, ok bool) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
return em.endpoints[endpoint], true
func (em *EndpointMap[T]) Get(e Endpoint) (value T, ok bool) {
val, found := em.endpoints[encodeEndpoint(e)]
if found {
return val.value, true
}
return nil, false
return value, false
}
// Set updates or adds the value to the address in the map.
func (em *EndpointMap) Set(e Endpoint, value any) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
em.endpoints[endpoint] = value
return
func (em *EndpointMap[T]) Set(e Endpoint, value T) {
en := encodeEndpoint(e)
em.endpoints[en] = endpointData[T]{
decodedKey: Endpoint{Addresses: e.Addresses},
value: value,
}
em.endpoints[&en] = value
}
// Len returns the number of entries in the map.
func (em *EndpointMap) Len() int {
func (em *EndpointMap[T]) Len() int {
return len(em.endpoints)
}
@@ -209,43 +223,25 @@ func (em *EndpointMap) Len() int {
// the unordered set of addresses. Thus, endpoint information returned is not
// the full endpoint data (drops duplicated addresses and attributes) but can be
// used for EndpointMap accesses.
func (em *EndpointMap) Keys() []Endpoint {
func (em *EndpointMap[T]) Keys() []Endpoint {
ret := make([]Endpoint, 0, len(em.endpoints))
for en := range em.endpoints {
var endpoint Endpoint
for addr := range en.addrs {
endpoint.Addresses = append(endpoint.Addresses, Address{Addr: addr})
}
ret = append(ret, endpoint)
for _, en := range em.endpoints {
ret = append(ret, en.decodedKey)
}
return ret
}
// Values returns a slice of all current map values.
func (em *EndpointMap) Values() []any {
ret := make([]any, 0, len(em.endpoints))
func (em *EndpointMap[T]) Values() []T {
ret := make([]T, 0, len(em.endpoints))
for _, val := range em.endpoints {
ret = append(ret, val)
ret = append(ret, val.value)
}
return ret
}
// find returns a pointer to the endpoint node in em if the endpoint node is
// already present. If not found, nil is returned. The comparisons are done on
// the unordered set of addresses within an endpoint.
func (em EndpointMap) find(e endpointNode) *endpointNode {
for endpoint := range em.endpoints {
if e.Equal(endpoint) {
return endpoint
}
}
return nil
}
// Delete removes the specified endpoint from the map.
func (em *EndpointMap) Delete(e Endpoint) {
en := toEndpointNode(e)
if entry := em.find(en); entry != nil {
delete(em.endpoints, entry)
}
func (em *EndpointMap[T]) Delete(e Endpoint) {
en := encodeEndpoint(e)
delete(em.endpoints, en)
}

View File

@@ -22,6 +22,7 @@ package resolver
import (
"context"
"errors"
"fmt"
"net"
"net/url"
@@ -29,6 +30,7 @@ import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig"
)
@@ -174,6 +176,8 @@ type BuildOptions struct {
// Authority is the effective authority of the clientconn for which the
// resolver is built.
Authority string
// MetricsRecorder is the metrics recorder to do recording.
MetricsRecorder stats.MetricsRecorder
}
// An Endpoint is one network endpoint, or server, which may have multiple
@@ -237,8 +241,8 @@ type ClientConn interface {
// UpdateState can be omitted.
UpdateState(State) error
// ReportError notifies the ClientConn that the Resolver encountered an
// error. The ClientConn will notify the load balancer and begin calling
// ResolveNow on the Resolver with exponential backoff.
// error. The ClientConn then forwards this error to the load balancing
// policy.
ReportError(error)
// NewAddress is called by resolver to notify ClientConn a new list
// of resolved addresses.
@@ -330,3 +334,20 @@ type AuthorityOverrider interface {
// typically in line, and must keep it unchanged.
OverrideAuthority(Target) string
}
// ValidateEndpoints validates endpoints from a petiole policy's perspective.
// Petiole policies should call this before calling into their children. See
// [gRPC A61](https://github.com/grpc/proposal/blob/master/A61-IPv4-IPv6-dualstack-backends.md)
// for details.
func ValidateEndpoints(endpoints []Endpoint) error {
if len(endpoints) == 0 {
return errors.New("endpoints list is empty")
}
for _, endpoint := range endpoints {
for range endpoint.Addresses {
return nil
}
}
return errors.New("endpoints list contains no addresses")
}

View File

@@ -26,6 +26,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/resolver/delegatingresolver"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
@@ -66,7 +67,7 @@ func newCCResolverWrapper(cc *ClientConn) *ccResolverWrapper {
// any newly created ccResolverWrapper, except that close may be called instead.
func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error)
ccr.serializer.Schedule(func(ctx context.Context) {
ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil {
return
}
@@ -76,16 +77,26 @@ func (ccr *ccResolverWrapper) start() error {
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
Dialer: ccr.cc.dopts.copts.Dialer,
Authority: ccr.cc.authority,
MetricsRecorder: ccr.cc.metricsRecorderList,
}
var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
// The delegating resolver is used unless:
// - A custom dialer is provided via WithContextDialer dialoption or
// - Proxy usage is disabled through WithNoProxy dialoption.
// In these cases, the resolver is built based on the scheme of target,
// using the appropriate resolver builder.
if ccr.cc.dopts.copts.Dialer != nil || !ccr.cc.dopts.useProxy {
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
} else {
ccr.resolver, err = delegatingresolver.New(ccr.cc.parsedTarget, ccr, opts, ccr.cc.resolverBuilder, ccr.cc.dopts.enableLocalDNSResolution)
}
errCh <- err
})
return <-errCh
}
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.serializer.Schedule(func(ctx context.Context) {
ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccr.resolver == nil {
return
}
@@ -102,7 +113,7 @@ func (ccr *ccResolverWrapper) close() {
ccr.closed = true
ccr.mu.Unlock()
ccr.serializer.Schedule(func(context.Context) {
ccr.serializer.TrySchedule(func(context.Context) {
if ccr.resolver == nil {
return
}
@@ -123,12 +134,7 @@ func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error {
return nil
}
if s.Endpoints == nil {
s.Endpoints = make([]resolver.Endpoint, 0, len(s.Addresses))
for _, a := range s.Addresses {
ep := resolver.Endpoint{Addresses: []resolver.Address{a}, Attributes: a.BalancerAttributes}
ep.Addresses[0].BalancerAttributes = nil
s.Endpoints = append(s.Endpoints, ep)
}
s.Endpoints = addressesToEndpoints(s.Addresses)
}
ccr.addChannelzTraceEvent(s)
ccr.curState = s
@@ -161,7 +167,11 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
ccr.cc.mu.Unlock()
return
}
s := resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}
s := resolver.State{
Addresses: addrs,
ServiceConfig: ccr.curState.ServiceConfig,
Endpoints: addressesToEndpoints(addrs),
}
ccr.addChannelzTraceEvent(s)
ccr.curState = s
ccr.mu.Unlock()
@@ -171,12 +181,15 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
// ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON)
return parseServiceConfig(scJSON, ccr.cc.dopts.maxCallAttempts)
}
// addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations.
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
if !logger.V(0) && !channelz.IsOn() {
return
}
var updates []string
var oldSC, newSC *ServiceConfig
var oldOK, newOK bool
@@ -196,3 +209,13 @@ func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
}
channelz.Infof(logger, ccr.cc.channelz, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
}
func addressesToEndpoints(addrs []resolver.Address) []resolver.Endpoint {
endpoints := make([]resolver.Endpoint, 0, len(addrs))
for _, a := range addrs {
ep := resolver.Endpoint{Addresses: []resolver.Address{a}, Attributes: a.BalancerAttributes}
ep.Addresses[0].BalancerAttributes = nil
endpoints = append(endpoints, ep)
}
return endpoints
}

View File

@@ -19,7 +19,6 @@
package grpc
import (
"bytes"
"compress/gzip"
"context"
"encoding/binary"
@@ -35,6 +34,7 @@ import (
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -151,7 +151,7 @@ func (d *gzipDecompressor) Type() string {
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
compressorType string
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
@@ -220,9 +220,9 @@ type HeaderCallOption struct {
HeaderAddr *metadata.MD
}
func (o HeaderCallOption) before(c *callInfo) error { return nil }
func (o HeaderCallOption) after(c *callInfo, attempt *csAttempt) {
*o.HeaderAddr, _ = attempt.s.Header()
func (o HeaderCallOption) before(*callInfo) error { return nil }
func (o HeaderCallOption) after(_ *callInfo, attempt *csAttempt) {
*o.HeaderAddr, _ = attempt.transportStream.Header()
}
// Trailer returns a CallOptions that retrieves the trailer metadata
@@ -242,9 +242,9 @@ type TrailerCallOption struct {
TrailerAddr *metadata.MD
}
func (o TrailerCallOption) before(c *callInfo) error { return nil }
func (o TrailerCallOption) after(c *callInfo, attempt *csAttempt) {
*o.TrailerAddr = attempt.s.Trailer()
func (o TrailerCallOption) before(*callInfo) error { return nil }
func (o TrailerCallOption) after(_ *callInfo, attempt *csAttempt) {
*o.TrailerAddr = attempt.transportStream.Trailer()
}
// Peer returns a CallOption that retrieves peer information for a unary RPC.
@@ -264,24 +264,20 @@ type PeerCallOption struct {
PeerAddr *peer.Peer
}
func (o PeerCallOption) before(c *callInfo) error { return nil }
func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
if x, ok := peer.FromContext(attempt.s.Context()); ok {
func (o PeerCallOption) before(*callInfo) error { return nil }
func (o PeerCallOption) after(_ *callInfo, attempt *csAttempt) {
if x, ok := peer.FromContext(attempt.transportStream.Context()); ok {
*o.PeerAddr = *x
}
}
// WaitForReady configures the action to take when an RPC is attempted on broken
// connections or unreachable servers. If waitForReady is false and the
// connection is in the TRANSIENT_FAILURE state, the RPC will fail
// immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will
// retry the call if it fails due to a transient error. gRPC will not retry if
// data was written to the wire unless the server indicates it did not process
// the data. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md.
// WaitForReady configures the RPC's behavior when the client is in
// TRANSIENT_FAILURE, which occurs when all addresses fail to connect. If
// waitForReady is false, the RPC will fail immediately. Otherwise, the client
// will wait until a connection becomes available or the RPC's deadline is
// reached.
//
// By default, RPCs don't "wait for ready".
// By default, RPCs do not "wait for ready".
func WaitForReady(waitForReady bool) CallOption {
return FailFastCallOption{FailFast: !waitForReady}
}
@@ -308,7 +304,7 @@ func (o FailFastCallOption) before(c *callInfo) error {
c.failFast = o.FailFast
return nil
}
func (o FailFastCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o FailFastCallOption) after(*callInfo, *csAttempt) {}
// OnFinish returns a CallOption that configures a callback to be called when
// the call completes. The error passed to the callback is the status of the
@@ -343,7 +339,7 @@ func (o OnFinishCallOption) before(c *callInfo) error {
return nil
}
func (o OnFinishCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o OnFinishCallOption) after(*callInfo, *csAttempt) {}
// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size
// in bytes the client can receive. If this is not set, gRPC uses the default
@@ -367,7 +363,7 @@ func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
c.maxReceiveMessageSize = &o.MaxRecvMsgSize
return nil
}
func (o MaxRecvMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o MaxRecvMsgSizeCallOption) after(*callInfo, *csAttempt) {}
// MaxCallSendMsgSize returns a CallOption which sets the maximum message size
// in bytes the client can send. If this is not set, gRPC uses the default
@@ -391,7 +387,7 @@ func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
c.maxSendMessageSize = &o.MaxSendMsgSize
return nil
}
func (o MaxSendMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o MaxSendMsgSizeCallOption) after(*callInfo, *csAttempt) {}
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call.
@@ -414,7 +410,7 @@ func (o PerRPCCredsCallOption) before(c *callInfo) error {
c.creds = o.Creds
return nil
}
func (o PerRPCCredsCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o PerRPCCredsCallOption) after(*callInfo, *csAttempt) {}
// UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has
@@ -439,10 +435,10 @@ type CompressorCallOption struct {
}
func (o CompressorCallOption) before(c *callInfo) error {
c.compressorType = o.CompressorType
c.compressorName = o.CompressorType
return nil
}
func (o CompressorCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
@@ -479,7 +475,7 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
c.contentSubtype = o.ContentSubtype
return nil
}
func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o ContentSubtypeCallOption) after(*callInfo, *csAttempt) {}
// ForceCodec returns a CallOption that will set codec to be used for all
// request and response messages for a call. The result of calling Name() will
@@ -515,10 +511,50 @@ type ForceCodecCallOption struct {
}
func (o ForceCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
c.codec = newCodecV1Bridge(o.Codec)
return nil
}
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o ForceCodecCallOption) after(*callInfo, *csAttempt) {}
// ForceCodecV2 returns a CallOption that will set codec to be used for all
// request and response messages for a call. The result of calling Name() will
// be used as the content-subtype after converting to lowercase, unless
// CallContentSubtype is also used.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceCodecV2(codec encoding.CodecV2) CallOption {
return ForceCodecV2CallOption{CodecV2: codec}
}
// ForceCodecV2CallOption is a CallOption that indicates the codec used for
// marshaling messages.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type ForceCodecV2CallOption struct {
CodecV2 encoding.CodecV2
}
func (o ForceCodecV2CallOption) before(c *callInfo) error {
c.codec = o.CodecV2
return nil
}
func (o ForceCodecV2CallOption) after(*callInfo, *csAttempt) {}
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
// an encoding.Codec.
@@ -540,10 +576,10 @@ type CustomCodecCallOption struct {
}
func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
c.codec = newCodecV0Bridge(o.Codec)
return nil
}
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o CustomCodecCallOption) after(*callInfo, *csAttempt) {}
// MaxRetryRPCBufferSize returns a CallOption that limits the amount of memory
// used for buffering this RPC's requests for retry purposes.
@@ -571,7 +607,7 @@ func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error {
c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize
return nil
}
func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
func (o MaxRetryRPCBufferSizeCallOption) after(*callInfo, *csAttempt) {}
// The format of the payload: compressed or not?
type payloadFormat uint8
@@ -581,19 +617,28 @@ const (
compressionMade payloadFormat = 1 // compressed
)
func (pf payloadFormat) isCompressed() bool {
return pf == compressionMade
}
type streamReader interface {
ReadMessageHeader(header []byte) error
Read(n int) (mem.BufferSlice, error)
}
// parser reads complete gRPC messages from the underlying reader.
type parser struct {
// r is the underlying reader.
// See the comment on recvMsg for the permissible
// error types.
r io.Reader
r streamReader
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte
// recvBufferPool is the pool of shared receive buffers.
recvBufferPool SharedBufferPool
// bufferPool is the pool of shared receive buffers.
bufferPool mem.BufferPool
}
// recvMsg reads a complete gRPC message from the stream.
@@ -608,39 +653,38 @@ type parser struct {
// - an error from the status package
//
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// that the underlying streamReader must not return an incompatible
// error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := p.r.Read(p.header[:]); err != nil {
func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) {
err := p.r.ReadMessageHeader(p.header[:])
if err != nil {
return 0, nil, err
}
pf = payloadFormat(p.header[0])
pf := payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 {
return pf, nil, nil
}
if int64(length) > int64(maxInt) {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
}
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
msg = p.recvBufferPool.Get(int(length))
if _, err := p.r.Read(msg); err != nil {
data, err := p.r.Read(int(length))
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, nil, err
}
return pf, msg, nil
return pf, data, nil
}
// encode serializes msg and returns a buffer containing the message, or an
// error if it is too large to be transmitted by grpc. If msg is nil, it
// generates an empty message.
func encode(c baseCodec, msg any) ([]byte, error) {
func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
if msg == nil { // NOTE: typed nils will not be caught by this check
return nil, nil
}
@@ -648,8 +692,9 @@ func encode(c baseCodec, msg any) ([]byte, error) {
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
if uint(len(b)) > math.MaxUint32 {
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
if bufSize := uint(b.Len()); bufSize > math.MaxUint32 {
b.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", bufSize)
}
return b, nil
}
@@ -659,34 +704,41 @@ func encode(c baseCodec, msg any) ([]byte, error) {
// indicating no compression was done.
//
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
if compressor == nil && cp == nil {
return nil, nil
}
if len(in) == 0 {
return nil, nil
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
if (compressor == nil && cp == nil) || in.Len() == 0 {
return nil, compressionNone, nil
}
var out mem.BufferSlice
w := mem.NewWriter(&out, pool)
wrapErr := func(err error) error {
out.Free()
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
cbuf := &bytes.Buffer{}
if compressor != nil {
z, err := compressor.Compress(cbuf)
z, err := compressor.Compress(w)
if err != nil {
return nil, wrapErr(err)
return nil, 0, wrapErr(err)
}
if _, err := z.Write(in); err != nil {
return nil, wrapErr(err)
for _, b := range in {
if _, err := z.Write(b.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
}
}
if err := z.Close(); err != nil {
return nil, wrapErr(err)
return nil, 0, wrapErr(err)
}
} else {
if err := cp.Do(cbuf, in); err != nil {
return nil, wrapErr(err)
// This is obviously really inefficient since it fully materializes the data, but
// there is no way around this with the old Compressor API. At least it attempts
// to return the buffer to the provider, in the hopes it can be reused (maybe
// even by a subsequent call to this very function).
buf := in.MaterializeToBuffer(pool)
defer buf.Free()
if err := cp.Do(w, buf.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
}
}
return cbuf.Bytes(), nil
return out, compressionMade, nil
}
const (
@@ -697,33 +749,36 @@ const (
// msgHeader returns a 5-byte header for the message being transmitted and the
// payload, which is compData if non-nil or data otherwise.
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
func msgHeader(data, compData mem.BufferSlice, pf payloadFormat) (hdr []byte, payload mem.BufferSlice) {
hdr = make([]byte, headerLen)
if compData != nil {
hdr[0] = byte(compressionMade)
data = compData
hdr[0] = byte(pf)
var length uint32
if pf.isCompressed() {
length = uint32(compData.Len())
payload = compData
} else {
hdr[0] = byte(compressionNone)
length = uint32(data.Len())
payload = data
}
// Write length of payload into buf
binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
return hdr, data
binary.BigEndian.PutUint32(hdr[payloadLen:], length)
return hdr, payload
}
func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload {
func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time) *stats.OutPayload {
return &stats.OutPayload{
Client: client,
Payload: msg,
Data: data,
Length: len(data),
WireLength: len(payload) + headerLen,
CompressedLength: len(payload),
Length: dataLength,
WireLength: payloadLength + headerLen,
CompressedLength: payloadLength,
SentTime: t,
}
}
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
@@ -731,7 +786,10 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
@@ -741,104 +799,119 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
type payloadInfo struct {
compressedLength int // The compressed length got from wire.
uncompressedBytes []byte
uncompressedBytes mem.BufferSlice
}
func (p *payloadInfo) free() {
if p != nil && p.uncompressedBytes != nil {
p.uncompressedBytes.Free()
}
}
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, nil, err
return nil, err
}
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return nil, nil, st.Err()
compressedLength := compressed.Len()
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
compressed.Free()
return nil, st.Err()
}
var size int
if pf == compressionMade {
if pf.isCompressed() {
defer compressed.Free()
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
size = len(uncompressedBuf)
} else {
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
}
out, err = decompress(compressor, compressed, dc, maxReceiveMessageSize, p.bufferPool)
if err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, err
}
} else {
uncompressedBuf = compressedBuf
out = compressed
}
if payInfo != nil {
payInfo.compressedLength = len(compressedBuf)
payInfo.uncompressedBytes = uncompressedBuf
cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
payInfo.compressedLength = compressedLength
out.Ref()
payInfo.uncompressedBytes = out
}
return uncompressedBuf, cancel, nil
return out, nil
}
// Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) {
dcReader, err := compressor.Decompress(bytes.NewReader(d))
if err != nil {
return nil, 0, err
}
if sizer, ok := compressor.(interface {
DecompressedSize(compressedBytes []byte) int
}); ok {
if size := sizer.DecompressedSize(d); size >= 0 {
if size > maxReceiveMessageSize {
return nil, size, nil
}
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// we can also utilize the recv buffer pool here.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor.
// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data
// does not exceed the specified maximum size and returns an error if this limit is exceeded.
// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit.
func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) {
if dc != nil {
uncompressed, err := dc.Do(d.Reader())
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if len(uncompressed) > maxReceiveMessageSize {
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize)
}
return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil
}
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return d, len(d), err
if compressor != nil {
dcReader, err := compressor.Decompress(d.Reader())
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
}
// Read at most one byte more than the limit from the decompressor.
// Unless the limit is MaxInt64, in which case, that's impossible, so
// apply no limit.
if limit := int64(maxReceiveMessageSize); limit < math.MaxInt64 {
dcReader = io.LimitReader(dcReader, limit+1)
}
out, err := mem.ReadAll(dcReader, pool)
if err != nil {
out.Free()
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
}
if out.Len() > maxReceiveMessageSize {
out.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
}
return out, nil
}
return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload")
}
type recvCompressor interface {
RecvCompress() string
}
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err
}
defer cancel()
if err := c.Unmarshal(buf, m); err != nil {
// If the codec wants its own reference to the data, it can get it. Otherwise, always
// free the buffers.
defer data.Free()
if err := c.Unmarshal(data, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
return nil
}
@@ -941,7 +1014,7 @@ func setCallInfoCodec(c *callInfo) error {
// encoding.Codec (Name vs. String method name). We only support
// setting content subtype from encoding.Codec to avoid a behavior
// change with the deprecated version.
if ec, ok := c.codec.(encoding.Codec); ok {
if ec, ok := c.codec.(encoding.CodecV2); ok {
c.contentSubtype = strings.ToLower(ec.Name())
}
}
@@ -950,12 +1023,12 @@ func setCallInfoCodec(c *callInfo) error {
if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default.
c.codec = encoding.GetCodec(proto.Name)
c.codec = getCodec(proto.Name)
return nil
}
// c.contentSubtype is already lowercased in CallContentSubtype
c.codec = encoding.GetCodec(c.contentSubtype)
c.codec = getCodec(c.contentSubtype)
if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
}
@@ -964,7 +1037,7 @@ func setCallInfoCodec(c *callInfo) error {
// The SupportPackageIsVersion variables are referenced from generated protocol
// buffer files to ensure compatibility with the gRPC version used. The latest
// support package version is 7.
// support package version is 9.
//
// Older versions are kept for compatibility.
//
@@ -976,6 +1049,7 @@ const (
SupportPackageIsVersion6 = true
SupportPackageIsVersion7 = true
SupportPackageIsVersion8 = true
SupportPackageIsVersion9 = true
)
const grpcUA = "grpc-go/" + Version

Some files were not shown because too many files have changed in this diff Show More