TUN-6772: Add a JWT Validator as an ingress verifier

This adds a new verifier interface that can be attached to ingress.Rule.
This would act as a middleware layer that gets executed at the start of
proxy.ProxyHTTP.

A jwt validator implementation for this verifier is also provided. The
validator downloads the public key from the access teams endpoint and
uses it to verify the JWT sent to cloudflared with the audtag (clientID)
information provided in the config.
This commit is contained in:
Sudarsan Reddy
2022-09-21 15:17:44 +01:00
parent e9a2c85671
commit de07da02cd
51 changed files with 4371 additions and 790 deletions

View File

@@ -0,0 +1,382 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package gracefulswitch implements a graceful switch load balancer.
package gracefulswitch
import (
"errors"
"fmt"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
var errBalancerClosed = errors.New("gracefulSwitchBalancer is closed")
var _ balancer.Balancer = (*Balancer)(nil)
// NewBalancer returns a graceful switch Balancer.
func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions) *Balancer {
return &Balancer{
cc: cc,
bOpts: opts,
}
}
// Balancer is a utility to gracefully switch from one balancer to
// a new balancer. It implements the balancer.Balancer interface.
type Balancer struct {
bOpts balancer.BuildOptions
cc balancer.ClientConn
// mu protects the following fields and all fields within balancerCurrent
// and balancerPending. mu does not need to be held when calling into the
// child balancers, as all calls into these children happen only as a direct
// result of a call into the gracefulSwitchBalancer, which are also
// guaranteed to be synchronous. There is one exception: an UpdateState call
// from a child balancer when current and pending are populated can lead to
// calling Close() on the current. To prevent that racing with an
// UpdateSubConnState from the channel, we hold currentMu during Close and
// UpdateSubConnState calls.
mu sync.Mutex
balancerCurrent *balancerWrapper
balancerPending *balancerWrapper
closed bool // set to true when this balancer is closed
// currentMu must be locked before mu. This mutex guards against this
// sequence of events: UpdateSubConnState() called, finds the
// balancerCurrent, gives up lock, updateState comes in, causes Close() on
// balancerCurrent before the UpdateSubConnState is called on the
// balancerCurrent.
currentMu sync.Mutex
}
// swap swaps out the current lb with the pending lb and updates the ClientConn.
// The caller must hold gsb.mu.
func (gsb *Balancer) swap() {
gsb.cc.UpdateState(gsb.balancerPending.lastState)
cur := gsb.balancerCurrent
gsb.balancerCurrent = gsb.balancerPending
gsb.balancerPending = nil
go func() {
gsb.currentMu.Lock()
defer gsb.currentMu.Unlock()
cur.Close()
}()
}
// Helper function that checks if the balancer passed in is current or pending.
// The caller must hold gsb.mu.
func (gsb *Balancer) balancerCurrentOrPending(bw *balancerWrapper) bool {
return bw == gsb.balancerCurrent || bw == gsb.balancerPending
}
// SwitchTo initializes the graceful switch process, which completes based on
// connectivity state changes on the current/pending balancer. Thus, the switch
// process is not complete when this method returns. This method must be called
// synchronously alongside the rest of the balancer.Balancer methods this
// Graceful Switch Balancer implements.
func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
gsb.mu.Lock()
if gsb.closed {
gsb.mu.Unlock()
return errBalancerClosed
}
bw := &balancerWrapper{
gsb: gsb,
lastState: balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
},
subconns: make(map[balancer.SubConn]bool),
}
balToClose := gsb.balancerPending // nil if there is no pending balancer
if gsb.balancerCurrent == nil {
gsb.balancerCurrent = bw
} else {
gsb.balancerPending = bw
}
gsb.mu.Unlock()
balToClose.Close()
// This function takes a builder instead of a balancer because builder.Build
// can call back inline, and this utility needs to handle the callbacks.
newBalancer := builder.Build(bw, gsb.bOpts)
if newBalancer == nil {
// This is illegal and should never happen; we clear the balancerWrapper
// we were constructing if it happens to avoid a potential panic.
gsb.mu.Lock()
if gsb.balancerPending != nil {
gsb.balancerPending = nil
} else {
gsb.balancerCurrent = nil
}
gsb.mu.Unlock()
return balancer.ErrBadResolverState
}
// This write doesn't need to take gsb.mu because this field never gets read
// or written to on any calls from the current or pending. Calls from grpc
// to this balancer are guaranteed to be called synchronously, so this
// bw.Balancer field will never be forwarded to until this SwitchTo()
// function returns.
bw.Balancer = newBalancer
return nil
}
// Returns nil if the graceful switch balancer is closed.
func (gsb *Balancer) latestBalancer() *balancerWrapper {
gsb.mu.Lock()
defer gsb.mu.Unlock()
if gsb.balancerPending != nil {
return gsb.balancerPending
}
return gsb.balancerCurrent
}
// UpdateClientConnState forwards the update to the latest balancer created.
func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error {
// The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer()
if balToUpdate == nil {
return errBalancerClosed
}
// Perform this call without gsb.mu to prevent deadlocks if the child calls
// back into the channel. The latest balancer can never be closed during a
// call from the channel, even without gsb.mu held.
return balToUpdate.UpdateClientConnState(state)
}
// ResolverError forwards the error to the latest balancer created.
func (gsb *Balancer) ResolverError(err error) {
// The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer()
if balToUpdate == nil {
return
}
// Perform this call without gsb.mu to prevent deadlocks if the child calls
// back into the channel. The latest balancer can never be closed during a
// call from the channel, even without gsb.mu held.
balToUpdate.ResolverError(err)
}
// ExitIdle forwards the call to the latest balancer created.
//
// If the latest balancer does not support ExitIdle, the subConns are
// re-connected to manually.
func (gsb *Balancer) ExitIdle() {
balToUpdate := gsb.latestBalancer()
if balToUpdate == nil {
return
}
// There is no need to protect this read with a mutex, as the write to the
// Balancer field happens in SwitchTo, which completes before this can be
// called.
if ei, ok := balToUpdate.Balancer.(balancer.ExitIdler); ok {
ei.ExitIdle()
return
}
for sc := range balToUpdate.subconns {
sc.Connect()
}
}
// UpdateSubConnState forwards the update to the appropriate child.
func (gsb *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
gsb.currentMu.Lock()
defer gsb.currentMu.Unlock()
gsb.mu.Lock()
// Forward update to the appropriate child. Even if there is a pending
// balancer, the current balancer should continue to get SubConn updates to
// maintain the proper state while the pending is still connecting.
var balToUpdate *balancerWrapper
if gsb.balancerCurrent != nil && gsb.balancerCurrent.subconns[sc] {
balToUpdate = gsb.balancerCurrent
} else if gsb.balancerPending != nil && gsb.balancerPending.subconns[sc] {
balToUpdate = gsb.balancerPending
}
gsb.mu.Unlock()
if balToUpdate == nil {
// SubConn belonged to a stale lb policy that has not yet fully closed,
// or the balancer was already closed.
return
}
balToUpdate.UpdateSubConnState(sc, state)
}
// Close closes any active child balancers.
func (gsb *Balancer) Close() {
gsb.mu.Lock()
gsb.closed = true
currentBalancerToClose := gsb.balancerCurrent
gsb.balancerCurrent = nil
pendingBalancerToClose := gsb.balancerPending
gsb.balancerPending = nil
gsb.mu.Unlock()
currentBalancerToClose.Close()
pendingBalancerToClose.Close()
}
// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer
// methods to help cleanup SubConns created by the wrapped balancer.
//
// It implements the balancer.ClientConn interface and is passed down in that
// capacity to the wrapped balancer. It maintains a set of subConns created by
// the wrapped balancer and calls from the latter to create/update/remove
// SubConns update this set before being forwarded to the parent ClientConn.
// State updates from the wrapped balancer can result in invocation of the
// graceful switch logic.
type balancerWrapper struct {
balancer.Balancer
gsb *Balancer
lastState balancer.State
subconns map[balancer.SubConn]bool // subconns created by this balancer
}
func (bw *balancerWrapper) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
if state.ConnectivityState == connectivity.Shutdown {
bw.gsb.mu.Lock()
delete(bw.subconns, sc)
bw.gsb.mu.Unlock()
}
// There is no need to protect this read with a mutex, as the write to the
// Balancer field happens in SwitchTo, which completes before this can be
// called.
bw.Balancer.UpdateSubConnState(sc, state)
}
// Close closes the underlying LB policy and removes the subconns it created. bw
// must not be referenced via balancerCurrent or balancerPending in gsb when
// called. gsb.mu must not be held. Does not panic with a nil receiver.
func (bw *balancerWrapper) Close() {
// before Close is called.
if bw == nil {
return
}
// There is no need to protect this read with a mutex, as Close() is
// impossible to be called concurrently with the write in SwitchTo(). The
// callsites of Close() for this balancer in Graceful Switch Balancer will
// never be called until SwitchTo() returns.
bw.Balancer.Close()
bw.gsb.mu.Lock()
for sc := range bw.subconns {
bw.gsb.cc.RemoveSubConn(sc)
}
bw.gsb.mu.Unlock()
}
func (bw *balancerWrapper) UpdateState(state balancer.State) {
// Hold the mutex for this entire call to ensure it cannot occur
// concurrently with other updateState() calls. This causes updates to
// lastState and calls to cc.UpdateState to happen atomically.
bw.gsb.mu.Lock()
defer bw.gsb.mu.Unlock()
bw.lastState = state
if !bw.gsb.balancerCurrentOrPending(bw) {
return
}
if bw == bw.gsb.balancerCurrent {
// In the case that the current balancer exits READY, and there is a pending
// balancer, you can forward the pending balancer's cached State up to
// ClientConn and swap the pending into the current. This is because there
// is no reason to gracefully switch from and keep using the old policy as
// the ClientConn is not connected to any backends.
if state.ConnectivityState != connectivity.Ready && bw.gsb.balancerPending != nil {
bw.gsb.swap()
return
}
// Even if there is a pending balancer waiting to be gracefully switched to,
// continue to forward current balancer updates to the Client Conn. Ignoring
// state + picker from the current would cause undefined behavior/cause the
// system to behave incorrectly from the current LB policies perspective.
// Also, the current LB is still being used by grpc to choose SubConns per
// RPC, and thus should use the most updated form of the current balancer.
bw.gsb.cc.UpdateState(state)
return
}
// This method is now dealing with a state update from the pending balancer.
// If the current balancer is currently in a state other than READY, the new
// policy can be swapped into place immediately. This is because there is no
// reason to gracefully switch from and keep using the old policy as the
// ClientConn is not connected to any backends.
if state.ConnectivityState != connectivity.Connecting || bw.gsb.balancerCurrent.lastState.ConnectivityState != connectivity.Ready {
bw.gsb.swap()
}
}
func (bw *balancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
bw.gsb.mu.Lock()
if !bw.gsb.balancerCurrentOrPending(bw) {
bw.gsb.mu.Unlock()
return nil, fmt.Errorf("%T at address %p that called NewSubConn is deleted", bw, bw)
}
bw.gsb.mu.Unlock()
sc, err := bw.gsb.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
}
bw.gsb.mu.Lock()
if !bw.gsb.balancerCurrentOrPending(bw) { // balancer was closed during this call
bw.gsb.cc.RemoveSubConn(sc)
bw.gsb.mu.Unlock()
return nil, fmt.Errorf("%T at address %p that called NewSubConn is deleted", bw, bw)
}
bw.subconns[sc] = true
bw.gsb.mu.Unlock()
return sc, nil
}
func (bw *balancerWrapper) ResolveNow(opts resolver.ResolveNowOptions) {
// Ignore ResolveNow requests from anything other than the most recent
// balancer, because older balancers were already removed from the config.
if bw != bw.gsb.latestBalancer() {
return
}
bw.gsb.cc.ResolveNow(opts)
}
func (bw *balancerWrapper) RemoveSubConn(sc balancer.SubConn) {
bw.gsb.mu.Lock()
if !bw.gsb.balancerCurrentOrPending(bw) {
bw.gsb.mu.Unlock()
return
}
bw.gsb.mu.Unlock()
bw.gsb.cc.RemoveSubConn(sc)
}
func (bw *balancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
bw.gsb.mu.Lock()
if !bw.gsb.balancerCurrentOrPending(bw) {
bw.gsb.mu.Unlock()
return
}
bw.gsb.mu.Unlock()
bw.gsb.cc.UpdateAddresses(sc, addrs)
}
func (bw *balancerWrapper) Target() string {
return bw.gsb.cc.Target()
}

View File

@@ -31,7 +31,7 @@ import (
// Logger is the global binary logger. It can be used to get binary logger for
// each method.
type Logger interface {
getMethodLogger(methodName string) *MethodLogger
GetMethodLogger(methodName string) MethodLogger
}
// binLogger is the global binary logger for the binary. One of this should be
@@ -49,17 +49,24 @@ func SetLogger(l Logger) {
binLogger = l
}
// GetLogger gets the binarg logger.
//
// Only call this at init time.
func GetLogger() Logger {
return binLogger
}
// GetMethodLogger returns the methodLogger for the given methodName.
//
// methodName should be in the format of "/service/method".
//
// Each methodLogger returned by this method is a new instance. This is to
// generate sequence id within the call.
func GetMethodLogger(methodName string) *MethodLogger {
func GetMethodLogger(methodName string) MethodLogger {
if binLogger == nil {
return nil
}
return binLogger.getMethodLogger(methodName)
return binLogger.GetMethodLogger(methodName)
}
func init() {
@@ -68,17 +75,29 @@ func init() {
binLogger = NewLoggerFromConfigString(configStr)
}
type methodLoggerConfig struct {
// MethodLoggerConfig contains the setting for logging behavior of a method
// logger. Currently, it contains the max length of header and message.
type MethodLoggerConfig struct {
// Max length of header and message.
hdr, msg uint64
Header, Message uint64
}
// LoggerConfig contains the config for loggers to create method loggers.
type LoggerConfig struct {
All *MethodLoggerConfig
Services map[string]*MethodLoggerConfig
Methods map[string]*MethodLoggerConfig
Blacklist map[string]struct{}
}
type logger struct {
all *methodLoggerConfig
services map[string]*methodLoggerConfig
methods map[string]*methodLoggerConfig
config LoggerConfig
}
blacklist map[string]struct{}
// NewLoggerFromConfig builds a logger with the given LoggerConfig.
func NewLoggerFromConfig(config LoggerConfig) Logger {
return &logger{config: config}
}
// newEmptyLogger creates an empty logger. The map fields need to be filled in
@@ -88,57 +107,57 @@ func newEmptyLogger() *logger {
}
// Set method logger for "*".
func (l *logger) setDefaultMethodLogger(ml *methodLoggerConfig) error {
if l.all != nil {
func (l *logger) setDefaultMethodLogger(ml *MethodLoggerConfig) error {
if l.config.All != nil {
return fmt.Errorf("conflicting global rules found")
}
l.all = ml
l.config.All = ml
return nil
}
// Set method logger for "service/*".
//
// New methodLogger with same service overrides the old one.
func (l *logger) setServiceMethodLogger(service string, ml *methodLoggerConfig) error {
if _, ok := l.services[service]; ok {
func (l *logger) setServiceMethodLogger(service string, ml *MethodLoggerConfig) error {
if _, ok := l.config.Services[service]; ok {
return fmt.Errorf("conflicting service rules for service %v found", service)
}
if l.services == nil {
l.services = make(map[string]*methodLoggerConfig)
if l.config.Services == nil {
l.config.Services = make(map[string]*MethodLoggerConfig)
}
l.services[service] = ml
l.config.Services[service] = ml
return nil
}
// Set method logger for "service/method".
//
// New methodLogger with same method overrides the old one.
func (l *logger) setMethodMethodLogger(method string, ml *methodLoggerConfig) error {
if _, ok := l.blacklist[method]; ok {
func (l *logger) setMethodMethodLogger(method string, ml *MethodLoggerConfig) error {
if _, ok := l.config.Blacklist[method]; ok {
return fmt.Errorf("conflicting blacklist rules for method %v found", method)
}
if _, ok := l.methods[method]; ok {
if _, ok := l.config.Methods[method]; ok {
return fmt.Errorf("conflicting method rules for method %v found", method)
}
if l.methods == nil {
l.methods = make(map[string]*methodLoggerConfig)
if l.config.Methods == nil {
l.config.Methods = make(map[string]*MethodLoggerConfig)
}
l.methods[method] = ml
l.config.Methods[method] = ml
return nil
}
// Set blacklist method for "-service/method".
func (l *logger) setBlacklist(method string) error {
if _, ok := l.blacklist[method]; ok {
if _, ok := l.config.Blacklist[method]; ok {
return fmt.Errorf("conflicting blacklist rules for method %v found", method)
}
if _, ok := l.methods[method]; ok {
if _, ok := l.config.Methods[method]; ok {
return fmt.Errorf("conflicting method rules for method %v found", method)
}
if l.blacklist == nil {
l.blacklist = make(map[string]struct{})
if l.config.Blacklist == nil {
l.config.Blacklist = make(map[string]struct{})
}
l.blacklist[method] = struct{}{}
l.config.Blacklist[method] = struct{}{}
return nil
}
@@ -148,23 +167,23 @@ func (l *logger) setBlacklist(method string) error {
//
// Each methodLogger returned by this method is a new instance. This is to
// generate sequence id within the call.
func (l *logger) getMethodLogger(methodName string) *MethodLogger {
func (l *logger) GetMethodLogger(methodName string) MethodLogger {
s, m, err := grpcutil.ParseMethod(methodName)
if err != nil {
grpclogLogger.Infof("binarylogging: failed to parse %q: %v", methodName, err)
return nil
}
if ml, ok := l.methods[s+"/"+m]; ok {
return newMethodLogger(ml.hdr, ml.msg)
if ml, ok := l.config.Methods[s+"/"+m]; ok {
return newMethodLogger(ml.Header, ml.Message)
}
if _, ok := l.blacklist[s+"/"+m]; ok {
if _, ok := l.config.Blacklist[s+"/"+m]; ok {
return nil
}
if ml, ok := l.services[s]; ok {
return newMethodLogger(ml.hdr, ml.msg)
if ml, ok := l.config.Services[s]; ok {
return newMethodLogger(ml.Header, ml.Message)
}
if l.all == nil {
if l.config.All == nil {
return nil
}
return newMethodLogger(l.all.hdr, l.all.msg)
return newMethodLogger(l.config.All.Header, l.config.All.Message)
}

View File

@@ -89,7 +89,7 @@ func (l *logger) fillMethodLoggerWithConfigString(config string) error {
if err != nil {
return fmt.Errorf("invalid config: %q, %v", config, err)
}
if err := l.setDefaultMethodLogger(&methodLoggerConfig{hdr: hdr, msg: msg}); err != nil {
if err := l.setDefaultMethodLogger(&MethodLoggerConfig{Header: hdr, Message: msg}); err != nil {
return fmt.Errorf("invalid config: %v", err)
}
return nil
@@ -104,11 +104,11 @@ func (l *logger) fillMethodLoggerWithConfigString(config string) error {
return fmt.Errorf("invalid header/message length config: %q, %v", suffix, err)
}
if m == "*" {
if err := l.setServiceMethodLogger(s, &methodLoggerConfig{hdr: hdr, msg: msg}); err != nil {
if err := l.setServiceMethodLogger(s, &MethodLoggerConfig{Header: hdr, Message: msg}); err != nil {
return fmt.Errorf("invalid config: %v", err)
}
} else {
if err := l.setMethodMethodLogger(s+"/"+m, &methodLoggerConfig{hdr: hdr, msg: msg}); err != nil {
if err := l.setMethodMethodLogger(s+"/"+m, &MethodLoggerConfig{Header: hdr, Message: msg}); err != nil {
return fmt.Errorf("invalid config: %v", err)
}
}

View File

@@ -48,7 +48,11 @@ func (g *callIDGenerator) reset() {
var idGen callIDGenerator
// MethodLogger is the sub-logger for each method.
type MethodLogger struct {
type MethodLogger interface {
Log(LogEntryConfig)
}
type methodLogger struct {
headerMaxLen, messageMaxLen uint64
callID uint64
@@ -57,8 +61,8 @@ type MethodLogger struct {
sink Sink // TODO(blog): make this plugable.
}
func newMethodLogger(h, m uint64) *MethodLogger {
return &MethodLogger{
func newMethodLogger(h, m uint64) *methodLogger {
return &methodLogger{
headerMaxLen: h,
messageMaxLen: m,
@@ -69,8 +73,10 @@ func newMethodLogger(h, m uint64) *MethodLogger {
}
}
// Log creates a proto binary log entry, and logs it to the sink.
func (ml *MethodLogger) Log(c LogEntryConfig) {
// Build is an internal only method for building the proto message out of the
// input event. It's made public to enable other library to reuse as much logic
// in methodLogger as possible.
func (ml *methodLogger) Build(c LogEntryConfig) *pb.GrpcLogEntry {
m := c.toProto()
timestamp, _ := ptypes.TimestampProto(time.Now())
m.Timestamp = timestamp
@@ -85,11 +91,15 @@ func (ml *MethodLogger) Log(c LogEntryConfig) {
case *pb.GrpcLogEntry_Message:
m.PayloadTruncated = ml.truncateMessage(pay.Message)
}
ml.sink.Write(m)
return m
}
func (ml *MethodLogger) truncateMetadata(mdPb *pb.Metadata) (truncated bool) {
// Log creates a proto binary log entry, and logs it to the sink.
func (ml *methodLogger) Log(c LogEntryConfig) {
ml.sink.Write(ml.Build(c))
}
func (ml *methodLogger) truncateMetadata(mdPb *pb.Metadata) (truncated bool) {
if ml.headerMaxLen == maxUInt {
return false
}
@@ -119,7 +129,7 @@ func (ml *MethodLogger) truncateMetadata(mdPb *pb.Metadata) (truncated bool) {
return truncated
}
func (ml *MethodLogger) truncateMessage(msgPb *pb.Message) (truncated bool) {
func (ml *methodLogger) truncateMessage(msgPb *pb.Message) (truncated bool) {
if ml.messageMaxLen == maxUInt {
return false
}

View File

@@ -25,6 +25,7 @@ package channelz
import (
"context"
"errors"
"fmt"
"sort"
"sync"
@@ -184,54 +185,77 @@ func GetServer(id int64) *ServerMetric {
return db.get().GetServer(id)
}
// RegisterChannel registers the given channel c in channelz database with ref
// as its reference name, and add it to the child list of its parent (identified
// by pid). pid = 0 means no parent. It returns the unique channelz tracking id
// assigned to this channel.
func RegisterChannel(c Channel, pid int64, ref string) int64 {
// RegisterChannel registers the given channel c in the channelz database with
// ref as its reference name, and adds it to the child list of its parent
// (identified by pid). pid == nil means no parent.
//
// Returns a unique channelz identifier assigned to this channel.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier {
id := idGen.genID()
var parent int64
isTopChannel := true
if pid != nil {
isTopChannel = false
parent = pid.Int()
}
if !IsOn() {
return newIdentifer(RefChannel, id, pid)
}
cn := &channel{
refName: ref,
c: c,
subChans: make(map[int64]string),
nestedChans: make(map[int64]string),
id: id,
pid: pid,
pid: parent,
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
}
if pid == 0 {
db.get().addChannel(id, cn, true, pid)
} else {
db.get().addChannel(id, cn, false, pid)
}
return id
db.get().addChannel(id, cn, isTopChannel, parent)
return newIdentifer(RefChannel, id, pid)
}
// RegisterSubChannel registers the given channel c in channelz database with ref
// as its reference name, and add it to the child list of its parent (identified
// by pid). It returns the unique channelz tracking id assigned to this subchannel.
func RegisterSubChannel(c Channel, pid int64, ref string) int64 {
if pid == 0 {
logger.Error("a SubChannel's parent id cannot be 0")
return 0
// RegisterSubChannel registers the given subChannel c in the channelz database
// with ref as its reference name, and adds it to the child list of its parent
// (identified by pid).
//
// Returns a unique channelz identifier assigned to this subChannel.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterSubChannel(c Channel, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a SubChannel's parent id cannot be nil")
}
id := idGen.genID()
if !IsOn() {
return newIdentifer(RefSubChannel, id, pid), nil
}
sc := &subChannel{
refName: ref,
c: c,
sockets: make(map[int64]string),
id: id,
pid: pid,
pid: pid.Int(),
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
}
db.get().addSubChannel(id, sc, pid)
return id
db.get().addSubChannel(id, sc, pid.Int())
return newIdentifer(RefSubChannel, id, pid), nil
}
// RegisterServer registers the given server s in channelz database. It returns
// the unique channelz tracking id assigned to this server.
func RegisterServer(s Server, ref string) int64 {
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterServer(s Server, ref string) *Identifier {
id := idGen.genID()
if !IsOn() {
return newIdentifer(RefServer, id, nil)
}
svr := &server{
refName: ref,
s: s,
@@ -240,71 +264,92 @@ func RegisterServer(s Server, ref string) int64 {
id: id,
}
db.get().addServer(id, svr)
return id
return newIdentifer(RefServer, id, nil)
}
// RegisterListenSocket registers the given listen socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this listen socket.
func RegisterListenSocket(s Socket, pid int64, ref string) int64 {
if pid == 0 {
logger.Error("a ListenSocket's parent id cannot be 0")
return 0
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterListenSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a ListenSocket's parent id cannot be 0")
}
id := idGen.genID()
ls := &listenSocket{refName: ref, s: s, id: id, pid: pid}
db.get().addListenSocket(id, ls, pid)
return id
if !IsOn() {
return newIdentifer(RefListenSocket, id, pid), nil
}
ls := &listenSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addListenSocket(id, ls, pid.Int())
return newIdentifer(RefListenSocket, id, pid), nil
}
// RegisterNormalSocket registers the given normal socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// with ref as its reference name, and adds it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this normal socket.
func RegisterNormalSocket(s Socket, pid int64, ref string) int64 {
if pid == 0 {
logger.Error("a NormalSocket's parent id cannot be 0")
return 0
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterNormalSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a NormalSocket's parent id cannot be 0")
}
id := idGen.genID()
ns := &normalSocket{refName: ref, s: s, id: id, pid: pid}
db.get().addNormalSocket(id, ns, pid)
return id
if !IsOn() {
return newIdentifer(RefNormalSocket, id, pid), nil
}
ns := &normalSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addNormalSocket(id, ns, pid.Int())
return newIdentifer(RefNormalSocket, id, pid), nil
}
// RemoveEntry removes an entry with unique channelz tracking id to be id from
// channelz database.
func RemoveEntry(id int64) {
db.get().removeEntry(id)
//
// If channelz is not turned ON, this function is a no-op.
func RemoveEntry(id *Identifier) {
if !IsOn() {
return
}
db.get().removeEntry(id.Int())
}
// TraceEventDesc is what the caller of AddTraceEvent should provide to describe the event to be added
// to the channel trace.
// The Parent field is optional. It is used for event that will be recorded in the entity's parent
// trace also.
// TraceEventDesc is what the caller of AddTraceEvent should provide to describe
// the event to be added to the channel trace.
//
// The Parent field is optional. It is used for an event that will be recorded
// in the entity's parent trace.
type TraceEventDesc struct {
Desc string
Severity Severity
Parent *TraceEventDesc
}
// AddTraceEvent adds trace related to the entity with specified id, using the provided TraceEventDesc.
func AddTraceEvent(l grpclog.DepthLoggerV2, id int64, depth int, desc *TraceEventDesc) {
for d := desc; d != nil; d = d.Parent {
switch d.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, d.Desc)
case CtWarning:
l.WarningDepth(depth+1, d.Desc)
case CtError:
l.ErrorDepth(depth+1, d.Desc)
}
// AddTraceEvent adds trace related to the entity with specified id, using the
// provided TraceEventDesc.
//
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, id *Identifier, depth int, desc *TraceEventDesc) {
// Log only the trace description associated with the bottom most entity.
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, withParens(id)+desc.Desc)
case CtWarning:
l.WarningDepth(depth+1, withParens(id)+desc.Desc)
case CtError:
l.ErrorDepth(depth+1, withParens(id)+desc.Desc)
}
if getMaxTraceEntry() == 0 {
return
}
db.get().traceEvent(id, desc)
if IsOn() {
db.get().traceEvent(id.Int(), desc)
}
}
// channelMap is the storage data structure for channelz.

75
vendor/google.golang.org/grpc/internal/channelz/id.go generated vendored Normal file
View File

@@ -0,0 +1,75 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import "fmt"
// Identifier is an opaque identifier which uniquely identifies an entity in the
// channelz database.
type Identifier struct {
typ RefChannelType
id int64
str string
pid *Identifier
}
// Type returns the entity type corresponding to id.
func (id *Identifier) Type() RefChannelType {
return id.typ
}
// Int returns the integer identifier corresponding to id.
func (id *Identifier) Int() int64 {
return id.id
}
// String returns a string representation of the entity corresponding to id.
//
// This includes some information about the parent as well. Examples:
// Top-level channel: [Channel #channel-number]
// Nested channel: [Channel #parent-channel-number Channel #channel-number]
// Sub channel: [Channel #parent-channel SubChannel #subchannel-number]
func (id *Identifier) String() string {
return id.str
}
// Equal returns true if other is the same as id.
func (id *Identifier) Equal(other *Identifier) bool {
if (id != nil) != (other != nil) {
return false
}
if id == nil && other == nil {
return true
}
return id.typ == other.typ && id.id == other.id && id.pid == other.pid
}
// NewIdentifierForTesting returns a new opaque identifier to be used only for
// testing purposes.
func NewIdentifierForTesting(typ RefChannelType, id int64, pid *Identifier) *Identifier {
return newIdentifer(typ, id, pid)
}
func newIdentifer(typ RefChannelType, id int64, pid *Identifier) *Identifier {
str := fmt.Sprintf("%s #%d", typ, id)
if pid != nil {
str = fmt.Sprintf("%s %s", pid, str)
}
return &Identifier{typ: typ, id: id, str: str, pid: pid}
}

View File

@@ -26,77 +26,54 @@ import (
var logger = grpclog.Component("channelz")
func withParens(id *Identifier) string {
return "[" + id.String() + "] "
}
// Info logs and adds a trace event if channelz is on.
func Info(l grpclog.DepthLoggerV2, id int64, args ...interface{}) {
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtInfo,
})
} else {
l.InfoDepth(1, args...)
}
func Info(l grpclog.DepthLoggerV2, id *Identifier, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtInfo,
})
}
// Infof logs and adds a trace event if channelz is on.
func Infof(l grpclog.DepthLoggerV2, id int64, format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: msg,
Severity: CtInfo,
})
} else {
l.InfoDepth(1, msg)
}
func Infof(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprintf(format, args...),
Severity: CtInfo,
})
}
// Warning logs and adds a trace event if channelz is on.
func Warning(l grpclog.DepthLoggerV2, id int64, args ...interface{}) {
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtWarning,
})
} else {
l.WarningDepth(1, args...)
}
func Warning(l grpclog.DepthLoggerV2, id *Identifier, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtWarning,
})
}
// Warningf logs and adds a trace event if channelz is on.
func Warningf(l grpclog.DepthLoggerV2, id int64, format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: msg,
Severity: CtWarning,
})
} else {
l.WarningDepth(1, msg)
}
func Warningf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprintf(format, args...),
Severity: CtWarning,
})
}
// Error logs and adds a trace event if channelz is on.
func Error(l grpclog.DepthLoggerV2, id int64, args ...interface{}) {
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtError,
})
} else {
l.ErrorDepth(1, args...)
}
func Error(l grpclog.DepthLoggerV2, id *Identifier, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprint(args...),
Severity: CtError,
})
}
// Errorf logs and adds a trace event if channelz is on.
func Errorf(l grpclog.DepthLoggerV2, id int64, format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if IsOn() {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: msg,
Severity: CtError,
})
} else {
l.ErrorDepth(1, msg)
}
func Errorf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...interface{}) {
AddTraceEvent(l, id, 1, &TraceEventDesc{
Desc: fmt.Sprintf(format, args...),
Severity: CtError,
})
}

View File

@@ -686,12 +686,33 @@ const (
type RefChannelType int
const (
// RefUnknown indicates an unknown entity type, the zero value for this type.
RefUnknown RefChannelType = iota
// RefChannel indicates the referenced entity is a Channel.
RefChannel RefChannelType = iota
RefChannel
// RefSubChannel indicates the referenced entity is a SubChannel.
RefSubChannel
// RefServer indicates the referenced entity is a Server.
RefServer
// RefListenSocket indicates the referenced entity is a ListenSocket.
RefListenSocket
// RefNormalSocket indicates the referenced entity is a NormalSocket.
RefNormalSocket
)
var refChannelTypeToString = map[RefChannelType]string{
RefUnknown: "Unknown",
RefChannel: "Channel",
RefSubChannel: "SubChannel",
RefServer: "Server",
RefListenSocket: "ListenSocket",
RefNormalSocket: "NormalSocket",
}
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}
func (c *channelTrace) dumpData() *ChannelTrace {
c.mu.Lock()
ct := &ChannelTrace{EventNum: c.eventCount, CreationTime: c.createdTime}

View File

@@ -85,3 +85,9 @@ const (
// that supports backend returned by grpclb balancer.
CredsBundleModeBackendFromBalancer = "backend-from-balancer"
)
// RLSLoadBalancingPolicyName is the name of the RLS LB policy.
//
// It currently has an experimental suffix which would be removed once
// end-to-end testing of the policy is completed.
const RLSLoadBalancingPolicyName = "rls_experimental"

View File

@@ -22,6 +22,9 @@
package metadata
import (
"fmt"
"strings"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
@@ -72,3 +75,46 @@ func Set(addr resolver.Address, md metadata.MD) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(mdKey, mdValue(md))
return addr
}
// Validate returns an error if the input md contains invalid keys or values.
//
// If the header is not a pseudo-header, the following items are checked:
// - header names must contain one or more characters from this set [0-9 a-z _ - .].
// - if the header-name ends with a "-bin" suffix, no validation of the header value is performed.
// - otherwise, the header value must contain one or more characters from the set [%x20-%x7E].
func Validate(md metadata.MD) error {
for k, vals := range md {
// pseudo-header will be ignored
if k[0] == ':' {
continue
}
// check key, for i that saving a conversion if not using for range
for i := 0; i < len(k); i++ {
r := k[i]
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", k)
}
}
if strings.HasSuffix(k, "-bin") {
continue
}
// check value
for _, val := range vals {
if hasNotPrintable(val) {
return fmt.Errorf("header key %q contains value with non-printable ASCII characters", k)
}
}
}
return nil
}
// hasNotPrintable return true if msg contains any characters which are not in %x20-%x7E
func hasNotPrintable(msg string) bool {
// for i that saving a conversion if not using for range
for i := 0; i < len(msg); i++ {
if msg[i] < 0x20 || msg[i] > 0x7E {
return true
}
}
return false
}

View File

@@ -0,0 +1,82 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pretty defines helper functions to pretty-print structs for logging.
package pretty
import (
"bytes"
"encoding/json"
"fmt"
"github.com/golang/protobuf/jsonpb"
protov1 "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/encoding/protojson"
protov2 "google.golang.org/protobuf/proto"
)
const jsonIndent = " "
// ToJSON marshals the input into a json string.
//
// If marshal fails, it falls back to fmt.Sprintf("%+v").
func ToJSON(e interface{}) string {
switch ee := e.(type) {
case protov1.Message:
mm := jsonpb.Marshaler{Indent: jsonIndent}
ret, err := mm.MarshalToString(ee)
if err != nil {
// This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2
// messages are not imported, and this will fail because the message
// is not found.
return fmt.Sprintf("%+v", ee)
}
return ret
case protov2.Message:
mm := protojson.MarshalOptions{
Multiline: true,
Indent: jsonIndent,
}
ret, err := mm.Marshal(ee)
if err != nil {
// This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2
// messages are not imported, and this will fail because the message
// is not found.
return fmt.Sprintf("%+v", ee)
}
return string(ret)
default:
ret, err := json.MarshalIndent(ee, "", jsonIndent)
if err != nil {
return fmt.Sprintf("%+v", ee)
}
return string(ret)
}
}
// FormatJSON formats the input json bytes with indentation.
//
// If Indent fails, it returns the unchanged input as string.
func FormatJSON(b []byte) string {
var out bytes.Buffer
err := json.Indent(&out, b, "", jsonIndent)
if err != nil {
return string(b)
}
return out.String()
}

View File

@@ -137,6 +137,7 @@ type earlyAbortStream struct {
streamID uint32
contentSubtype string
status *status.Status
rst bool
}
func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
@@ -786,6 +787,11 @@ func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
return err
}
if eas.rst {
if err := l.framer.fr.WriteRSTStream(eas.streamID, http2.ErrCodeNo); err != nil {
return err
}
}
return nil
}

View File

@@ -132,7 +132,7 @@ type http2Client struct {
kpDormant bool
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
channelzID *channelz.Identifier
czData *channelzData
onGoAway func(GoAwayReason)
@@ -351,8 +351,9 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
t.statsHandler.HandleConn(t.ctx, connBegin)
}
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr))
t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr))
if err != nil {
return nil, err
}
if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu)
@@ -630,8 +631,8 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// the wire. However, there are two notable exceptions:
//
// 1. If the stream headers violate the max header list size allowed by the
// server. In this case there is no reason to retry at all, as it is
// assumed the RPC would continue to fail on subsequent attempts.
// server. It's possible this could succeed on another transport, even if
// it's unlikely, but do not transparently retry.
// 2. If the credentials errored when requesting their headers. In this case,
// it's possible a retry can fix the problem, but indefinitely transparently
// retrying is not appropriate as it is likely the credentials, if they can
@@ -639,8 +640,7 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
type NewStreamError struct {
Err error
DoNotRetry bool
DoNotTransparentRetry bool
AllowTransparentRetry bool
}
func (e NewStreamError) Error() string {
@@ -649,11 +649,11 @@ func (e NewStreamError) Error() string {
// NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
headerFields, err := t.createHeaderFields(ctx, callHdr)
if err != nil {
return nil, &NewStreamError{Err: err, DoNotTransparentRetry: true}
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
}
s := t.newStream(ctx, callHdr)
cleanup := func(err error) {
@@ -753,13 +753,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return true
}, hdr)
if err != nil {
return nil, &NewStreamError{Err: err}
// Connection closed.
return nil, &NewStreamError{Err: err, AllowTransparentRetry: true}
}
if success {
break
}
if hdrListSizeErr != nil {
return nil, &NewStreamError{Err: hdrListSizeErr, DoNotRetry: true}
return nil, &NewStreamError{Err: hdrListSizeErr}
}
firstTry = false
select {
@@ -767,9 +768,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
case <-ctx.Done():
return nil, &NewStreamError{Err: ContextErr(ctx.Err())}
case <-t.goAway:
return nil, &NewStreamError{Err: errStreamDrain}
return nil, &NewStreamError{Err: errStreamDrain, AllowTransparentRetry: true}
case <-t.ctx.Done():
return nil, &NewStreamError{Err: ErrConnClosing}
return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true}
}
}
if t.statsHandler != nil {
@@ -898,9 +899,7 @@ func (t *http2Client) Close(err error) {
t.controlBuf.finish()
t.cancel()
t.conn.Close()
if channelz.IsOn() {
channelz.RemoveEntry(t.channelzID)
}
channelz.RemoveEntry(t.channelzID)
// Append info about previous goaways if there were any, since this may be important
// for understanding the root cause for this connection to be closed.
_, goAwayDebugMessage := t.GetGoAwayReason()

View File

@@ -21,7 +21,6 @@ package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
@@ -36,6 +35,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -52,10 +52,10 @@ import (
var (
// ErrIllegalHeaderWrite indicates that setting header is illegal because of
// the stream's state.
ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called")
ErrIllegalHeaderWrite = status.Error(codes.Internal, "transport: SendHeader called multiple times")
// ErrHeaderListSizeLimitViolation indicates that the header list size is larger
// than the limit set by peer.
ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer")
ErrHeaderListSizeLimitViolation = status.Error(codes.Internal, "transport: trying to send header list size larger than the limit set by peer")
)
// serverConnectionCounter counts the number of connections a server has seen
@@ -117,7 +117,7 @@ type http2Server struct {
idle time.Time
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
channelzID *channelz.Identifier
czData *channelzData
bufferPool *bufferPool
@@ -231,6 +231,11 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
if kp.Timeout == 0 {
kp.Timeout = defaultServerKeepaliveTimeout
}
if kp.Time != infinity {
if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil {
return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err)
}
}
kep := config.KeepalivePolicy
if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime
@@ -275,12 +280,12 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
connBegin := &stats.ConnBegin{}
t.stats.HandleConn(t.ctx, connBegin)
}
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
if err != nil {
return nil, err
}
t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1)
t.framer.writer.Flush()
defer func() {
@@ -443,6 +448,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
return false
}
@@ -516,14 +522,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
}
if httpMethod != http.MethodPost {
t.mu.Unlock()
errMsg := fmt.Sprintf("http2Server.operateHeaders parsed a :method field: %v which should be POST", httpMethod)
if logger.V(logLevel) {
logger.Infof("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", httpMethod)
logger.Infof("transport: %v", errMsg)
}
t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeProtocol,
onWrite: func() {},
t.controlBuf.put(&earlyAbortStream{
httpStatus: 405,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
s.cancel()
return false
@@ -544,6 +552,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
rst: !frame.StreamEnded(),
})
return false
}
@@ -925,11 +934,25 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool {
return true
}
func (t *http2Server) streamContextErr(s *Stream) error {
select {
case <-t.done:
return ErrConnClosing
default:
}
return ContextErr(s.ctx.Err())
}
// WriteHeader sends the header metadata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
if s.updateHeaderSent() || s.getState() == streamDone {
if s.updateHeaderSent() {
return ErrIllegalHeaderWrite
}
if s.getState() == streamDone {
return t.streamContextErr(s)
}
s.hdrMu.Lock()
if md.Len() > 0 {
if s.header.Len() > 0 {
@@ -940,7 +963,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
if err := t.writeHeaderLocked(s); err != nil {
s.hdrMu.Unlock()
return err
return status.Convert(err).Err()
}
s.hdrMu.Unlock()
return nil
@@ -1056,23 +1079,12 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
if _, ok := err.(ConnectionError); ok {
return err
}
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
return status.Errorf(codes.Internal, "transport: %v", err)
return err
}
} else {
// Writing headers checks for this condition.
if s.getState() == streamDone {
// TODO(mmukhi, dfawley): Should the server write also return io.EOF?
s.cancel()
select {
case <-t.done:
return ErrConnClosing
default:
}
return ContextErr(s.ctx.Err())
return t.streamContextErr(s)
}
}
df := &dataFrame{
@@ -1082,12 +1094,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
onEachWrite: t.setResetPingStrikes,
}
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
select {
case <-t.done:
return ErrConnClosing
default:
}
return ContextErr(s.ctx.Err())
return t.streamContextErr(s)
}
return t.controlBuf.put(df)
}
@@ -1210,9 +1217,7 @@ func (t *http2Server) Close() {
if err := t.conn.Close(); err != nil && logger.V(logLevel) {
logger.Infof("transport: error closing conn during Close: %v", err)
}
if channelz.IsOn() {
channelz.RemoveEntry(t.channelzID)
}
channelz.RemoveEntry(t.channelzID)
// Cancel all active streams.
for _, s := range streams {
s.cancel()
@@ -1225,10 +1230,6 @@ func (t *http2Server) Close() {
// deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok {
@@ -1250,6 +1251,11 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
// finishStream closes the stream and puts the trailing headerFrame into controlbuf.
func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
oldState := s.swapState(streamDone)
if oldState == streamDone {
// If the stream was already done, return.
@@ -1269,6 +1275,11 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h
// closeStream clears the footprint of a stream when the stream is not needed any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
s.swapState(streamDone)
t.deleteStream(s, eosReceived)

View File

@@ -34,6 +34,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
@@ -529,7 +530,7 @@ type ServerConfig struct {
InitialConnWindowSize int32
WriteBufferSize int
ReadBufferSize int
ChannelzParentID int64
ChannelzParentID *channelz.Identifier
MaxHeaderListSize *uint32
HeaderTableSize *uint32
}
@@ -563,7 +564,7 @@ type ConnectOptions struct {
// ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall.
ReadBufferSize int
// ChannelzParentID sets the addrConn id which initiate the creation of this client transport.
ChannelzParentID int64
ChannelzParentID *channelz.Identifier
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used.