TUN-5989: Add in-memory otlp exporter

This commit is contained in:
Devin Carr
2022-04-06 16:20:29 -07:00
parent 9cde11f8e0
commit def8f57dbc
236 changed files with 33015 additions and 177 deletions

29
vendor/github.com/go-logr/logr/.golangci.yaml generated vendored Normal file
View File

@@ -0,0 +1,29 @@
run:
timeout: 1m
tests: true
linters:
disable-all: true
enable:
- asciicheck
- deadcode
- errcheck
- forcetypeassert
- gocritic
- gofmt
- goimports
- gosimple
- govet
- ineffassign
- misspell
- revive
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
issues:
exclude-use-default: false
max-issues-per-linter: 0
max-same-issues: 10

6
vendor/github.com/go-logr/logr/CHANGELOG.md generated vendored Normal file
View File

@@ -0,0 +1,6 @@
# CHANGELOG
## v1.0.0-rc1
This is the first logged release. Major changes (including breaking changes)
have occurred since earlier tags.

17
vendor/github.com/go-logr/logr/CONTRIBUTING.md generated vendored Normal file
View File

@@ -0,0 +1,17 @@
# Contributing
Logr is open to pull-requests, provided they fit within the intended scope of
the project. Specifically, this library aims to be VERY small and minimalist,
with no external dependencies.
## Compatibility
This project intends to follow [semantic versioning](http://semver.org) and
is very strict about compatibility. Any proposed changes MUST follow those
rules.
## Performance
As a logging library, logr must be as light-weight as possible. Any proposed
code change must include results of running the [benchmark](./benchmark)
before and after the change.

201
vendor/github.com/go-logr/logr/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

282
vendor/github.com/go-logr/logr/README.md generated vendored Normal file
View File

@@ -0,0 +1,282 @@
# A minimal logging API for Go
[![Go Reference](https://pkg.go.dev/badge/github.com/go-logr/logr.svg)](https://pkg.go.dev/github.com/go-logr/logr)
logr offers an(other) opinion on how Go programs and libraries can do logging
without becoming coupled to a particular logging implementation. This is not
an implementation of logging - it is an API. In fact it is two APIs with two
different sets of users.
The `Logger` type is intended for application and library authors. It provides
a relatively small API which can be used everywhere you want to emit logs. It
defers the actual act of writing logs (to files, to stdout, or whatever) to the
`LogSink` interface.
The `LogSink` interface is intended for logging library implementers. It is a
pure interface which can be implemented by logging frameworks to provide the actual logging
functionality.
This decoupling allows application and library developers to write code in
terms of `logr.Logger` (which has very low dependency fan-out) while the
implementation of logging is managed "up stack" (e.g. in or near `main()`.)
Application developers can then switch out implementations as necessary.
Many people assert that libraries should not be logging, and as such efforts
like this are pointless. Those people are welcome to convince the authors of
the tens-of-thousands of libraries that *DO* write logs that they are all
wrong. In the meantime, logr takes a more practical approach.
## Typical usage
Somewhere, early in an application's life, it will make a decision about which
logging library (implementation) it actually wants to use. Something like:
```
func main() {
// ... other setup code ...
// Create the "root" logger. We have chosen the "logimpl" implementation,
// which takes some initial parameters and returns a logr.Logger.
logger := logimpl.New(param1, param2)
// ... other setup code ...
```
Most apps will call into other libraries, create structures to govern the flow,
etc. The `logr.Logger` object can be passed to these other libraries, stored
in structs, or even used as a package-global variable, if needed. For example:
```
app := createTheAppObject(logger)
app.Run()
```
Outside of this early setup, no other packages need to know about the choice of
implementation. They write logs in terms of the `logr.Logger` that they
received:
```
type appObject struct {
// ... other fields ...
logger logr.Logger
// ... other fields ...
}
func (app *appObject) Run() {
app.logger.Info("starting up", "timestamp", time.Now())
// ... app code ...
```
## Background
If the Go standard library had defined an interface for logging, this project
probably would not be needed. Alas, here we are.
### Inspiration
Before you consider this package, please read [this blog post by the
inimitable Dave Cheney][warning-makes-no-sense]. We really appreciate what
he has to say, and it largely aligns with our own experiences.
### Differences from Dave's ideas
The main differences are:
1. Dave basically proposes doing away with the notion of a logging API in favor
of `fmt.Printf()`. We disagree, especially when you consider things like output
locations, timestamps, file and line decorations, and structured logging. This
package restricts the logging API to just 2 types of logs: info and error.
Info logs are things you want to tell the user which are not errors. Error
logs are, well, errors. If your code receives an `error` from a subordinate
function call and is logging that `error` *and not returning it*, use error
logs.
2. Verbosity-levels on info logs. This gives developers a chance to indicate
arbitrary grades of importance for info logs, without assigning names with
semantic meaning such as "warning", "trace", and "debug." Superficially this
may feel very similar, but the primary difference is the lack of semantics.
Because verbosity is a numerical value, it's safe to assume that an app running
with higher verbosity means more (and less important) logs will be generated.
## Implementations (non-exhaustive)
There are implementations for the following logging libraries:
- **a function** (can bridge to non-structured libraries): [funcr](https://github.com/go-logr/logr/tree/master/funcr)
- **a testing.T** (for use in Go tests, with JSON-like output): [testr](https://github.com/go-logr/logr/tree/master/testr)
- **github.com/google/glog**: [glogr](https://github.com/go-logr/glogr)
- **k8s.io/klog** (for Kubernetes): [klogr](https://git.k8s.io/klog/klogr)
- **a testing.T** (with klog-like text output): [ktesting](https://git.k8s.io/klog/ktesting)
- **go.uber.org/zap**: [zapr](https://github.com/go-logr/zapr)
- **log** (the Go standard library logger): [stdr](https://github.com/go-logr/stdr)
- **github.com/sirupsen/logrus**: [logrusr](https://github.com/bombsimon/logrusr)
- **github.com/wojas/genericr**: [genericr](https://github.com/wojas/genericr) (makes it easy to implement your own backend)
- **logfmt** (Heroku style [logging](https://www.brandur.org/logfmt)): [logfmtr](https://github.com/iand/logfmtr)
- **github.com/rs/zerolog**: [zerologr](https://github.com/go-logr/zerologr)
- **github.com/go-kit/log**: [gokitlogr](https://github.com/tonglil/gokitlogr) (also compatible with github.com/go-kit/kit/log since v0.12.0)
- **bytes.Buffer** (writing to a buffer): [bufrlogr](https://github.com/tonglil/buflogr) (useful for ensuring values were logged, like during testing)
## FAQ
### Conceptual
#### Why structured logging?
- **Structured logs are more easily queryable**: Since you've got
key-value pairs, it's much easier to query your structured logs for
particular values by filtering on the contents of a particular key --
think searching request logs for error codes, Kubernetes reconcilers for
the name and namespace of the reconciled object, etc.
- **Structured logging makes it easier to have cross-referenceable logs**:
Similarly to searchability, if you maintain conventions around your
keys, it becomes easy to gather all log lines related to a particular
concept.
- **Structured logs allow better dimensions of filtering**: if you have
structure to your logs, you've got more precise control over how much
information is logged -- you might choose in a particular configuration
to log certain keys but not others, only log lines where a certain key
matches a certain value, etc., instead of just having v-levels and names
to key off of.
- **Structured logs better represent structured data**: sometimes, the
data that you want to log is inherently structured (think tuple-link
objects.) Structured logs allow you to preserve that structure when
outputting.
#### Why V-levels?
**V-levels give operators an easy way to control the chattiness of log
operations**. V-levels provide a way for a given package to distinguish
the relative importance or verbosity of a given log message. Then, if
a particular logger or package is logging too many messages, the user
of the package can simply change the v-levels for that library.
#### Why not named levels, like Info/Warning/Error?
Read [Dave Cheney's post][warning-makes-no-sense]. Then read [Differences
from Dave's ideas](#differences-from-daves-ideas).
#### Why not allow format strings, too?
**Format strings negate many of the benefits of structured logs**:
- They're not easily searchable without resorting to fuzzy searching,
regular expressions, etc.
- They don't store structured data well, since contents are flattened into
a string.
- They're not cross-referenceable.
- They don't compress easily, since the message is not constant.
(Unless you turn positional parameters into key-value pairs with numerical
keys, at which point you've gotten key-value logging with meaningless
keys.)
### Practical
#### Why key-value pairs, and not a map?
Key-value pairs are *much* easier to optimize, especially around
allocations. Zap (a structured logger that inspired logr's interface) has
[performance measurements](https://github.com/uber-go/zap#performance)
that show this quite nicely.
While the interface ends up being a little less obvious, you get
potentially better performance, plus avoid making users type
`map[string]string{}` every time they want to log.
#### What if my V-levels differ between libraries?
That's fine. Control your V-levels on a per-logger basis, and use the
`WithName` method to pass different loggers to different libraries.
Generally, you should take care to ensure that you have relatively
consistent V-levels within a given logger, however, as this makes deciding
on what verbosity of logs to request easier.
#### But I really want to use a format string!
That's not actually a question. Assuming your question is "how do
I convert my mental model of logging with format strings to logging with
constant messages":
1. Figure out what the error actually is, as you'd write in a TL;DR style,
and use that as a message.
2. For every place you'd write a format specifier, look to the word before
it, and add that as a key value pair.
For instance, consider the following examples (all taken from spots in the
Kubernetes codebase):
- `klog.V(4).Infof("Client is returning errors: code %v, error %v",
responseCode, err)` becomes `logger.Error(err, "client returned an
error", "code", responseCode)`
- `klog.V(4).Infof("Got a Retry-After %ds response for attempt %d to %v",
seconds, retries, url)` becomes `logger.V(4).Info("got a retry-after
response when requesting url", "attempt", retries, "after
seconds", seconds, "url", url)`
If you *really* must use a format string, use it in a key's value, and
call `fmt.Sprintf` yourself. For instance: `log.Printf("unable to
reflect over type %T")` becomes `logger.Info("unable to reflect over
type", "type", fmt.Sprintf("%T"))`. In general though, the cases where
this is necessary should be few and far between.
#### How do I choose my V-levels?
This is basically the only hard constraint: increase V-levels to denote
more verbose or more debug-y logs.
Otherwise, you can start out with `0` as "you always want to see this",
`1` as "common logging that you might *possibly* want to turn off", and
`10` as "I would like to performance-test your log collection stack."
Then gradually choose levels in between as you need them, working your way
down from 10 (for debug and trace style logs) and up from 1 (for chattier
info-type logs.)
#### How do I choose my keys?
Keys are fairly flexible, and can hold more or less any string
value. For best compatibility with implementations and consistency
with existing code in other projects, there are a few conventions you
should consider.
- Make your keys human-readable.
- Constant keys are generally a good idea.
- Be consistent across your codebase.
- Keys should naturally match parts of the message string.
- Use lower case for simple keys and
[lowerCamelCase](https://en.wiktionary.org/wiki/lowerCamelCase) for
more complex ones. Kubernetes is one example of a project that has
[adopted that
convention](https://github.com/kubernetes/community/blob/HEAD/contributors/devel/sig-instrumentation/migration-to-structured-logging.md#name-arguments).
While key names are mostly unrestricted (and spaces are acceptable),
it's generally a good idea to stick to printable ascii characters, or at
least match the general character set of your log lines.
#### Why should keys be constant values?
The point of structured logging is to make later log processing easier. Your
keys are, effectively, the schema of each log message. If you use different
keys across instances of the same log line, you will make your structured logs
much harder to use. `Sprintf()` is for values, not for keys!
#### Why is this not a pure interface?
The Logger type is implemented as a struct in order to allow the Go compiler to
optimize things like high-V `Info` logs that are not triggered. Not all of
these implementations are implemented yet, but this structure was suggested as
a way to ensure they *can* be implemented. All of the real work is behind the
`LogSink` interface.
[warning-makes-no-sense]: http://dave.cheney.net/2015/11/05/lets-talk-about-logging

54
vendor/github.com/go-logr/logr/discard.go generated vendored Normal file
View File

@@ -0,0 +1,54 @@
/*
Copyright 2020 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package logr
// Discard returns a Logger that discards all messages logged to it. It can be
// used whenever the caller is not interested in the logs. Logger instances
// produced by this function always compare as equal.
func Discard() Logger {
return Logger{
level: 0,
sink: discardLogSink{},
}
}
// discardLogSink is a LogSink that discards all messages.
type discardLogSink struct{}
// Verify that it actually implements the interface
var _ LogSink = discardLogSink{}
func (l discardLogSink) Init(RuntimeInfo) {
}
func (l discardLogSink) Enabled(int) bool {
return false
}
func (l discardLogSink) Info(int, string, ...interface{}) {
}
func (l discardLogSink) Error(error, string, ...interface{}) {
}
func (l discardLogSink) WithValues(...interface{}) LogSink {
return l
}
func (l discardLogSink) WithName(string) LogSink {
return l
}

787
vendor/github.com/go-logr/logr/funcr/funcr.go generated vendored Normal file
View File

@@ -0,0 +1,787 @@
/*
Copyright 2021 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package funcr implements formatting of structured log messages and
// optionally captures the call site and timestamp.
//
// The simplest way to use it is via its implementation of a
// github.com/go-logr/logr.LogSink with output through an arbitrary
// "write" function. See New and NewJSON for details.
//
// Custom LogSinks
//
// For users who need more control, a funcr.Formatter can be embedded inside
// your own custom LogSink implementation. This is useful when the LogSink
// needs to implement additional methods, for example.
//
// Formatting
//
// This will respect logr.Marshaler, fmt.Stringer, and error interfaces for
// values which are being logged. When rendering a struct, funcr will use Go's
// standard JSON tags (all except "string").
package funcr
import (
"bytes"
"encoding"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"time"
"github.com/go-logr/logr"
)
// New returns a logr.Logger which is implemented by an arbitrary function.
func New(fn func(prefix, args string), opts Options) logr.Logger {
return logr.New(newSink(fn, NewFormatter(opts)))
}
// NewJSON returns a logr.Logger which is implemented by an arbitrary function
// and produces JSON output.
func NewJSON(fn func(obj string), opts Options) logr.Logger {
fnWrapper := func(_, obj string) {
fn(obj)
}
return logr.New(newSink(fnWrapper, NewFormatterJSON(opts)))
}
// Underlier exposes access to the underlying logging function. Since
// callers only have a logr.Logger, they have to know which
// implementation is in use, so this interface is less of an
// abstraction and more of a way to test type conversion.
type Underlier interface {
GetUnderlying() func(prefix, args string)
}
func newSink(fn func(prefix, args string), formatter Formatter) logr.LogSink {
l := &fnlogger{
Formatter: formatter,
write: fn,
}
// For skipping fnlogger.Info and fnlogger.Error.
l.Formatter.AddCallDepth(1)
return l
}
// Options carries parameters which influence the way logs are generated.
type Options struct {
// LogCaller tells funcr to add a "caller" key to some or all log lines.
// This has some overhead, so some users might not want it.
LogCaller MessageClass
// LogCallerFunc tells funcr to also log the calling function name. This
// has no effect if caller logging is not enabled (see Options.LogCaller).
LogCallerFunc bool
// LogTimestamp tells funcr to add a "ts" key to log lines. This has some
// overhead, so some users might not want it.
LogTimestamp bool
// TimestampFormat tells funcr how to render timestamps when LogTimestamp
// is enabled. If not specified, a default format will be used. For more
// details, see docs for Go's time.Layout.
TimestampFormat string
// Verbosity tells funcr which V logs to produce. Higher values enable
// more logs. Info logs at or below this level will be written, while logs
// above this level will be discarded.
Verbosity int
// RenderBuiltinsHook allows users to mutate the list of key-value pairs
// while a log line is being rendered. The kvList argument follows logr
// conventions - each pair of slice elements is comprised of a string key
// and an arbitrary value (verified and sanitized before calling this
// hook). The value returned must follow the same conventions. This hook
// can be used to audit or modify logged data. For example, you might want
// to prefix all of funcr's built-in keys with some string. This hook is
// only called for built-in (provided by funcr itself) key-value pairs.
// Equivalent hooks are offered for key-value pairs saved via
// logr.Logger.WithValues or Formatter.AddValues (see RenderValuesHook) and
// for user-provided pairs (see RenderArgsHook).
RenderBuiltinsHook func(kvList []interface{}) []interface{}
// RenderValuesHook is the same as RenderBuiltinsHook, except that it is
// only called for key-value pairs saved via logr.Logger.WithValues. See
// RenderBuiltinsHook for more details.
RenderValuesHook func(kvList []interface{}) []interface{}
// RenderArgsHook is the same as RenderBuiltinsHook, except that it is only
// called for key-value pairs passed directly to Info and Error. See
// RenderBuiltinsHook for more details.
RenderArgsHook func(kvList []interface{}) []interface{}
// MaxLogDepth tells funcr how many levels of nested fields (e.g. a struct
// that contains a struct, etc.) it may log. Every time it finds a struct,
// slice, array, or map the depth is increased by one. When the maximum is
// reached, the value will be converted to a string indicating that the max
// depth has been exceeded. If this field is not specified, a default
// value will be used.
MaxLogDepth int
}
// MessageClass indicates which category or categories of messages to consider.
type MessageClass int
const (
// None ignores all message classes.
None MessageClass = iota
// All considers all message classes.
All
// Info only considers info messages.
Info
// Error only considers error messages.
Error
)
// fnlogger inherits some of its LogSink implementation from Formatter
// and just needs to add some glue code.
type fnlogger struct {
Formatter
write func(prefix, args string)
}
func (l fnlogger) WithName(name string) logr.LogSink {
l.Formatter.AddName(name)
return &l
}
func (l fnlogger) WithValues(kvList ...interface{}) logr.LogSink {
l.Formatter.AddValues(kvList)
return &l
}
func (l fnlogger) WithCallDepth(depth int) logr.LogSink {
l.Formatter.AddCallDepth(depth)
return &l
}
func (l fnlogger) Info(level int, msg string, kvList ...interface{}) {
prefix, args := l.FormatInfo(level, msg, kvList)
l.write(prefix, args)
}
func (l fnlogger) Error(err error, msg string, kvList ...interface{}) {
prefix, args := l.FormatError(err, msg, kvList)
l.write(prefix, args)
}
func (l fnlogger) GetUnderlying() func(prefix, args string) {
return l.write
}
// Assert conformance to the interfaces.
var _ logr.LogSink = &fnlogger{}
var _ logr.CallDepthLogSink = &fnlogger{}
var _ Underlier = &fnlogger{}
// NewFormatter constructs a Formatter which emits a JSON-like key=value format.
func NewFormatter(opts Options) Formatter {
return newFormatter(opts, outputKeyValue)
}
// NewFormatterJSON constructs a Formatter which emits strict JSON.
func NewFormatterJSON(opts Options) Formatter {
return newFormatter(opts, outputJSON)
}
// Defaults for Options.
const defaultTimestampFormat = "2006-01-02 15:04:05.000000"
const defaultMaxLogDepth = 16
func newFormatter(opts Options, outfmt outputFormat) Formatter {
if opts.TimestampFormat == "" {
opts.TimestampFormat = defaultTimestampFormat
}
if opts.MaxLogDepth == 0 {
opts.MaxLogDepth = defaultMaxLogDepth
}
f := Formatter{
outputFormat: outfmt,
prefix: "",
values: nil,
depth: 0,
opts: opts,
}
return f
}
// Formatter is an opaque struct which can be embedded in a LogSink
// implementation. It should be constructed with NewFormatter. Some of
// its methods directly implement logr.LogSink.
type Formatter struct {
outputFormat outputFormat
prefix string
values []interface{}
valuesStr string
depth int
opts Options
}
// outputFormat indicates which outputFormat to use.
type outputFormat int
const (
// outputKeyValue emits a JSON-like key=value format, but not strict JSON.
outputKeyValue outputFormat = iota
// outputJSON emits strict JSON.
outputJSON
)
// PseudoStruct is a list of key-value pairs that gets logged as a struct.
type PseudoStruct []interface{}
// render produces a log line, ready to use.
func (f Formatter) render(builtins, args []interface{}) string {
// Empirically bytes.Buffer is faster than strings.Builder for this.
buf := bytes.NewBuffer(make([]byte, 0, 1024))
if f.outputFormat == outputJSON {
buf.WriteByte('{')
}
vals := builtins
if hook := f.opts.RenderBuiltinsHook; hook != nil {
vals = hook(f.sanitize(vals))
}
f.flatten(buf, vals, false, false) // keys are ours, no need to escape
continuing := len(builtins) > 0
if len(f.valuesStr) > 0 {
if continuing {
if f.outputFormat == outputJSON {
buf.WriteByte(',')
} else {
buf.WriteByte(' ')
}
}
continuing = true
buf.WriteString(f.valuesStr)
}
vals = args
if hook := f.opts.RenderArgsHook; hook != nil {
vals = hook(f.sanitize(vals))
}
f.flatten(buf, vals, continuing, true) // escape user-provided keys
if f.outputFormat == outputJSON {
buf.WriteByte('}')
}
return buf.String()
}
// flatten renders a list of key-value pairs into a buffer. If continuing is
// true, it assumes that the buffer has previous values and will emit a
// separator (which depends on the output format) before the first pair it
// writes. If escapeKeys is true, the keys are assumed to have
// non-JSON-compatible characters in them and must be evaluated for escapes.
//
// This function returns a potentially modified version of kvList, which
// ensures that there is a value for every key (adding a value if needed) and
// that each key is a string (substituting a key if needed).
func (f Formatter) flatten(buf *bytes.Buffer, kvList []interface{}, continuing bool, escapeKeys bool) []interface{} {
// This logic overlaps with sanitize() but saves one type-cast per key,
// which can be measurable.
if len(kvList)%2 != 0 {
kvList = append(kvList, noValue)
}
for i := 0; i < len(kvList); i += 2 {
k, ok := kvList[i].(string)
if !ok {
k = f.nonStringKey(kvList[i])
kvList[i] = k
}
v := kvList[i+1]
if i > 0 || continuing {
if f.outputFormat == outputJSON {
buf.WriteByte(',')
} else {
// In theory the format could be something we don't understand. In
// practice, we control it, so it won't be.
buf.WriteByte(' ')
}
}
if escapeKeys {
buf.WriteString(prettyString(k))
} else {
// this is faster
buf.WriteByte('"')
buf.WriteString(k)
buf.WriteByte('"')
}
if f.outputFormat == outputJSON {
buf.WriteByte(':')
} else {
buf.WriteByte('=')
}
buf.WriteString(f.pretty(v))
}
return kvList
}
func (f Formatter) pretty(value interface{}) string {
return f.prettyWithFlags(value, 0, 0)
}
const (
flagRawStruct = 0x1 // do not print braces on structs
)
// TODO: This is not fast. Most of the overhead goes here.
func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) string {
if depth > f.opts.MaxLogDepth {
return `"<max-log-depth-exceeded>"`
}
// Handle types that take full control of logging.
if v, ok := value.(logr.Marshaler); ok {
// Replace the value with what the type wants to get logged.
// That then gets handled below via reflection.
value = invokeMarshaler(v)
}
// Handle types that want to format themselves.
switch v := value.(type) {
case fmt.Stringer:
value = invokeStringer(v)
case error:
value = invokeError(v)
}
// Handling the most common types without reflect is a small perf win.
switch v := value.(type) {
case bool:
return strconv.FormatBool(v)
case string:
return prettyString(v)
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(int64(v), 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
case uintptr:
return strconv.FormatUint(uint64(v), 10)
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case complex64:
return `"` + strconv.FormatComplex(complex128(v), 'f', -1, 64) + `"`
case complex128:
return `"` + strconv.FormatComplex(v, 'f', -1, 128) + `"`
case PseudoStruct:
buf := bytes.NewBuffer(make([]byte, 0, 1024))
v = f.sanitize(v)
if flags&flagRawStruct == 0 {
buf.WriteByte('{')
}
for i := 0; i < len(v); i += 2 {
if i > 0 {
buf.WriteByte(',')
}
k, _ := v[i].(string) // sanitize() above means no need to check success
// arbitrary keys might need escaping
buf.WriteString(prettyString(k))
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(v[i+1], 0, depth+1))
}
if flags&flagRawStruct == 0 {
buf.WriteByte('}')
}
return buf.String()
}
buf := bytes.NewBuffer(make([]byte, 0, 256))
t := reflect.TypeOf(value)
if t == nil {
return "null"
}
v := reflect.ValueOf(value)
switch t.Kind() {
case reflect.Bool:
return strconv.FormatBool(v.Bool())
case reflect.String:
return prettyString(v.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(int64(v.Int()), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(uint64(v.Uint()), 10)
case reflect.Float32:
return strconv.FormatFloat(float64(v.Float()), 'f', -1, 32)
case reflect.Float64:
return strconv.FormatFloat(v.Float(), 'f', -1, 64)
case reflect.Complex64:
return `"` + strconv.FormatComplex(complex128(v.Complex()), 'f', -1, 64) + `"`
case reflect.Complex128:
return `"` + strconv.FormatComplex(v.Complex(), 'f', -1, 128) + `"`
case reflect.Struct:
if flags&flagRawStruct == 0 {
buf.WriteByte('{')
}
for i := 0; i < t.NumField(); i++ {
fld := t.Field(i)
if fld.PkgPath != "" {
// reflect says this field is only defined for non-exported fields.
continue
}
if !v.Field(i).CanInterface() {
// reflect isn't clear exactly what this means, but we can't use it.
continue
}
name := ""
omitempty := false
if tag, found := fld.Tag.Lookup("json"); found {
if tag == "-" {
continue
}
if comma := strings.Index(tag, ","); comma != -1 {
if n := tag[:comma]; n != "" {
name = n
}
rest := tag[comma:]
if strings.Contains(rest, ",omitempty,") || strings.HasSuffix(rest, ",omitempty") {
omitempty = true
}
} else {
name = tag
}
}
if omitempty && isEmpty(v.Field(i)) {
continue
}
if i > 0 {
buf.WriteByte(',')
}
if fld.Anonymous && fld.Type.Kind() == reflect.Struct && name == "" {
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), flags|flagRawStruct, depth+1))
continue
}
if name == "" {
name = fld.Name
}
// field names can't contain characters which need escaping
buf.WriteByte('"')
buf.WriteString(name)
buf.WriteByte('"')
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), 0, depth+1))
}
if flags&flagRawStruct == 0 {
buf.WriteByte('}')
}
return buf.String()
case reflect.Slice, reflect.Array:
buf.WriteByte('[')
for i := 0; i < v.Len(); i++ {
if i > 0 {
buf.WriteByte(',')
}
e := v.Index(i)
buf.WriteString(f.prettyWithFlags(e.Interface(), 0, depth+1))
}
buf.WriteByte(']')
return buf.String()
case reflect.Map:
buf.WriteByte('{')
// This does not sort the map keys, for best perf.
it := v.MapRange()
i := 0
for it.Next() {
if i > 0 {
buf.WriteByte(',')
}
// If a map key supports TextMarshaler, use it.
keystr := ""
if m, ok := it.Key().Interface().(encoding.TextMarshaler); ok {
txt, err := m.MarshalText()
if err != nil {
keystr = fmt.Sprintf("<error-MarshalText: %s>", err.Error())
} else {
keystr = string(txt)
}
keystr = prettyString(keystr)
} else {
// prettyWithFlags will produce already-escaped values
keystr = f.prettyWithFlags(it.Key().Interface(), 0, depth+1)
if t.Key().Kind() != reflect.String {
// JSON only does string keys. Unlike Go's standard JSON, we'll
// convert just about anything to a string.
keystr = prettyString(keystr)
}
}
buf.WriteString(keystr)
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(it.Value().Interface(), 0, depth+1))
i++
}
buf.WriteByte('}')
return buf.String()
case reflect.Ptr, reflect.Interface:
if v.IsNil() {
return "null"
}
return f.prettyWithFlags(v.Elem().Interface(), 0, depth)
}
return fmt.Sprintf(`"<unhandled-%s>"`, t.Kind().String())
}
func prettyString(s string) string {
// Avoid escaping (which does allocations) if we can.
if needsEscape(s) {
return strconv.Quote(s)
}
b := bytes.NewBuffer(make([]byte, 0, 1024))
b.WriteByte('"')
b.WriteString(s)
b.WriteByte('"')
return b.String()
}
// needsEscape determines whether the input string needs to be escaped or not,
// without doing any allocations.
func needsEscape(s string) bool {
for _, r := range s {
if !strconv.IsPrint(r) || r == '\\' || r == '"' {
return true
}
}
return false
}
func isEmpty(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Complex64, reflect.Complex128:
return v.Complex() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return false
}
func invokeMarshaler(m logr.Marshaler) (ret interface{}) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return m.MarshalLog()
}
func invokeStringer(s fmt.Stringer) (ret string) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return s.String()
}
func invokeError(e error) (ret string) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return e.Error()
}
// Caller represents the original call site for a log line, after considering
// logr.Logger.WithCallDepth and logr.Logger.WithCallStackHelper. The File and
// Line fields will always be provided, while the Func field is optional.
// Users can set the render hook fields in Options to examine logged key-value
// pairs, one of which will be {"caller", Caller} if the Options.LogCaller
// field is enabled for the given MessageClass.
type Caller struct {
// File is the basename of the file for this call site.
File string `json:"file"`
// Line is the line number in the file for this call site.
Line int `json:"line"`
// Func is the function name for this call site, or empty if
// Options.LogCallerFunc is not enabled.
Func string `json:"function,omitempty"`
}
func (f Formatter) caller() Caller {
// +1 for this frame, +1 for Info/Error.
pc, file, line, ok := runtime.Caller(f.depth + 2)
if !ok {
return Caller{"<unknown>", 0, ""}
}
fn := ""
if f.opts.LogCallerFunc {
if fp := runtime.FuncForPC(pc); fp != nil {
fn = fp.Name()
}
}
return Caller{filepath.Base(file), line, fn}
}
const noValue = "<no-value>"
func (f Formatter) nonStringKey(v interface{}) string {
return fmt.Sprintf("<non-string-key: %s>", f.snippet(v))
}
// snippet produces a short snippet string of an arbitrary value.
func (f Formatter) snippet(v interface{}) string {
const snipLen = 16
snip := f.pretty(v)
if len(snip) > snipLen {
snip = snip[:snipLen]
}
return snip
}
// sanitize ensures that a list of key-value pairs has a value for every key
// (adding a value if needed) and that each key is a string (substituting a key
// if needed).
func (f Formatter) sanitize(kvList []interface{}) []interface{} {
if len(kvList)%2 != 0 {
kvList = append(kvList, noValue)
}
for i := 0; i < len(kvList); i += 2 {
_, ok := kvList[i].(string)
if !ok {
kvList[i] = f.nonStringKey(kvList[i])
}
}
return kvList
}
// Init configures this Formatter from runtime info, such as the call depth
// imposed by logr itself.
// Note that this receiver is a pointer, so depth can be saved.
func (f *Formatter) Init(info logr.RuntimeInfo) {
f.depth += info.CallDepth
}
// Enabled checks whether an info message at the given level should be logged.
func (f Formatter) Enabled(level int) bool {
return level <= f.opts.Verbosity
}
// GetDepth returns the current depth of this Formatter. This is useful for
// implementations which do their own caller attribution.
func (f Formatter) GetDepth() int {
return f.depth
}
// FormatInfo renders an Info log message into strings. The prefix will be
// empty when no names were set (via AddNames), or when the output is
// configured for JSON.
func (f Formatter) FormatInfo(level int, msg string, kvList []interface{}) (prefix, argsStr string) {
args := make([]interface{}, 0, 64) // using a constant here impacts perf
prefix = f.prefix
if f.outputFormat == outputJSON {
args = append(args, "logger", prefix)
prefix = ""
}
if f.opts.LogTimestamp {
args = append(args, "ts", time.Now().Format(f.opts.TimestampFormat))
}
if policy := f.opts.LogCaller; policy == All || policy == Info {
args = append(args, "caller", f.caller())
}
args = append(args, "level", level, "msg", msg)
return prefix, f.render(args, kvList)
}
// FormatError renders an Error log message into strings. The prefix will be
// empty when no names were set (via AddNames), or when the output is
// configured for JSON.
func (f Formatter) FormatError(err error, msg string, kvList []interface{}) (prefix, argsStr string) {
args := make([]interface{}, 0, 64) // using a constant here impacts perf
prefix = f.prefix
if f.outputFormat == outputJSON {
args = append(args, "logger", prefix)
prefix = ""
}
if f.opts.LogTimestamp {
args = append(args, "ts", time.Now().Format(f.opts.TimestampFormat))
}
if policy := f.opts.LogCaller; policy == All || policy == Error {
args = append(args, "caller", f.caller())
}
args = append(args, "msg", msg)
var loggableErr interface{}
if err != nil {
loggableErr = err.Error()
}
args = append(args, "error", loggableErr)
return f.prefix, f.render(args, kvList)
}
// AddName appends the specified name. funcr uses '/' characters to separate
// name elements. Callers should not pass '/' in the provided name string, but
// this library does not actually enforce that.
func (f *Formatter) AddName(name string) {
if len(f.prefix) > 0 {
f.prefix += "/"
}
f.prefix += name
}
// AddValues adds key-value pairs to the set of saved values to be logged with
// each log line.
func (f *Formatter) AddValues(kvList []interface{}) {
// Three slice args forces a copy.
n := len(f.values)
f.values = append(f.values[:n:n], kvList...)
vals := f.values
if hook := f.opts.RenderValuesHook; hook != nil {
vals = hook(f.sanitize(vals))
}
// Pre-render values, so we don't have to do it on each Info/Error call.
buf := bytes.NewBuffer(make([]byte, 0, 1024))
f.flatten(buf, vals, false, true) // escape user-provided keys
f.valuesStr = buf.String()
}
// AddCallDepth increases the number of stack-frames to skip when attributing
// the log line to a file and line.
func (f *Formatter) AddCallDepth(depth int) {
f.depth += depth
}

510
vendor/github.com/go-logr/logr/logr.go generated vendored Normal file
View File

@@ -0,0 +1,510 @@
/*
Copyright 2019 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// This design derives from Dave Cheney's blog:
// http://dave.cheney.net/2015/11/05/lets-talk-about-logging
// Package logr defines a general-purpose logging API and abstract interfaces
// to back that API. Packages in the Go ecosystem can depend on this package,
// while callers can implement logging with whatever backend is appropriate.
//
// Usage
//
// Logging is done using a Logger instance. Logger is a concrete type with
// methods, which defers the actual logging to a LogSink interface. The main
// methods of Logger are Info() and Error(). Arguments to Info() and Error()
// are key/value pairs rather than printf-style formatted strings, emphasizing
// "structured logging".
//
// With Go's standard log package, we might write:
// log.Printf("setting target value %s", targetValue)
//
// With logr's structured logging, we'd write:
// logger.Info("setting target", "value", targetValue)
//
// Errors are much the same. Instead of:
// log.Printf("failed to open the pod bay door for user %s: %v", user, err)
//
// We'd write:
// logger.Error(err, "failed to open the pod bay door", "user", user)
//
// Info() and Error() are very similar, but they are separate methods so that
// LogSink implementations can choose to do things like attach additional
// information (such as stack traces) on calls to Error(). Error() messages are
// always logged, regardless of the current verbosity. If there is no error
// instance available, passing nil is valid.
//
// Verbosity
//
// Often we want to log information only when the application in "verbose
// mode". To write log lines that are more verbose, Logger has a V() method.
// The higher the V-level of a log line, the less critical it is considered.
// Log-lines with V-levels that are not enabled (as per the LogSink) will not
// be written. Level V(0) is the default, and logger.V(0).Info() has the same
// meaning as logger.Info(). Negative V-levels have the same meaning as V(0).
// Error messages do not have a verbosity level and are always logged.
//
// Where we might have written:
// if flVerbose >= 2 {
// log.Printf("an unusual thing happened")
// }
//
// We can write:
// logger.V(2).Info("an unusual thing happened")
//
// Logger Names
//
// Logger instances can have name strings so that all messages logged through
// that instance have additional context. For example, you might want to add
// a subsystem name:
//
// logger.WithName("compactor").Info("started", "time", time.Now())
//
// The WithName() method returns a new Logger, which can be passed to
// constructors or other functions for further use. Repeated use of WithName()
// will accumulate name "segments". These name segments will be joined in some
// way by the LogSink implementation. It is strongly recommended that name
// segments contain simple identifiers (letters, digits, and hyphen), and do
// not contain characters that could muddle the log output or confuse the
// joining operation (e.g. whitespace, commas, periods, slashes, brackets,
// quotes, etc).
//
// Saved Values
//
// Logger instances can store any number of key/value pairs, which will be
// logged alongside all messages logged through that instance. For example,
// you might want to create a Logger instance per managed object:
//
// With the standard log package, we might write:
// log.Printf("decided to set field foo to value %q for object %s/%s",
// targetValue, object.Namespace, object.Name)
//
// With logr we'd write:
// // Elsewhere: set up the logger to log the object name.
// obj.logger = mainLogger.WithValues(
// "name", obj.name, "namespace", obj.namespace)
//
// // later on...
// obj.logger.Info("setting foo", "value", targetValue)
//
// Best Practices
//
// Logger has very few hard rules, with the goal that LogSink implementations
// might have a lot of freedom to differentiate. There are, however, some
// things to consider.
//
// The log message consists of a constant message attached to the log line.
// This should generally be a simple description of what's occurring, and should
// never be a format string. Variable information can then be attached using
// named values.
//
// Keys are arbitrary strings, but should generally be constant values. Values
// may be any Go value, but how the value is formatted is determined by the
// LogSink implementation.
//
// Logger instances are meant to be passed around by value. Code that receives
// such a value can call its methods without having to check whether the
// instance is ready for use.
//
// Calling methods with the null logger (Logger{}) as instance will crash
// because it has no LogSink. Therefore this null logger should never be passed
// around. For cases where passing a logger is optional, a pointer to Logger
// should be used.
//
// Key Naming Conventions
//
// Keys are not strictly required to conform to any specification or regex, but
// it is recommended that they:
// * be human-readable and meaningful (not auto-generated or simple ordinals)
// * be constant (not dependent on input data)
// * contain only printable characters
// * not contain whitespace or punctuation
// * use lower case for simple keys and lowerCamelCase for more complex ones
//
// These guidelines help ensure that log data is processed properly regardless
// of the log implementation. For example, log implementations will try to
// output JSON data or will store data for later database (e.g. SQL) queries.
//
// While users are generally free to use key names of their choice, it's
// generally best to avoid using the following keys, as they're frequently used
// by implementations:
// * "caller": the calling information (file/line) of a particular log line
// * "error": the underlying error value in the `Error` method
// * "level": the log level
// * "logger": the name of the associated logger
// * "msg": the log message
// * "stacktrace": the stack trace associated with a particular log line or
// error (often from the `Error` message)
// * "ts": the timestamp for a log line
//
// Implementations are encouraged to make use of these keys to represent the
// above concepts, when necessary (for example, in a pure-JSON output form, it
// would be necessary to represent at least message and timestamp as ordinary
// named values).
//
// Break Glass
//
// Implementations may choose to give callers access to the underlying
// logging implementation. The recommended pattern for this is:
// // Underlier exposes access to the underlying logging implementation.
// // Since callers only have a logr.Logger, they have to know which
// // implementation is in use, so this interface is less of an abstraction
// // and more of way to test type conversion.
// type Underlier interface {
// GetUnderlying() <underlying-type>
// }
//
// Logger grants access to the sink to enable type assertions like this:
// func DoSomethingWithImpl(log logr.Logger) {
// if underlier, ok := log.GetSink()(impl.Underlier) {
// implLogger := underlier.GetUnderlying()
// ...
// }
// }
//
// Custom `With*` functions can be implemented by copying the complete
// Logger struct and replacing the sink in the copy:
// // WithFooBar changes the foobar parameter in the log sink and returns a
// // new logger with that modified sink. It does nothing for loggers where
// // the sink doesn't support that parameter.
// func WithFoobar(log logr.Logger, foobar int) logr.Logger {
// if foobarLogSink, ok := log.GetSink()(FoobarSink); ok {
// log = log.WithSink(foobarLogSink.WithFooBar(foobar))
// }
// return log
// }
//
// Don't use New to construct a new Logger with a LogSink retrieved from an
// existing Logger. Source code attribution might not work correctly and
// unexported fields in Logger get lost.
//
// Beware that the same LogSink instance may be shared by different logger
// instances. Calling functions that modify the LogSink will affect all of
// those.
package logr
import (
"context"
)
// New returns a new Logger instance. This is primarily used by libraries
// implementing LogSink, rather than end users.
func New(sink LogSink) Logger {
logger := Logger{}
logger.setSink(sink)
sink.Init(runtimeInfo)
return logger
}
// setSink stores the sink and updates any related fields. It mutates the
// logger and thus is only safe to use for loggers that are not currently being
// used concurrently.
func (l *Logger) setSink(sink LogSink) {
l.sink = sink
}
// GetSink returns the stored sink.
func (l Logger) GetSink() LogSink {
return l.sink
}
// WithSink returns a copy of the logger with the new sink.
func (l Logger) WithSink(sink LogSink) Logger {
l.setSink(sink)
return l
}
// Logger is an interface to an abstract logging implementation. This is a
// concrete type for performance reasons, but all the real work is passed on to
// a LogSink. Implementations of LogSink should provide their own constructors
// that return Logger, not LogSink.
//
// The underlying sink can be accessed through GetSink and be modified through
// WithSink. This enables the implementation of custom extensions (see "Break
// Glass" in the package documentation). Normally the sink should be used only
// indirectly.
type Logger struct {
sink LogSink
level int
}
// Enabled tests whether this Logger is enabled. For example, commandline
// flags might be used to set the logging verbosity and disable some info logs.
func (l Logger) Enabled() bool {
return l.sink.Enabled(l.level)
}
// Info logs a non-error message with the given key/value pairs as context.
//
// The msg argument should be used to add some constant description to the log
// line. The key/value pairs can then be used to add additional variable
// information. The key/value pairs must alternate string keys and arbitrary
// values.
func (l Logger) Info(msg string, keysAndValues ...interface{}) {
if l.Enabled() {
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
withHelper.GetCallStackHelper()()
}
l.sink.Info(l.level, msg, keysAndValues...)
}
}
// Error logs an error, with the given message and key/value pairs as context.
// It functions similarly to Info, but may have unique behavior, and should be
// preferred for logging errors (see the package documentations for more
// information). The log message will always be emitted, regardless of
// verbosity level.
//
// The msg argument should be used to add context to any underlying error,
// while the err argument should be used to attach the actual error that
// triggered this log line, if present. The err parameter is optional
// and nil may be passed instead of an error instance.
func (l Logger) Error(err error, msg string, keysAndValues ...interface{}) {
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
withHelper.GetCallStackHelper()()
}
l.sink.Error(err, msg, keysAndValues...)
}
// V returns a new Logger instance for a specific verbosity level, relative to
// this Logger. In other words, V-levels are additive. A higher verbosity
// level means a log message is less important. Negative V-levels are treated
// as 0.
func (l Logger) V(level int) Logger {
if level < 0 {
level = 0
}
l.level += level
return l
}
// WithValues returns a new Logger instance with additional key/value pairs.
// See Info for documentation on how key/value pairs work.
func (l Logger) WithValues(keysAndValues ...interface{}) Logger {
l.setSink(l.sink.WithValues(keysAndValues...))
return l
}
// WithName returns a new Logger instance with the specified name element added
// to the Logger's name. Successive calls with WithName append additional
// suffixes to the Logger's name. It's strongly recommended that name segments
// contain only letters, digits, and hyphens (see the package documentation for
// more information).
func (l Logger) WithName(name string) Logger {
l.setSink(l.sink.WithName(name))
return l
}
// WithCallDepth returns a Logger instance that offsets the call stack by the
// specified number of frames when logging call site information, if possible.
// This is useful for users who have helper functions between the "real" call
// site and the actual calls to Logger methods. If depth is 0 the attribution
// should be to the direct caller of this function. If depth is 1 the
// attribution should skip 1 call frame, and so on. Successive calls to this
// are additive.
//
// If the underlying log implementation supports a WithCallDepth(int) method,
// it will be called and the result returned. If the implementation does not
// support CallDepthLogSink, the original Logger will be returned.
//
// To skip one level, WithCallStackHelper() should be used instead of
// WithCallDepth(1) because it works with implementions that support the
// CallDepthLogSink and/or CallStackHelperLogSink interfaces.
func (l Logger) WithCallDepth(depth int) Logger {
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
l.setSink(withCallDepth.WithCallDepth(depth))
}
return l
}
// WithCallStackHelper returns a new Logger instance that skips the direct
// caller when logging call site information, if possible. This is useful for
// users who have helper functions between the "real" call site and the actual
// calls to Logger methods and want to support loggers which depend on marking
// each individual helper function, like loggers based on testing.T.
//
// In addition to using that new logger instance, callers also must call the
// returned function.
//
// If the underlying log implementation supports a WithCallDepth(int) method,
// WithCallDepth(1) will be called to produce a new logger. If it supports a
// WithCallStackHelper() method, that will be also called. If the
// implementation does not support either of these, the original Logger will be
// returned.
func (l Logger) WithCallStackHelper() (func(), Logger) {
var helper func()
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
l.setSink(withCallDepth.WithCallDepth(1))
}
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
helper = withHelper.GetCallStackHelper()
} else {
helper = func() {}
}
return helper, l
}
// contextKey is how we find Loggers in a context.Context.
type contextKey struct{}
// FromContext returns a Logger from ctx or an error if no Logger is found.
func FromContext(ctx context.Context) (Logger, error) {
if v, ok := ctx.Value(contextKey{}).(Logger); ok {
return v, nil
}
return Logger{}, notFoundError{}
}
// notFoundError exists to carry an IsNotFound method.
type notFoundError struct{}
func (notFoundError) Error() string {
return "no logr.Logger was present"
}
func (notFoundError) IsNotFound() bool {
return true
}
// FromContextOrDiscard returns a Logger from ctx. If no Logger is found, this
// returns a Logger that discards all log messages.
func FromContextOrDiscard(ctx context.Context) Logger {
if v, ok := ctx.Value(contextKey{}).(Logger); ok {
return v
}
return Discard()
}
// NewContext returns a new Context, derived from ctx, which carries the
// provided Logger.
func NewContext(ctx context.Context, logger Logger) context.Context {
return context.WithValue(ctx, contextKey{}, logger)
}
// RuntimeInfo holds information that the logr "core" library knows which
// LogSinks might want to know.
type RuntimeInfo struct {
// CallDepth is the number of call frames the logr library adds between the
// end-user and the LogSink. LogSink implementations which choose to print
// the original logging site (e.g. file & line) should climb this many
// additional frames to find it.
CallDepth int
}
// runtimeInfo is a static global. It must not be changed at run time.
var runtimeInfo = RuntimeInfo{
CallDepth: 1,
}
// LogSink represents a logging implementation. End-users will generally not
// interact with this type.
type LogSink interface {
// Init receives optional information about the logr library for LogSink
// implementations that need it.
Init(info RuntimeInfo)
// Enabled tests whether this LogSink is enabled at the specified V-level.
// For example, commandline flags might be used to set the logging
// verbosity and disable some info logs.
Enabled(level int) bool
// Info logs a non-error message with the given key/value pairs as context.
// The level argument is provided for optional logging. This method will
// only be called when Enabled(level) is true. See Logger.Info for more
// details.
Info(level int, msg string, keysAndValues ...interface{})
// Error logs an error, with the given message and key/value pairs as
// context. See Logger.Error for more details.
Error(err error, msg string, keysAndValues ...interface{})
// WithValues returns a new LogSink with additional key/value pairs. See
// Logger.WithValues for more details.
WithValues(keysAndValues ...interface{}) LogSink
// WithName returns a new LogSink with the specified name appended. See
// Logger.WithName for more details.
WithName(name string) LogSink
}
// CallDepthLogSink represents a Logger that knows how to climb the call stack
// to identify the original call site and can offset the depth by a specified
// number of frames. This is useful for users who have helper functions
// between the "real" call site and the actual calls to Logger methods.
// Implementations that log information about the call site (such as file,
// function, or line) would otherwise log information about the intermediate
// helper functions.
//
// This is an optional interface and implementations are not required to
// support it.
type CallDepthLogSink interface {
// WithCallDepth returns a LogSink that will offset the call
// stack by the specified number of frames when logging call
// site information.
//
// If depth is 0, the LogSink should skip exactly the number
// of call frames defined in RuntimeInfo.CallDepth when Info
// or Error are called, i.e. the attribution should be to the
// direct caller of Logger.Info or Logger.Error.
//
// If depth is 1 the attribution should skip 1 call frame, and so on.
// Successive calls to this are additive.
WithCallDepth(depth int) LogSink
}
// CallStackHelperLogSink represents a Logger that knows how to climb
// the call stack to identify the original call site and can skip
// intermediate helper functions if they mark themselves as
// helper. Go's testing package uses that approach.
//
// This is useful for users who have helper functions between the
// "real" call site and the actual calls to Logger methods.
// Implementations that log information about the call site (such as
// file, function, or line) would otherwise log information about the
// intermediate helper functions.
//
// This is an optional interface and implementations are not required
// to support it. Implementations that choose to support this must not
// simply implement it as WithCallDepth(1), because
// Logger.WithCallStackHelper will call both methods if they are
// present. This should only be implemented for LogSinks that actually
// need it, as with testing.T.
type CallStackHelperLogSink interface {
// GetCallStackHelper returns a function that must be called
// to mark the direct caller as helper function when logging
// call site information.
GetCallStackHelper() func()
}
// Marshaler is an optional interface that logged values may choose to
// implement. Loggers with structured output, such as JSON, should
// log the object return by the MarshalLog method instead of the
// original value.
type Marshaler interface {
// MarshalLog can be used to:
// - ensure that structs are not logged as strings when the original
// value has a String method: return a different type without a
// String method
// - select which fields of a complex type should get logged:
// return a simpler struct with fewer fields
// - log unexported fields: return a different struct
// with exported fields
//
// It may return any value of any type.
MarshalLog() interface{}
}

201
vendor/github.com/go-logr/stdr/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

6
vendor/github.com/go-logr/stdr/README.md generated vendored Normal file
View File

@@ -0,0 +1,6 @@
# Minimal Go logging using logr and Go's standard library
[![Go Reference](https://pkg.go.dev/badge/github.com/go-logr/stdr.svg)](https://pkg.go.dev/github.com/go-logr/stdr)
This package implements the [logr interface](https://github.com/go-logr/logr)
in terms of Go's standard log package(https://pkg.go.dev/log).

170
vendor/github.com/go-logr/stdr/stdr.go generated vendored Normal file
View File

@@ -0,0 +1,170 @@
/*
Copyright 2019 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package stdr implements github.com/go-logr/logr.Logger in terms of
// Go's standard log package.
package stdr
import (
"log"
"os"
"github.com/go-logr/logr"
"github.com/go-logr/logr/funcr"
)
// The global verbosity level. See SetVerbosity().
var globalVerbosity int
// SetVerbosity sets the global level against which all info logs will be
// compared. If this is greater than or equal to the "V" of the logger, the
// message will be logged. A higher value here means more logs will be written.
// The previous verbosity value is returned. This is not concurrent-safe -
// callers must be sure to call it from only one goroutine.
func SetVerbosity(v int) int {
old := globalVerbosity
globalVerbosity = v
return old
}
// New returns a logr.Logger which is implemented by Go's standard log package,
// or something like it. If std is nil, this will use a default logger
// instead.
//
// Example: stdr.New(log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile)))
func New(std StdLogger) logr.Logger {
return NewWithOptions(std, Options{})
}
// NewWithOptions returns a logr.Logger which is implemented by Go's standard
// log package, or something like it. See New for details.
func NewWithOptions(std StdLogger, opts Options) logr.Logger {
if std == nil {
// Go's log.Default() is only available in 1.16 and higher.
std = log.New(os.Stderr, "", log.LstdFlags)
}
if opts.Depth < 0 {
opts.Depth = 0
}
fopts := funcr.Options{
LogCaller: funcr.MessageClass(opts.LogCaller),
}
sl := &logger{
Formatter: funcr.NewFormatter(fopts),
std: std,
}
// For skipping our own logger.Info/Error.
sl.Formatter.AddCallDepth(1 + opts.Depth)
return logr.New(sl)
}
// Options carries parameters which influence the way logs are generated.
type Options struct {
// Depth biases the assumed number of call frames to the "true" caller.
// This is useful when the calling code calls a function which then calls
// stdr (e.g. a logging shim to another API). Values less than zero will
// be treated as zero.
Depth int
// LogCaller tells stdr to add a "caller" key to some or all log lines.
// Go's log package has options to log this natively, too.
LogCaller MessageClass
// TODO: add an option to log the date/time
}
// MessageClass indicates which category or categories of messages to consider.
type MessageClass int
const (
// None ignores all message classes.
None MessageClass = iota
// All considers all message classes.
All
// Info only considers info messages.
Info
// Error only considers error messages.
Error
)
// StdLogger is the subset of the Go stdlib log.Logger API that is needed for
// this adapter.
type StdLogger interface {
// Output is the same as log.Output and log.Logger.Output.
Output(calldepth int, logline string) error
}
type logger struct {
funcr.Formatter
std StdLogger
}
var _ logr.LogSink = &logger{}
var _ logr.CallDepthLogSink = &logger{}
func (l logger) Enabled(level int) bool {
return globalVerbosity >= level
}
func (l logger) Info(level int, msg string, kvList ...interface{}) {
prefix, args := l.FormatInfo(level, msg, kvList)
if prefix != "" {
args = prefix + ": " + args
}
_ = l.std.Output(l.Formatter.GetDepth()+1, args)
}
func (l logger) Error(err error, msg string, kvList ...interface{}) {
prefix, args := l.FormatError(err, msg, kvList)
if prefix != "" {
args = prefix + ": " + args
}
_ = l.std.Output(l.Formatter.GetDepth()+1, args)
}
func (l logger) WithName(name string) logr.LogSink {
l.Formatter.AddName(name)
return &l
}
func (l logger) WithValues(kvList ...interface{}) logr.LogSink {
l.Formatter.AddValues(kvList)
return &l
}
func (l logger) WithCallDepth(depth int) logr.LogSink {
l.Formatter.AddCallDepth(depth)
return &l
}
// Underlier exposes access to the underlying logging implementation. Since
// callers only have a logr.Logger, they have to know which implementation is
// in use, so this interface is less of an abstraction and more of way to test
// type conversion.
type Underlier interface {
GetUnderlying() StdLogger
}
// GetUnderlying returns the StdLogger underneath this logger. Since StdLogger
// is itself an interface, the result may or may not be a Go log.Logger.
func (l logger) GetUnderlying() StdLogger {
return l.std
}

View File

@@ -0,0 +1,27 @@
Copyright (c) 2015, Gengo, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Gengo, Inc. nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,35 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "httprule",
srcs = [
"compile.go",
"parse.go",
"types.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule",
deps = ["//utilities"],
)
go_test(
name = "httprule_test",
size = "small",
srcs = [
"compile_test.go",
"parse_test.go",
"types_test.go",
],
embed = [":httprule"],
deps = [
"//utilities",
"@com_github_golang_glog//:glog",
],
)
alias(
name = "go_default_library",
actual = ":httprule",
visibility = ["//:__subpackages__"],
)

View File

@@ -0,0 +1,121 @@
package httprule
import (
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
)
const (
opcodeVersion = 1
)
// Template is a compiled representation of path templates.
type Template struct {
// Version is the version number of the format.
Version int
// OpCodes is a sequence of operations.
OpCodes []int
// Pool is a constant pool
Pool []string
// Verb is a VERB part in the template.
Verb string
// Fields is a list of field paths bound in this template.
Fields []string
// Original template (example: /v1/a_bit_of_everything)
Template string
}
// Compiler compiles utilities representation of path templates into marshallable operations.
// They can be unmarshalled by runtime.NewPattern.
type Compiler interface {
Compile() Template
}
type op struct {
// code is the opcode of the operation
code utilities.OpCode
// str is a string operand of the code.
// num is ignored if str is not empty.
str string
// num is a numeric operand of the code.
num int
}
func (w wildcard) compile() []op {
return []op{
{code: utilities.OpPush},
}
}
func (w deepWildcard) compile() []op {
return []op{
{code: utilities.OpPushM},
}
}
func (l literal) compile() []op {
return []op{
{
code: utilities.OpLitPush,
str: string(l),
},
}
}
func (v variable) compile() []op {
var ops []op
for _, s := range v.segments {
ops = append(ops, s.compile()...)
}
ops = append(ops, op{
code: utilities.OpConcatN,
num: len(v.segments),
}, op{
code: utilities.OpCapture,
str: v.path,
})
return ops
}
func (t template) Compile() Template {
var rawOps []op
for _, s := range t.segments {
rawOps = append(rawOps, s.compile()...)
}
var (
ops []int
pool []string
fields []string
)
consts := make(map[string]int)
for _, op := range rawOps {
ops = append(ops, int(op.code))
if op.str == "" {
ops = append(ops, op.num)
} else {
// eof segment literal represents the "/" path pattern
if op.str == eof {
op.str = ""
}
if _, ok := consts[op.str]; !ok {
consts[op.str] = len(pool)
pool = append(pool, op.str)
}
ops = append(ops, consts[op.str])
}
if op.code == utilities.OpCapture {
fields = append(fields, op.str)
}
}
return Template{
Version: opcodeVersion,
OpCodes: ops,
Pool: pool,
Verb: t.verb,
Fields: fields,
Template: t.template,
}
}

View File

@@ -0,0 +1,11 @@
// +build gofuzz
package httprule
func Fuzz(data []byte) int {
_, err := Parse(string(data))
if err != nil {
return 0
}
return 0
}

View File

@@ -0,0 +1,368 @@
package httprule
import (
"fmt"
"strings"
)
// InvalidTemplateError indicates that the path template is not valid.
type InvalidTemplateError struct {
tmpl string
msg string
}
func (e InvalidTemplateError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.tmpl)
}
// Parse parses the string representation of path template
func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
}
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens}
segs, err := p.topLevelSegments()
if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
}
return template{
segments: segs,
verb: verb,
template: tmpl,
}, nil
}
func tokenize(path string) (tokens []string, verb string) {
if path == "" {
return []string{eof}, ""
}
const (
init = iota
field
nested
)
st := init
for path != "" {
var idx int
switch st {
case init:
idx = strings.IndexAny(path, "/{")
case field:
idx = strings.IndexAny(path, ".=}")
case nested:
idx = strings.IndexAny(path, "/}")
}
if idx < 0 {
tokens = append(tokens, path)
break
}
switch r := path[idx]; r {
case '/', '.':
case '{':
st = field
case '=':
st = nested
case '}':
st = init
}
if idx == 0 {
tokens = append(tokens, path[idx:idx+1])
} else {
tokens = append(tokens, path[:idx], path[idx:idx+1])
}
path = path[idx+1:]
}
l := len(tokens)
// See
// https://github.com/grpc-ecosystem/grpc-gateway/pull/1947#issuecomment-774523693 ;
// although normal and backwards-compat logic here is to use the last index
// of a colon, if the final segment is a variable followed by a colon, the
// part following the colon must be a verb. Hence if the previous token is
// an end var marker, we switch the index we're looking for to Index instead
// of LastIndex, so that we correctly grab the remaining part of the path as
// the verb.
var penultimateTokenIsEndVar bool
switch l {
case 0, 1:
// Not enough to be variable so skip this logic and don't result in an
// invalid index
default:
penultimateTokenIsEndVar = tokens[l-2] == "}"
}
t := tokens[l-1]
var idx int
if penultimateTokenIsEndVar {
idx = strings.Index(t, ":")
} else {
idx = strings.LastIndex(t, ":")
}
if idx == 0 {
tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:]
}
tokens = append(tokens, eof)
return tokens, verb
}
// parser is a parser of the template syntax defined in github.com/googleapis/googleapis/google/api/http.proto.
type parser struct {
tokens []string
accepted []string
}
// topLevelSegments is the target of this parser.
func (p *parser) topLevelSegments() ([]segment, error) {
if _, err := p.accept(typeEOF); err == nil {
p.tokens = p.tokens[:0]
return []segment{literal(eof)}, nil
}
segs, err := p.segments()
if err != nil {
return nil, err
}
if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
}
return segs, nil
}
func (p *parser) segments() ([]segment, error) {
s, err := p.segment()
if err != nil {
return nil, err
}
segs := []segment{s}
for {
if _, err := p.accept("/"); err != nil {
return segs, nil
}
s, err := p.segment()
if err != nil {
return segs, err
}
segs = append(segs, s)
}
}
func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil {
return wildcard{}, nil
}
if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil
}
if l, err := p.literal(); err == nil {
return l, nil
}
v, err := p.variable()
if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %v", err)
}
return v, err
}
func (p *parser) literal() (segment, error) {
lit, err := p.accept(typeLiteral)
if err != nil {
return nil, err
}
return literal(lit), nil
}
func (p *parser) variable() (segment, error) {
if _, err := p.accept("{"); err != nil {
return nil, err
}
path, err := p.fieldPath()
if err != nil {
return nil, err
}
var segs []segment
if _, err := p.accept("="); err == nil {
segs, err = p.segments()
if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %v", path, err)
}
} else {
segs = []segment{wildcard{}}
}
if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path)
}
return variable{
path: path,
segments: segs,
}, nil
}
func (p *parser) fieldPath() (string, error) {
c, err := p.accept(typeIdent)
if err != nil {
return "", err
}
components := []string{c}
for {
if _, err = p.accept("."); err != nil {
return strings.Join(components, "."), nil
}
c, err := p.accept(typeIdent)
if err != nil {
return "", fmt.Errorf("invalid field path component: %v", err)
}
components = append(components, c)
}
}
// A termType is a type of terminal symbols.
type termType string
// These constants define some of valid values of termType.
// They improve readability of parse functions.
//
// You can also use "/", "*", "**", "." or "=" as valid values.
const (
typeIdent = termType("ident")
typeLiteral = termType("literal")
typeEOF = termType("$")
)
const (
// eof is the terminal symbol which always appears at the end of token sequence.
eof = "\u0000"
)
// accept tries to accept a token in "p".
// This function consumes a token and returns it if it matches to the specified "term".
// If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0]
switch term {
case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) && t != "/" {
return "", fmt.Errorf("expected %q but got %q", term, t)
}
case typeEOF:
if t != eof {
return "", fmt.Errorf("expected EOF but got %q", t)
}
case typeIdent:
if err := expectIdent(t); err != nil {
return "", err
}
case typeLiteral:
if err := expectPChars(t); err != nil {
return "", err
}
default:
return "", fmt.Errorf("unknown termType %q", term)
}
p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t)
return t, nil
}
// expectPChars determines if "t" consists of only pchars defined in RFC3986.
//
// https://www.ietf.org/rfc/rfc3986.txt, P.49
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func expectPChars(t string) error {
const (
init = iota
pct1
pct2
)
st := init
for _, r := range t {
if st != init {
if !isHexDigit(r) {
return fmt.Errorf("invalid hexdigit: %c(%U)", r, r)
}
switch st {
case pct1:
st = pct2
case pct2:
st = init
}
continue
}
// unreserved
switch {
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case '0' <= r && r <= '9':
continue
}
switch r {
case '-', '.', '_', '~':
// unreserved
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=':
// sub-delims
case ':', '@':
// rest of pchar
case '%':
// pct-encoded
st = pct1
default:
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
}
}
if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t)
}
return nil
}
// expectIdent determines if "ident" is a valid identifier in .proto schema ([[:alpha:]_][[:alphanum:]_]*).
func expectIdent(ident string) error {
if ident == "" {
return fmt.Errorf("empty identifier")
}
for pos, r := range ident {
switch {
case '0' <= r && r <= '9':
if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident)
}
continue
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case r == '_':
continue
default:
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
}
}
return nil
}
func isHexDigit(r rune) bool {
switch {
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'F':
return true
case 'a' <= r && r <= 'f':
return true
}
return false
}

View File

@@ -0,0 +1,60 @@
package httprule
import (
"fmt"
"strings"
)
type template struct {
segments []segment
verb string
template string
}
type segment interface {
fmt.Stringer
compile() (ops []op)
}
type wildcard struct{}
type deepWildcard struct{}
type literal string
type variable struct {
path string
segments []segment
}
func (wildcard) String() string {
return "*"
}
func (deepWildcard) String() string {
return "**"
}
func (l literal) String() string {
return string(l)
}
func (v variable) String() string {
var segs []string
for _, s := range v.segments {
segs = append(segs, s.String())
}
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
}
func (t template) String() string {
var segs []string
for _, s := range t.segments {
segs = append(segs, s.String())
}
str := strings.Join(segs, "/")
if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb)
}
return "/" + str
}

View File

@@ -0,0 +1,91 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "runtime",
srcs = [
"context.go",
"convert.go",
"doc.go",
"errors.go",
"fieldmask.go",
"handler.go",
"marshal_httpbodyproto.go",
"marshal_json.go",
"marshal_jsonpb.go",
"marshal_proto.go",
"marshaler.go",
"marshaler_registry.go",
"mux.go",
"pattern.go",
"proto2_convert.go",
"query.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/runtime",
deps = [
"//internal/httprule",
"//utilities",
"@go_googleapis//google/api:httpbody_go_proto",
"@io_bazel_rules_go//proto/wkt:field_mask_go_proto",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//grpclog",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//reflect/protoregistry",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
go_test(
name = "runtime_test",
size = "small",
srcs = [
"context_test.go",
"convert_test.go",
"errors_test.go",
"fieldmask_test.go",
"handler_test.go",
"marshal_httpbodyproto_test.go",
"marshal_json_test.go",
"marshal_jsonpb_test.go",
"marshal_proto_test.go",
"marshaler_registry_test.go",
"mux_test.go",
"pattern_test.go",
"query_test.go",
],
embed = [":runtime"],
deps = [
"//runtime/internal/examplepb",
"//utilities",
"@com_github_google_go_cmp//cmp",
"@com_github_google_go_cmp//cmp/cmpopts",
"@go_googleapis//google/api:httpbody_go_proto",
"@go_googleapis//google/rpc:errdetails_go_proto",
"@go_googleapis//google/rpc:status_go_proto",
"@io_bazel_rules_go//proto/wkt:field_mask_go_proto",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//testing/protocmp",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/emptypb",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
alias(
name = "go_default_library",
actual = ":runtime",
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,345 @@
package runtime
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// MetadataHeaderPrefix is the http prefix that represents custom metadata
// parameters to or from a gRPC call.
const MetadataHeaderPrefix = "Grpc-Metadata-"
// MetadataPrefix is prepended to permanent HTTP header keys (as specified
// by the IANA) when added to the gRPC context.
const MetadataPrefix = "grpcgateway-"
// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
// HTTP headers in a response handled by grpc-gateway
const MetadataTrailerPrefix = "Grpc-Trailer-"
const metadataGrpcTimeout = "Grpc-Timeout"
const metadataHeaderBinarySuffix = "-Bin"
const xForwardedFor = "X-Forwarded-For"
const xForwardedHost = "X-Forwarded-Host"
var (
// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
// header isn't present. If the value is 0 the sent `context` will not have a timeout.
DefaultContextTimeout = 0 * time.Second
)
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
AnnotateContextOption func(ctx context.Context) context.Context
)
func WithHTTPPathPattern(pattern string) AnnotateContextOption {
return func(ctx context.Context) context.Context {
return withHTTPPathPattern(ctx, pattern)
}
}
func decodeBinHeader(v string) ([]byte, error) {
if len(v)%4 == 0 {
// Input was padded, or padding was not necessary.
return base64.StdEncoding.DecodeString(v)
}
return base64.RawStdEncoding.DecodeString(v)
}
/*
AnnotateContext adds context information such as metadata from the request.
At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
except that the forwarded destination is not another HTTP service but rather
a gRPC service.
*/
func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewOutgoingContext(ctx, md), nil
}
// AnnotateIncomingContext adds context information such as metadata from the request.
// Attach metadata as incoming context.
func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewIncomingContext(ctx, md), nil
}
func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
ctx = withRPCMethod(ctx, rpcMethodName)
for _, o := range options {
ctx = o(ctx)
}
var pairs []string
timeout := DefaultContextTimeout
if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
var err error
timeout, err = timeoutDecode(tm)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
}
}
for key, vals := range req.Header {
key = textproto.CanonicalMIMEHeaderKey(key)
for _, val := range vals {
// For backwards-compatibility, pass through 'authorization' header with no prefix.
if key == "Authorization" {
pairs = append(pairs, "authorization", val)
}
if h, ok := mux.incomingHeaderMatcher(key); ok {
// Handles "-bin" metadata in grpc, since grpc will do another base64
// encode before sending to server, we need to decode it first.
if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
b, err := decodeBinHeader(val)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
}
val = string(b)
}
pairs = append(pairs, h, val)
}
}
}
if host := req.Header.Get(xForwardedHost); host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), host)
} else if req.Host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
if fwd := req.Header.Get(xForwardedFor); fwd == "" {
pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
} else {
pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
}
}
}
if timeout != 0 {
//nolint:govet // The context outlives this function
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
return ctx, nil, nil
}
md := metadata.Pairs(pairs...)
for _, mda := range mux.metadataAnnotators {
md = metadata.Join(md, mda(ctx, req))
}
return ctx, md, nil
}
// ServerMetadata consists of metadata sent from gRPC server.
type ServerMetadata struct {
HeaderMD metadata.MD
TrailerMD metadata.MD
}
type serverMetadataKey struct{}
// NewServerMetadataContext creates a new context with ServerMetadata
func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
return context.WithValue(ctx, serverMetadataKey{}, md)
}
// ServerMetadataFromContext returns the ServerMetadata in ctx
func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
return
}
// ServerTransportStream implements grpc.ServerTransportStream.
// It should only be used by the generated files to support grpc.SendHeader
// outside of gRPC server use.
type ServerTransportStream struct {
mu sync.Mutex
header metadata.MD
trailer metadata.MD
}
// Method returns the method for the stream.
func (s *ServerTransportStream) Method() string {
return ""
}
// Header returns the header metadata of the stream.
func (s *ServerTransportStream) Header() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.header.Copy()
}
// SetHeader sets the header metadata.
func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.header = metadata.Join(s.header, md)
s.mu.Unlock()
return nil
}
// SendHeader sets the header metadata.
func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
return s.SetHeader(md)
}
// Trailer returns the cached trailer metadata.
func (s *ServerTransportStream) Trailer() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.trailer.Copy()
}
// SetTrailer sets the trailer metadata.
func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.mu.Unlock()
return nil
}
func timeoutDecode(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("timeout string is too short: %q", s)
}
d, ok := timeoutUnitToDuration(s[size-1])
if !ok {
return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
}
t, err := strconv.ParseInt(s[:size-1], 10, 64)
if err != nil {
return 0, err
}
return d * time.Duration(t), nil
}
func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
switch u {
case 'H':
return time.Hour, true
case 'M':
return time.Minute, true
case 'S':
return time.Second, true
case 'm':
return time.Millisecond, true
case 'u':
return time.Microsecond, true
case 'n':
return time.Nanosecond, true
default:
}
return
}
// isPermanentHTTPHeader checks whether hdr belongs to the list of
// permanent request headers maintained by IANA.
// http://www.iana.org/assignments/message-headers/message-headers.xml
func isPermanentHTTPHeader(hdr string) bool {
switch hdr {
case
"Accept",
"Accept-Charset",
"Accept-Language",
"Accept-Ranges",
"Authorization",
"Cache-Control",
"Content-Type",
"Cookie",
"Date",
"Expect",
"From",
"Host",
"If-Match",
"If-Modified-Since",
"If-None-Match",
"If-Schedule-Tag-Match",
"If-Unmodified-Since",
"Max-Forwards",
"Origin",
"Pragma",
"Referer",
"User-Agent",
"Via",
"Warning":
return true
}
return false
}
// RPCMethod returns the method string for the server context. The returned
// string is in the format of "/package.service/method".
func RPCMethod(ctx context.Context) (string, bool) {
m := ctx.Value(rpcMethodKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
}
// HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
// The format of the returned string is defined by the google.api.http path template type.
func HTTPPathPattern(ctx context.Context) (string, bool) {
m := ctx.Value(httpPathPatternKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}

View File

@@ -0,0 +1,322 @@
package runtime
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// String just returns the given string.
// It is just for compatibility to other types.
func String(val string) (string, error) {
return val, nil
}
// StringSlice converts 'val' where individual strings are separated by
// 'sep' into a string slice.
func StringSlice(val, sep string) ([]string, error) {
return strings.Split(val, sep), nil
}
// Bool converts the given string representation of a boolean value into bool.
func Bool(val string) (bool, error) {
return strconv.ParseBool(val)
}
// BoolSlice converts 'val' where individual booleans are separated by
// 'sep' into a bool slice.
func BoolSlice(val, sep string) ([]bool, error) {
s := strings.Split(val, sep)
values := make([]bool, len(s))
for i, v := range s {
value, err := Bool(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Float64 converts the given string representation into representation of a floating point number into float64.
func Float64(val string) (float64, error) {
return strconv.ParseFloat(val, 64)
}
// Float64Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float64 slice.
func Float64Slice(val, sep string) ([]float64, error) {
s := strings.Split(val, sep)
values := make([]float64, len(s))
for i, v := range s {
value, err := Float64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Float32 converts the given string representation of a floating point number into float32.
func Float32(val string) (float32, error) {
f, err := strconv.ParseFloat(val, 32)
if err != nil {
return 0, err
}
return float32(f), nil
}
// Float32Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float32 slice.
func Float32Slice(val, sep string) ([]float32, error) {
s := strings.Split(val, sep)
values := make([]float32, len(s))
for i, v := range s {
value, err := Float32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Int64 converts the given string representation of an integer into int64.
func Int64(val string) (int64, error) {
return strconv.ParseInt(val, 0, 64)
}
// Int64Slice converts 'val' where individual integers are separated by
// 'sep' into a int64 slice.
func Int64Slice(val, sep string) ([]int64, error) {
s := strings.Split(val, sep)
values := make([]int64, len(s))
for i, v := range s {
value, err := Int64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Int32 converts the given string representation of an integer into int32.
func Int32(val string) (int32, error) {
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return 0, err
}
return int32(i), nil
}
// Int32Slice converts 'val' where individual integers are separated by
// 'sep' into a int32 slice.
func Int32Slice(val, sep string) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Int32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Uint64 converts the given string representation of an integer into uint64.
func Uint64(val string) (uint64, error) {
return strconv.ParseUint(val, 0, 64)
}
// Uint64Slice converts 'val' where individual integers are separated by
// 'sep' into a uint64 slice.
func Uint64Slice(val, sep string) ([]uint64, error) {
s := strings.Split(val, sep)
values := make([]uint64, len(s))
for i, v := range s {
value, err := Uint64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Uint32 converts the given string representation of an integer into uint32.
func Uint32(val string) (uint32, error) {
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return 0, err
}
return uint32(i), nil
}
// Uint32Slice converts 'val' where individual integers are separated by
// 'sep' into a uint32 slice.
func Uint32Slice(val, sep string) ([]uint32, error) {
s := strings.Split(val, sep)
values := make([]uint32, len(s))
for i, v := range s {
value, err := Uint32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Bytes converts the given string representation of a byte sequence into a slice of bytes
// A bytes sequence is encoded in URL-safe base64 without padding
func Bytes(val string) ([]byte, error) {
b, err := base64.StdEncoding.DecodeString(val)
if err != nil {
b, err = base64.URLEncoding.DecodeString(val)
if err != nil {
return nil, err
}
}
return b, nil
}
// BytesSlice converts 'val' where individual bytes sequences, encoded in URL-safe
// base64 without padding, are separated by 'sep' into a slice of bytes slices slice.
func BytesSlice(val, sep string) ([][]byte, error) {
s := strings.Split(val, sep)
values := make([][]byte, len(s))
for i, v := range s {
value, err := Bytes(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Timestamp converts the given RFC3339 formatted string into a timestamp.Timestamp.
func Timestamp(val string) (*timestamppb.Timestamp, error) {
var r timestamppb.Timestamp
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
err := unmarshaler.Unmarshal([]byte(val), &r)
if err != nil {
return nil, err
}
return &r, nil
}
// Duration converts the given string into a timestamp.Duration.
func Duration(val string) (*durationpb.Duration, error) {
var r durationpb.Duration
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
err := unmarshaler.Unmarshal([]byte(val), &r)
if err != nil {
return nil, err
}
return &r, nil
}
// Enum converts the given string into an int32 that should be type casted into the
// correct enum proto type.
func Enum(val string, enumValMap map[string]int32) (int32, error) {
e, ok := enumValMap[val]
if ok {
return e, nil
}
i, err := Int32(val)
if err != nil {
return 0, fmt.Errorf("%s is not valid", val)
}
for _, v := range enumValMap {
if v == i {
return i, nil
}
}
return 0, fmt.Errorf("%s is not valid", val)
}
// EnumSlice converts 'val' where individual enums are separated by 'sep'
// into a int32 slice. Each individual int32 should be type casted into the
// correct enum proto type.
func EnumSlice(val, sep string, enumValMap map[string]int32) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Enum(v, enumValMap)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
/*
Support fot google.protobuf.wrappers on top of primitive types
*/
// StringValue well-known type support as wrapper around string type
func StringValue(val string) (*wrapperspb.StringValue, error) {
return &wrapperspb.StringValue{Value: val}, nil
}
// FloatValue well-known type support as wrapper around float32 type
func FloatValue(val string) (*wrapperspb.FloatValue, error) {
parsedVal, err := Float32(val)
return &wrapperspb.FloatValue{Value: parsedVal}, err
}
// DoubleValue well-known type support as wrapper around float64 type
func DoubleValue(val string) (*wrapperspb.DoubleValue, error) {
parsedVal, err := Float64(val)
return &wrapperspb.DoubleValue{Value: parsedVal}, err
}
// BoolValue well-known type support as wrapper around bool type
func BoolValue(val string) (*wrapperspb.BoolValue, error) {
parsedVal, err := Bool(val)
return &wrapperspb.BoolValue{Value: parsedVal}, err
}
// Int32Value well-known type support as wrapper around int32 type
func Int32Value(val string) (*wrapperspb.Int32Value, error) {
parsedVal, err := Int32(val)
return &wrapperspb.Int32Value{Value: parsedVal}, err
}
// UInt32Value well-known type support as wrapper around uint32 type
func UInt32Value(val string) (*wrapperspb.UInt32Value, error) {
parsedVal, err := Uint32(val)
return &wrapperspb.UInt32Value{Value: parsedVal}, err
}
// Int64Value well-known type support as wrapper around int64 type
func Int64Value(val string) (*wrapperspb.Int64Value, error) {
parsedVal, err := Int64(val)
return &wrapperspb.Int64Value{Value: parsedVal}, err
}
// UInt64Value well-known type support as wrapper around uint64 type
func UInt64Value(val string) (*wrapperspb.UInt64Value, error) {
parsedVal, err := Uint64(val)
return &wrapperspb.UInt64Value{Value: parsedVal}, err
}
// BytesValue well-known type support as wrapper around bytes[] type
func BytesValue(val string) (*wrapperspb.BytesValue, error) {
parsedVal, err := Bytes(val)
return &wrapperspb.BytesValue{Value: parsedVal}, err
}

View File

@@ -0,0 +1,5 @@
/*
Package runtime contains runtime helper functions used by
servers which protoc-gen-grpc-gateway generates.
*/
package runtime

View File

@@ -0,0 +1,180 @@
package runtime
import (
"context"
"errors"
"io"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)
// ErrorHandlerFunc is the signature used to configure error handling.
type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)
// StreamErrorHandlerFunc is the signature used to configure stream error handling.
type StreamErrorHandlerFunc func(context.Context, error) *status.Status
// RoutingErrorHandlerFunc is the signature used to configure error handling for routing errors.
type RoutingErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, int)
// HTTPStatusError is the error to use when needing to provide a different HTTP status code for an error
// passed to the DefaultRoutingErrorHandler.
type HTTPStatusError struct {
HTTPStatus int
Err error
}
func (e *HTTPStatusError) Error() string {
return e.Err.Error()
}
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func HTTPStatusFromCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return http.StatusRequestTimeout
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusGatewayTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusTooManyRequests
case codes.FailedPrecondition:
// Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
return http.StatusBadRequest
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
}
grpclog.Infof("Unknown gRPC error code: %v", code)
return http.StatusInternalServerError
}
// HTTPError uses the mux-configured error handler.
func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
mux.errorHandler(ctx, mux, marshaler, w, r, err)
}
// DefaultHTTPErrorHandler is the default error handler.
// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode.
// If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is
// intended to allow passing through of specific statuses via the function set via WithRoutingErrorHandler
// for the ServeMux constructor to handle edge cases which the standard mappings in HTTPStatusFromCode
// are insufficient for.
// If otherwise, it replies with http.StatusInternalServerError.
//
// The response body written by this function is a Status message marshaled by the Marshaler.
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
var customStatus *HTTPStatusError
if errors.As(err, &customStatus) {
err = customStatus.Err
}
s := status.Convert(err)
pb := s.Proto()
w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")
contentType := marshaler.ContentType(pb)
w.Header().Set("Content-Type", contentType)
if s.Code() == codes.Unauthenticated {
w.Header().Set("WWW-Authenticate", s.Message())
}
buf, merr := marshaler.Marshal(pb)
if merr != nil {
grpclog.Infof("Failed to marshal error message %q: %v", s, merr)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallback); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(r)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
st := HTTPStatusFromCode(s.Code())
if customStatus != nil {
st = customStatus.HTTPStatus
}
w.WriteHeader(st)
if _, err := w.Write(buf); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, md)
}
}
func DefaultStreamErrorHandler(_ context.Context, err error) *status.Status {
return status.Convert(err)
}
// DefaultRoutingErrorHandler is our default handler for routing errors.
// By default http error codes mapped on the following error codes:
// NotFound -> grpc.NotFound
// StatusBadRequest -> grpc.InvalidArgument
// MethodNotAllowed -> grpc.Unimplemented
// Other -> grpc.Internal, method is not expecting to be called for anything else
func DefaultRoutingErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
sterr := status.Error(codes.Internal, "Unexpected routing error")
switch httpStatus {
case http.StatusBadRequest:
sterr = status.Error(codes.InvalidArgument, http.StatusText(httpStatus))
case http.StatusMethodNotAllowed:
sterr = status.Error(codes.Unimplemented, http.StatusText(httpStatus))
case http.StatusNotFound:
sterr = status.Error(codes.NotFound, http.StatusText(httpStatus))
}
mux.errorHandler(ctx, mux, marshaler, w, r, sterr)
}

View File

@@ -0,0 +1,165 @@
package runtime
import (
"encoding/json"
"fmt"
"io"
"sort"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
func getFieldByName(fields protoreflect.FieldDescriptors, name string) protoreflect.FieldDescriptor {
fd := fields.ByName(protoreflect.Name(name))
if fd != nil {
return fd
}
return fields.ByJSONName(name)
}
// FieldMaskFromRequestBody creates a FieldMask printing all complete paths from the JSON body.
func FieldMaskFromRequestBody(r io.Reader, msg proto.Message) (*field_mask.FieldMask, error) {
fm := &field_mask.FieldMask{}
var root interface{}
if err := json.NewDecoder(r).Decode(&root); err != nil {
if err == io.EOF {
return fm, nil
}
return nil, err
}
queue := []fieldMaskPathItem{{node: root, msg: msg.ProtoReflect()}}
for len(queue) > 0 {
// dequeue an item
item := queue[0]
queue = queue[1:]
m, ok := item.node.(map[string]interface{})
switch {
case ok:
// if the item is an object, then enqueue all of its children
for k, v := range m {
if item.msg == nil {
return nil, fmt.Errorf("JSON structure did not match request type")
}
fd := getFieldByName(item.msg.Descriptor().Fields(), k)
if fd == nil {
return nil, fmt.Errorf("could not find field %q in %q", k, item.msg.Descriptor().FullName())
}
if isDynamicProtoMessage(fd.Message()) {
for _, p := range buildPathsBlindly(k, v) {
newPath := p
if item.path != "" {
newPath = item.path + "." + newPath
}
queue = append(queue, fieldMaskPathItem{path: newPath})
}
continue
}
if isProtobufAnyMessage(fd.Message()) {
_, hasTypeField := v.(map[string]interface{})["@type"]
if hasTypeField {
queue = append(queue, fieldMaskPathItem{path: k})
continue
} else {
return nil, fmt.Errorf("could not find field @type in %q in message %q", k, item.msg.Descriptor().FullName())
}
}
child := fieldMaskPathItem{
node: v,
}
if item.path == "" {
child.path = string(fd.FullName().Name())
} else {
child.path = item.path + "." + string(fd.FullName().Name())
}
switch {
case fd.IsList(), fd.IsMap():
// As per: https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/field_mask.proto#L85-L86
// Do not recurse into repeated fields. The repeated field goes on the end of the path and we stop.
fm.Paths = append(fm.Paths, child.path)
case fd.Message() != nil:
child.msg = item.msg.Get(fd).Message()
fallthrough
default:
queue = append(queue, child)
}
}
case len(item.path) > 0:
// otherwise, it's a leaf node so print its path
fm.Paths = append(fm.Paths, item.path)
}
}
// Sort for deterministic output in the presence
// of repeated fields.
sort.Strings(fm.Paths)
return fm, nil
}
func isProtobufAnyMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Any")
}
func isDynamicProtoMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Struct" || md.FullName() == "google.protobuf.Value")
}
// buildPathsBlindly does not attempt to match proto field names to the
// json value keys. Instead it relies completely on the structure of
// the unmarshalled json contained within in.
// Returns a slice containing all subpaths with the root at the
// passed in name and json value.
func buildPathsBlindly(name string, in interface{}) []string {
m, ok := in.(map[string]interface{})
if !ok {
return []string{name}
}
var paths []string
queue := []fieldMaskPathItem{{path: name, node: m}}
for len(queue) > 0 {
cur := queue[0]
queue = queue[1:]
m, ok := cur.node.(map[string]interface{})
if !ok {
// This should never happen since we should always check that we only add
// nodes of type map[string]interface{} to the queue.
continue
}
for k, v := range m {
if mi, ok := v.(map[string]interface{}); ok {
queue = append(queue, fieldMaskPathItem{path: cur.path + "." + k, node: mi})
} else {
// This is not a struct, so there are no more levels to descend.
curPath := cur.path + "." + k
paths = append(paths, curPath)
}
}
}
return paths
}
// fieldMaskPathItem stores a in-progress deconstruction of a path for a fieldmask
type fieldMaskPathItem struct {
// the list of prior fields leading up to node connected by dots
path string
// a generic decoded json object the current item to inspect for further path extraction
node interface{}
// parent message
msg protoreflect.Message
}

View File

@@ -0,0 +1,223 @@
package runtime
import (
"context"
"fmt"
"io"
"net/http"
"net/textproto"
"strings"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
f, ok := w.(http.Flusher)
if !ok {
grpclog.Infof("Flush not supported in %T", w)
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
handleForwardResponseServerMetadata(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var delimiter []byte
if d, ok := marshaler.(Delimited); ok {
delimiter = d.Delimiter()
} else {
delimiter = []byte("\n")
}
var wroteHeader bool
for {
resp, err := recv()
if err == io.EOF {
return
}
if err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(resp))
}
var buf []byte
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
switch {
case resp == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": resp}
if rb, ok := resp.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}
buf, err = marshaler.Marshal(result)
}
if err != nil {
grpclog.Infof("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Infof("Failed to send response chunk: %v", err)
return
}
wroteHeader = true
if _, err = w.Write(delimiter); err != nil {
grpclog.Infof("Failed to send delimiter chunk: %v", err)
return
}
f.Flush()
}
}
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.HeaderMD {
if h, ok := mux.outgoingHeaderMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
for _, v := range vs {
w.Header().Add(tKey, v)
}
}
}
// responseBody interface contains method for getting field for marshaling to the response body
// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
type responseBody interface {
XXX_ResponseBody() interface{}
}
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(req)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
handleForwardResponseTrailerHeader(w, md)
contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
var err error
if rb, ok := resp.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(resp)
}
if err != nil {
grpclog.Infof("Marshal error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, md)
}
}
func requestAcceptsTrailers(req *http.Request) bool {
te := req.Header.Get("TE")
return strings.Contains(strings.ToLower(te), "trailers")
}
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
if len(opts) == 0 {
return nil
}
for _, opt := range opts {
if err := opt(ctx, w, resp); err != nil {
grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
return err
}
}
return nil
}
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
st := mux.streamErrorHandler(ctx, err)
msg := errorChunk(st)
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(msg))
w.WriteHeader(HTTPStatusFromCode(st.Code()))
}
buf, merr := marshaler.Marshal(msg)
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
return
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Infof("Failed to notify error to client: %v", werr)
return
}
}
func errorChunk(st *status.Status) map[string]proto.Message {
return map[string]proto.Message{"error": st.Proto()}
}

View File

@@ -0,0 +1,32 @@
package runtime
import (
"google.golang.org/genproto/googleapis/api/httpbody"
)
// HTTPBodyMarshaler is a Marshaler which supports marshaling of a
// google.api.HttpBody message as the full response body if it is
// the actual message used as the response. If not, then this will
// simply fallback to the Marshaler specified as its default Marshaler.
type HTTPBodyMarshaler struct {
Marshaler
}
// ContentType returns its specified content type in case v is a
// google.api.HttpBody message, otherwise it will fall back to the default Marshalers
// content type.
func (h *HTTPBodyMarshaler) ContentType(v interface{}) string {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.GetContentType()
}
return h.Marshaler.ContentType(v)
}
// Marshal marshals "v" by returning the body bytes if v is a
// google.api.HttpBody message, otherwise it falls back to the default Marshaler.
func (h *HTTPBodyMarshaler) Marshal(v interface{}) ([]byte, error) {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.Data, nil
}
return h.Marshaler.Marshal(v)
}

View File

@@ -0,0 +1,45 @@
package runtime
import (
"encoding/json"
"io"
)
// JSONBuiltin is a Marshaler which marshals/unmarshals into/from JSON
// with the standard "encoding/json" package of Golang.
// Although it is generally faster for simple proto messages than JSONPb,
// it does not support advanced features of protobuf, e.g. map, oneof, ....
//
// The NewEncoder and NewDecoder types return *json.Encoder and
// *json.Decoder respectively.
type JSONBuiltin struct{}
// ContentType always Returns "application/json".
func (*JSONBuiltin) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON
func (j *JSONBuiltin) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal unmarshals JSON data into "v".
func (j *JSONBuiltin) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONBuiltin) NewDecoder(r io.Reader) Decoder {
return json.NewDecoder(r)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONBuiltin) NewEncoder(w io.Writer) Encoder {
return json.NewEncoder(w)
}
// Delimiter for newline encoded JSON streams.
func (j *JSONBuiltin) Delimiter() []byte {
return []byte("\n")
}

View File

@@ -0,0 +1,344 @@
package runtime
import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"
"strconv"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// JSONPb is a Marshaler which marshals/unmarshals into/from JSON
// with the "google.golang.org/protobuf/encoding/protojson" marshaler.
// It supports the full functionality of protobuf unlike JSONBuiltin.
//
// The NewDecoder method returns a DecoderWrapper, so the underlying
// *json.Decoder methods can be used.
type JSONPb struct {
protojson.MarshalOptions
protojson.UnmarshalOptions
}
// ContentType always returns "application/json".
func (*JSONPb) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON.
func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
if _, ok := v.(proto.Message); !ok {
return j.marshalNonProtoField(v)
}
var buf bytes.Buffer
if err := j.marshalTo(&buf, v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
buf, err := j.marshalNonProtoField(v)
if err != nil {
return err
}
_, err = w.Write(buf)
return err
}
b, err := j.MarshalOptions.Marshal(p)
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
var (
// protoMessageType is stored to prevent constant lookup of the same type at runtime.
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
// marshalNonProto marshals a non-message field of a protobuf message.
// This function does not correctly marshal arbitrary data structures into JSON,
// it is only capable of marshaling non-message field values of protobuf,
// i.e. primitive types, enums; pointers to primitives or enums; maps from
// integer/string types to primitives/enums/pointers to messages.
func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return []byte("null"), nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Slice {
if rv.IsNil() {
if j.EmitUnpopulated {
return []byte("[]"), nil
}
return []byte("null"), nil
}
if rv.Type().Elem().Implements(protoMessageType) {
var buf bytes.Buffer
err := buf.WriteByte('[')
if err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
err = buf.WriteByte(',')
if err != nil {
return nil, err
}
}
if err = j.marshalTo(&buf, rv.Index(i).Interface().(proto.Message)); err != nil {
return nil, err
}
}
err = buf.WriteByte(']')
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
if rv.Type().Elem().Implements(typeProtoEnum) {
var buf bytes.Buffer
err := buf.WriteByte('[')
if err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
err = buf.WriteByte(',')
if err != nil {
return nil, err
}
}
if j.UseEnumNumbers {
_, err = buf.WriteString(strconv.FormatInt(rv.Index(i).Int(), 10))
} else {
_, err = buf.WriteString("\"" + rv.Index(i).Interface().(protoEnum).String() + "\"")
}
if err != nil {
return nil, err
}
}
err = buf.WriteByte(']')
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
}
if rv.Kind() == reflect.Map {
m := make(map[string]*json.RawMessage)
for _, k := range rv.MapKeys() {
buf, err := j.Marshal(rv.MapIndex(k).Interface())
if err != nil {
return nil, err
}
m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
}
if j.Indent != "" {
return json.MarshalIndent(m, "", j.Indent)
}
return json.Marshal(m)
}
if enum, ok := rv.Interface().(protoEnum); ok && !j.UseEnumNumbers {
return json.Marshal(enum.String())
}
return json.Marshal(rv.Interface())
}
// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, j.UnmarshalOptions, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{
Decoder: d,
UnmarshalOptions: j.UnmarshalOptions,
}
}
// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
protojson.UnmarshalOptions
}
// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
return EncoderFunc(func(v interface{}) error {
if err := j.marshalTo(w, v); err != nil {
return err
}
// mimic json.Encoder by adding a newline (makes output
// easier to read when it contains multiple encoded items)
_, err := w.Write(j.Delimiter())
return err
})
}
func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, unmarshaler, v)
}
func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, unmarshaler, v)
}
// Decode into bytes for marshalling
var b json.RawMessage
err := d.Decode(&b)
if err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), p)
}
func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
// Decode into bytes for marshalling
var b json.RawMessage
err := d.Decode(&b)
if err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), rv.Interface().(proto.Message))
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
conv, ok := convFromType[rv.Type().Key().Kind()]
if !ok {
return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
}
m := make(map[string]*json.RawMessage)
if err := d.Decode(&m); err != nil {
return err
}
for k, v := range m {
result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if v == nil {
null := json.RawMessage("null")
v = &null
}
if err := unmarshalJSONPb([]byte(*v), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
}
return nil
}
if rv.Kind() == reflect.Slice {
var sl []json.RawMessage
if err := d.Decode(&sl); err != nil {
return err
}
if sl != nil {
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
}
for _, item := range sl {
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(item), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.Set(reflect.Append(rv, bv.Elem()))
}
return nil
}
if _, ok := rv.Interface().(protoEnum); ok {
var repr interface{}
if err := d.Decode(&repr); err != nil {
return err
}
switch v := repr.(type) {
case string:
// TODO(yugui) Should use proto.StructProperties?
return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface())
case float64:
rv.Set(reflect.ValueOf(int32(v)).Convert(rv.Type()))
return nil
default:
return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
}
}
return d.Decode(v)
}
type protoEnum interface {
fmt.Stringer
EnumDescriptor() ([]byte, []int)
}
var typeProtoEnum = reflect.TypeOf((*protoEnum)(nil)).Elem()
var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
// Delimiter for newline encoded JSON streams.
func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}
var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
reflect.Bool: reflect.ValueOf(Bool),
reflect.Float64: reflect.ValueOf(Float64),
reflect.Float32: reflect.ValueOf(Float32),
reflect.Int64: reflect.ValueOf(Int64),
reflect.Int32: reflect.ValueOf(Int32),
reflect.Uint64: reflect.ValueOf(Uint64),
reflect.Uint32: reflect.ValueOf(Uint32),
reflect.Slice: reflect.ValueOf(Bytes),
}
)

View File

@@ -0,0 +1,63 @@
package runtime
import (
"io"
"errors"
"io/ioutil"
"google.golang.org/protobuf/proto"
)
// ProtoMarshaller is a Marshaller which marshals/unmarshals into/from serialize proto bytes
type ProtoMarshaller struct{}
// ContentType always returns "application/octet-stream".
func (*ProtoMarshaller) ContentType(_ interface{}) string {
return "application/octet-stream"
}
// Marshal marshals "value" into Proto
func (*ProtoMarshaller) Marshal(value interface{}) ([]byte, error) {
message, ok := value.(proto.Message)
if !ok {
return nil, errors.New("unable to marshal non proto field")
}
return proto.Marshal(message)
}
// Unmarshal unmarshals proto "data" into "value"
func (*ProtoMarshaller) Unmarshal(data []byte, value interface{}) error {
message, ok := value.(proto.Message)
if !ok {
return errors.New("unable to unmarshal non proto field")
}
return proto.Unmarshal(data, message)
}
// NewDecoder returns a Decoder which reads proto stream from "reader".
func (marshaller *ProtoMarshaller) NewDecoder(reader io.Reader) Decoder {
return DecoderFunc(func(value interface{}) error {
buffer, err := ioutil.ReadAll(reader)
if err != nil {
return err
}
return marshaller.Unmarshal(buffer, value)
})
}
// NewEncoder returns an Encoder which writes proto stream into "writer".
func (marshaller *ProtoMarshaller) NewEncoder(writer io.Writer) Encoder {
return EncoderFunc(func(value interface{}) error {
buffer, err := marshaller.Marshal(value)
if err != nil {
return err
}
_, err = writer.Write(buffer)
if err != nil {
return err
}
return nil
})
}

View File

@@ -0,0 +1,50 @@
package runtime
import (
"io"
)
// Marshaler defines a conversion between byte sequence and gRPC payloads / fields.
type Marshaler interface {
// Marshal marshals "v" into byte sequence.
Marshal(v interface{}) ([]byte, error)
// Unmarshal unmarshals "data" into "v".
// "v" must be a pointer value.
Unmarshal(data []byte, v interface{}) error
// NewDecoder returns a Decoder which reads byte sequence from "r".
NewDecoder(r io.Reader) Decoder
// NewEncoder returns an Encoder which writes bytes sequence into "w".
NewEncoder(w io.Writer) Encoder
// ContentType returns the Content-Type which this marshaler is responsible for.
// The parameter describes the type which is being marshalled, which can sometimes
// affect the content type returned.
ContentType(v interface{}) string
}
// Decoder decodes a byte sequence
type Decoder interface {
Decode(v interface{}) error
}
// Encoder encodes gRPC payloads / fields into byte sequence.
type Encoder interface {
Encode(v interface{}) error
}
// DecoderFunc adapts an decoder function into Decoder.
type DecoderFunc func(v interface{}) error
// Decode delegates invocations to the underlying function itself.
func (f DecoderFunc) Decode(v interface{}) error { return f(v) }
// EncoderFunc adapts an encoder function into Encoder
type EncoderFunc func(v interface{}) error
// Encode delegates invocations to the underlying function itself.
func (f EncoderFunc) Encode(v interface{}) error { return f(v) }
// Delimited defines the streaming delimiter.
type Delimited interface {
// Delimiter returns the record separator for the stream.
Delimiter() []byte
}

View File

@@ -0,0 +1,109 @@
package runtime
import (
"errors"
"mime"
"net/http"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/encoding/protojson"
)
// MIMEWildcard is the fallback MIME type used for requests which do not match
// a registered MIME type.
const MIMEWildcard = "*"
var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
defaultMarshaler = &HTTPBodyMarshaler{
Marshaler: &JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}
)
// MarshalerForRequest returns the inbound/outbound marshalers for this request.
// It checks the registry on the ServeMux for the MIME type set by the Content-Type header.
// If it isn't set (or the request Content-Type is empty), checks for "*".
// If there are multiple Content-Type headers set, choose the first one that it can
// exactly match in the registry.
// Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler.
func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) {
for _, acceptVal := range r.Header[acceptHeader] {
if m, ok := mux.marshalers.mimeMap[acceptVal]; ok {
outbound = m
break
}
}
for _, contentTypeVal := range r.Header[contentTypeHeader] {
contentType, _, err := mime.ParseMediaType(contentTypeVal)
if err != nil {
grpclog.Infof("Failed to parse Content-Type %s: %v", contentTypeVal, err)
continue
}
if m, ok := mux.marshalers.mimeMap[contentType]; ok {
inbound = m
break
}
}
if inbound == nil {
inbound = mux.marshalers.mimeMap[MIMEWildcard]
}
if outbound == nil {
outbound = inbound
}
return inbound, outbound
}
// marshalerRegistry is a mapping from MIME types to Marshalers.
type marshalerRegistry struct {
mimeMap map[string]Marshaler
}
// add adds a marshaler for a case-sensitive MIME type string ("*" to match any
// MIME type).
func (m marshalerRegistry) add(mime string, marshaler Marshaler) error {
if len(mime) == 0 {
return errors.New("empty MIME type")
}
m.mimeMap[mime] = marshaler
return nil
}
// makeMarshalerMIMERegistry returns a new registry of marshalers.
// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces.
//
// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler
// with a "application/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler
// with a "application/json" Content-Type.
// "*" can be used to match any Content-Type.
// This can be attached to a ServerMux with the marshaler option.
func makeMarshalerMIMERegistry() marshalerRegistry {
return marshalerRegistry{
mimeMap: map[string]Marshaler{
MIMEWildcard: defaultMarshaler,
},
}
}
// WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound
// Marshalers to a MIME type in mux.
func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption {
return func(mux *ServeMux) {
if err := mux.marshalers.add(mime, marshaler); err != nil {
panic(err)
}
}
}

View File

@@ -0,0 +1,356 @@
package runtime
import (
"context"
"errors"
"fmt"
"net/http"
"net/textproto"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
type UnescapingMode int
const (
// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
// path string before doing any routing.
UnescapingModeLegacy UnescapingMode = iota
// EscapingTypeExceptReserved unescapes all path parameters except RFC 6570
// reserved characters.
UnescapingModeAllExceptReserved
// EscapingTypeExceptSlash unescapes URL path parameters except path
// seperators, which will be left as "%2F".
UnescapingModeAllExceptSlash
// URL path parameters will be fully decoded.
UnescapingModeAllCharacters
// UnescapingModeDefault is the default escaping type.
// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
// reference implementation
UnescapingModeDefault = UnescapingModeLegacy
)
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
}
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func(*ServeMux)
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
// http.ResponseWriter, and proto.Message before every forwarded response.
//
// The message may be nil in the case where just a header is being sent.
func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
}
}
// WithEscapingType sets the escaping type. See the definitions of UnescapingMode
// for more information.
func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.unescapingMode = mode
}
}
// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
return func(serveMux *ServeMux) {
currentQueryParser = queryParameterParser
}
}
// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
type HeaderMatcherFunc func(string) (string, bool)
// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
func DefaultHeaderMatcher(key string) (string, bool) {
key = textproto.CanonicalMIMEHeaderKey(key)
if isPermanentHTTPHeader(key) {
return MetadataPrefix + key, true
} else if strings.HasPrefix(key, MetadataHeaderPrefix) {
return key[len(MetadataHeaderPrefix):], true
}
return "", false
}
// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.incomingHeaderMatcher = fn
}
}
// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return modified header.
func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingHeaderMatcher = fn
}
}
// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
// is reading token from cookie and adding it in gRPC context.
func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
}
}
// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
//
// This can be used to configure a custom error response.
func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.errorHandler = fn
}
}
// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}
// WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
//
// Method called for errors which can happen before gRPC route selected or executed.
// The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.routingErrorHandler = fn
}
}
// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.disablePathLengthFallback = true
}
}
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
errorHandler: DefaultHTTPErrorHandler,
streamErrorHandler: DefaultStreamErrorHandler,
routingErrorHandler: DefaultRoutingErrorHandler,
unescapingMode: UnescapingModeDefault,
}
for _, opt := range opts {
opt(serveMux)
}
if serveMux.incomingHeaderMatcher == nil {
serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
}
if serveMux.outgoingHeaderMatcher == nil {
serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}
}
return serveMux
}
// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}
// HandlePath allows users to configure custom path handlers.
// refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
compiler, err := httprule.Parse(pathPattern)
if err != nil {
return fmt.Errorf("parsing path pattern: %w", err)
}
tp := compiler.Compile()
pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
if err != nil {
return fmt.Errorf("creating new pattern: %w", err)
}
s.Handle(meth, pattern, h)
return nil
}
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := r.URL.Path
if !strings.HasPrefix(path, "/") {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
return
}
// TODO(v3): remove UnescapingModeLegacy
if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
path = r.URL.RawPath
}
components := strings.Split(path[1:], "/")
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
r.Method = strings.ToUpper(override)
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
}
// Verb out here is to memoize for the fallback case below
var verb string
for _, h := range s.handlers[r.Method] {
// If the pattern has a verb, explicitly look for a suffix in the last
// component that matches a colon plus the verb. This allows us to
// handle some cases that otherwise can't be correctly handled by the
// former LastIndex case, such as when the verb literal itself contains
// a colon. This should work for all cases that have run through the
// parser because we know what verb we're looking for, however, there
// are still some cases that the parser itself cannot disambiguate. See
// the comment there if interested.
patVerb := h.pat.Verb()
l := len(components)
lastComponent := components[l-1]
var idx int = -1
if patVerb != "" && strings.HasSuffix(lastComponent, ":"+patVerb) {
idx = len(lastComponent) - len(patVerb) - 1
}
if idx == 0 {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
return
}
if idx > 0 {
components[l-1], verb = lastComponent[:idx], lastComponent[idx+1:]
}
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
h.h(w, r, pathParams)
return
}
// lookup other methods to handle fallback from GET to POST and
// to determine if it is NotImplemented or NotFound.
for m, handlers := range s.handlers {
if m == r.Method {
continue
}
for _, h := range handlers {
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
if s.isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
h.h(w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
return
}
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
}
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
return s.forwardResponseOptions
}
func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}

View File

@@ -0,0 +1,383 @@
package runtime
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc/grpclog"
)
var (
// ErrNotMatch indicates that the given HTTP request path does not match to the pattern.
ErrNotMatch = errors.New("not match to the path pattern")
// ErrInvalidPattern indicates that the given definition of Pattern is not valid.
ErrInvalidPattern = errors.New("invalid pattern")
// ErrMalformedSequence indicates that an escape sequence was malformed.
ErrMalformedSequence = errors.New("malformed escape sequence")
)
type MalformedSequenceError string
func (e MalformedSequenceError) Error() string {
return "malformed path escape " + strconv.Quote(string(e))
}
type op struct {
code utilities.OpCode
operand int
}
// Pattern is a template pattern of http request paths defined in
// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto
type Pattern struct {
// ops is a list of operations
ops []op
// pool is a constant pool indexed by the operands or vars.
pool []string
// vars is a list of variables names to be bound by this pattern
vars []string
// stacksize is the max depth of the stack
stacksize int
// tailLen is the length of the fixed-size segments after a deep wildcard
tailLen int
// verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part.
verb string
}
// NewPattern returns a new Pattern from the given definition values.
// "ops" is a sequence of op codes. "pool" is a constant pool.
// "verb" is the verb part of the pattern. It is empty if the pattern does not have the part.
// "version" must be 1 for now.
// It returns an error if the given definition is invalid.
func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) {
if version != 1 {
grpclog.Infof("unsupported version: %d", version)
return Pattern{}, ErrInvalidPattern
}
l := len(ops)
if l%2 != 0 {
grpclog.Infof("odd number of ops codes: %d", l)
return Pattern{}, ErrInvalidPattern
}
var (
typedOps []op
stack, maxstack int
tailLen int
pushMSeen bool
vars []string
)
for i := 0; i < l; i += 2 {
op := op{code: utilities.OpCode(ops[i]), operand: ops[i+1]}
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
if pushMSeen {
tailLen++
}
stack++
case utilities.OpPushM:
if pushMSeen {
grpclog.Infof("pushM appears twice")
return Pattern{}, ErrInvalidPattern
}
pushMSeen = true
stack++
case utilities.OpLitPush:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Infof("negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
if pushMSeen {
tailLen++
}
stack++
case utilities.OpConcatN:
if op.operand <= 0 {
grpclog.Infof("negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
stack -= op.operand
if stack < 0 {
grpclog.Info("stack underflow")
return Pattern{}, ErrInvalidPattern
}
stack++
case utilities.OpCapture:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Infof("variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
v := pool[op.operand]
op.operand = len(vars)
vars = append(vars, v)
stack--
if stack < 0 {
grpclog.Infof("stack underflow")
return Pattern{}, ErrInvalidPattern
}
default:
grpclog.Infof("invalid opcode: %d", op.code)
return Pattern{}, ErrInvalidPattern
}
if maxstack < stack {
maxstack = stack
}
typedOps = append(typedOps, op)
}
return Pattern{
ops: typedOps,
pool: pool,
vars: vars,
stacksize: maxstack,
tailLen: tailLen,
verb: verb,
}, nil
}
// MustPattern is a helper function which makes it easier to call NewPattern in variable initialization.
func MustPattern(p Pattern, err error) Pattern {
if err != nil {
grpclog.Fatalf("Pattern initialization failed: %v", err)
}
return p
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// MatchAndEscape will return an error if no Patterns matched or if a pattern
// matched but contained malformed escape sequences. If successful, the function
// returns a mapping from field paths to their captured values.
func (p Pattern) MatchAndEscape(components []string, verb string, unescapingMode UnescapingMode) (map[string]string, error) {
if p.verb != verb {
if p.verb != "" {
return nil, ErrNotMatch
}
if len(components) == 0 {
components = []string{":" + verb}
} else {
components = append([]string{}, components...)
components[len(components)-1] += ":" + verb
}
}
var pos int
stack := make([]string, 0, p.stacksize)
captured := make([]string, len(p.vars))
l := len(components)
for _, op := range p.ops {
var err error
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush, utilities.OpLitPush:
if pos >= l {
return nil, ErrNotMatch
}
c := components[pos]
if op.code == utilities.OpLitPush {
if lit := p.pool[op.operand]; c != lit {
return nil, ErrNotMatch
}
} else if op.code == utilities.OpPush {
if c, err = unescape(c, unescapingMode, false); err != nil {
return nil, err
}
}
stack = append(stack, c)
pos++
case utilities.OpPushM:
end := len(components)
if end < pos+p.tailLen {
return nil, ErrNotMatch
}
end -= p.tailLen
c := strings.Join(components[pos:end], "/")
if c, err = unescape(c, unescapingMode, true); err != nil {
return nil, err
}
stack = append(stack, c)
pos = end
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
captured[op.operand] = stack[n]
stack = stack[:n]
}
}
if pos < l {
return nil, ErrNotMatch
}
bindings := make(map[string]string)
for i, val := range captured {
bindings[p.vars[i]] = val
}
return bindings, nil
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// It will never perform per-component unescaping (see: UnescapingModeLegacy).
// MatchAndEscape will return an error if no Patterns matched. If successful,
// the function returns a mapping from field paths to their captured values.
//
// Deprecated: Use MatchAndEscape.
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
return p.MatchAndEscape(components, verb, UnescapingModeDefault)
}
// Verb returns the verb part of the Pattern.
func (p Pattern) Verb() string { return p.verb }
func (p Pattern) String() string {
var stack []string
for _, op := range p.ops {
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
stack = append(stack, "*")
case utilities.OpLitPush:
stack = append(stack, p.pool[op.operand])
case utilities.OpPushM:
stack = append(stack, "**")
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
stack[n] = fmt.Sprintf("{%s=%s}", p.vars[op.operand], stack[n])
}
}
segs := strings.Join(stack, "/")
if p.verb != "" {
return fmt.Sprintf("/%s:%s", segs, p.verb)
}
return "/" + segs
}
/*
* The following code is adopted and modified from Go's standard library
* and carries the attached license.
*
* Copyright 2009 The Go Authors. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*/
// ishex returns whether or not the given byte is a valid hex character
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func isRFC6570Reserved(c byte) bool {
switch c {
case '!', '#', '$', '&', '\'', '(', ')', '*',
'+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
return true
default:
return false
}
}
// unhex converts a hex point to the bit representation
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// shouldUnescapeWithMode returns true if the character is escapable with the
// given mode
func shouldUnescapeWithMode(c byte, mode UnescapingMode) bool {
switch mode {
case UnescapingModeAllExceptReserved:
if isRFC6570Reserved(c) {
return false
}
case UnescapingModeAllExceptSlash:
if c == '/' {
return false
}
case UnescapingModeAllCharacters:
return true
}
return true
}
// unescape unescapes a path string using the provided mode
func unescape(s string, mode UnescapingMode, multisegment bool) (string, error) {
// TODO(v3): remove UnescapingModeLegacy
if mode == UnescapingModeLegacy {
return s, nil
}
if !multisegment {
mode = UnescapingModeAllCharacters
}
// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
if s[i] == '%' {
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", MalformedSequenceError(s)
}
i += 3
} else {
i++
}
}
if n == 0 {
return s, nil
}
var t strings.Builder
t.Grow(len(s))
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
c := unhex(s[i+1])<<4 | unhex(s[i+2])
if shouldUnescapeWithMode(c, mode) {
t.WriteByte(c)
i += 2
continue
}
fallthrough
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}

View File

@@ -0,0 +1,80 @@
package runtime
import (
"google.golang.org/protobuf/proto"
)
// StringP returns a pointer to a string whose pointee is same as the given string value.
func StringP(val string) (*string, error) {
return proto.String(val), nil
}
// BoolP parses the given string representation of a boolean value,
// and returns a pointer to a bool whose value is same as the parsed value.
func BoolP(val string) (*bool, error) {
b, err := Bool(val)
if err != nil {
return nil, err
}
return proto.Bool(b), nil
}
// Float64P parses the given string representation of a floating point number,
// and returns a pointer to a float64 whose value is same as the parsed number.
func Float64P(val string) (*float64, error) {
f, err := Float64(val)
if err != nil {
return nil, err
}
return proto.Float64(f), nil
}
// Float32P parses the given string representation of a floating point number,
// and returns a pointer to a float32 whose value is same as the parsed number.
func Float32P(val string) (*float32, error) {
f, err := Float32(val)
if err != nil {
return nil, err
}
return proto.Float32(f), nil
}
// Int64P parses the given string representation of an integer
// and returns a pointer to a int64 whose value is same as the parsed integer.
func Int64P(val string) (*int64, error) {
i, err := Int64(val)
if err != nil {
return nil, err
}
return proto.Int64(i), nil
}
// Int32P parses the given string representation of an integer
// and returns a pointer to a int32 whose value is same as the parsed integer.
func Int32P(val string) (*int32, error) {
i, err := Int32(val)
if err != nil {
return nil, err
}
return proto.Int32(i), err
}
// Uint64P parses the given string representation of an integer
// and returns a pointer to a uint64 whose value is same as the parsed integer.
func Uint64P(val string) (*uint64, error) {
i, err := Uint64(val)
if err != nil {
return nil, err
}
return proto.Uint64(i), err
}
// Uint32P parses the given string representation of an integer
// and returns a pointer to a uint32 whose value is same as the parsed integer.
func Uint32P(val string) (*uint32, error) {
i, err := Uint32(val)
if err != nil {
return nil, err
}
return proto.Uint32(i), err
}

View File

@@ -0,0 +1,329 @@
package runtime
import (
"encoding/base64"
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
var currentQueryParser QueryParameterParser = &defaultQueryParser{}
// QueryParameterParser defines interface for all query parameter parsers
type QueryParameterParser interface {
Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
}
// PopulateQueryParameters parses query parameters
// into "msg" using current query parser
func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
return currentQueryParser.Parse(msg, values, filter)
}
type defaultQueryParser struct{}
// Parse populates "values" into "msg".
// A value is ignored if its key starts with one of the elements in "filter".
func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
for key, values := range values {
match := valuesKeyRegexp.FindStringSubmatch(key)
if len(match) == 3 {
key = match[1]
values = append([]string{match[2]}, values...)
}
fieldPath := strings.Split(key, ".")
if filter.HasCommonPrefix(fieldPath) {
continue
}
if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil {
return err
}
}
return nil
}
// PopulateFieldFromPath sets a value in a nested Protobuf structure.
func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
fieldPath := strings.Split(fieldPathString, ".")
return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
}
func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
if len(fieldPath) < 1 {
return errors.New("no field path")
}
if len(values) < 1 {
return errors.New("no value provided")
}
var fieldDescriptor protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := msgValue.Descriptor().Fields()
// Get field by name
fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
if fieldDescriptor == nil {
fieldDescriptor = fields.ByJSONName(fieldName)
if fieldDescriptor == nil {
// We're not returning an error here because this could just be
// an extra query parameter that isn't part of the request.
grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
return nil
}
}
// If this is the last element, we're done
if i == len(fieldPath)-1 {
break
}
// Only singular message fields are allowed
if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
return fmt.Errorf("invalid path: %q is not a message", fieldName)
}
// Get the nested message
msgValue = msgValue.Mutable(fieldDescriptor).Message()
}
// Check if oneof already set
if of := fieldDescriptor.ContainingOneof(); of != nil {
if f := msgValue.WhichOneof(of); f != nil {
return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
}
}
switch {
case fieldDescriptor.IsList():
return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
case fieldDescriptor.IsMap():
return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
}
if len(values) > 1 {
return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
}
return populateField(fieldDescriptor, msgValue, values[0])
}
func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
}
msgValue.Set(fieldDescriptor, v)
return nil
}
func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
for _, value := range values {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
}
list.Append(v)
}
return nil
}
func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
if len(values) != 2 {
return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
}
key, err := parseField(fieldDescriptor.MapKey(), values[0])
if err != nil {
return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
}
value, err := parseField(fieldDescriptor.MapValue(), values[1])
if err != nil {
return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
}
mp.Set(key.MapKey(), value)
return nil
}
func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBool(v), nil
case protoreflect.EnumKind:
enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
switch {
case errors.Is(err, protoregistry.NotFound):
return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
case err != nil:
return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
}
// Look for enum by name
v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
if v == nil {
i, err := strconv.Atoi(value)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
// Look for enum by number
v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
if v == nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
}
return protoreflect.ValueOfEnum(v.Number()), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt32(int32(v)), nil
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt64(v), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint32(uint32(v)), nil
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint64(v), nil
case protoreflect.FloatKind:
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat32(float32(v)), nil
case protoreflect.DoubleKind:
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat64(v), nil
case protoreflect.StringKind:
return protoreflect.ValueOfString(value), nil
case protoreflect.BytesKind:
v, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBytes(v), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return parseMessage(fieldDescriptor.Message(), value)
default:
panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
}
}
func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
var msg proto.Message
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
if value == "null" {
break
}
t, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return protoreflect.Value{}, err
}
msg = timestamppb.New(t)
case "google.protobuf.Duration":
if value == "null" {
break
}
d, err := time.ParseDuration(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = durationpb.New(d)
case "google.protobuf.DoubleValue":
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.DoubleValue{Value: v}
case "google.protobuf.FloatValue":
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.FloatValue{Value: float32(v)}
case "google.protobuf.Int64Value":
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.Int64Value{Value: v}
case "google.protobuf.Int32Value":
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.Int32Value{Value: int32(v)}
case "google.protobuf.UInt64Value":
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.UInt64Value{Value: v}
case "google.protobuf.UInt32Value":
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.UInt32Value{Value: uint32(v)}
case "google.protobuf.BoolValue":
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.BoolValue{Value: v}
case "google.protobuf.StringValue":
msg = &wrapperspb.StringValue{Value: value}
case "google.protobuf.BytesValue":
v, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.BytesValue{Value: v}
case "google.protobuf.FieldMask":
fm := &field_mask.FieldMask{}
fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
msg = fm
default:
return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
}

View File

@@ -0,0 +1,27 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "utilities",
srcs = [
"doc.go",
"pattern.go",
"readerfactory.go",
"trie.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/utilities",
)
go_test(
name = "utilities_test",
size = "small",
srcs = ["trie_test.go"],
deps = [":utilities"],
)
alias(
name = "go_default_library",
actual = ":utilities",
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,2 @@
// Package utilities provides members for internal use in grpc-gateway.
package utilities

View File

@@ -0,0 +1,22 @@
package utilities
// An OpCode is a opcode of compiled path patterns.
type OpCode int
// These constants are the valid values of OpCode.
const (
// OpNop does nothing
OpNop = OpCode(iota)
// OpPush pushes a component to stack
OpPush
// OpLitPush pushes a component to stack if it matches to the literal
OpLitPush
// OpPushM concatenates the remaining components and pushes it to stack
OpPushM
// OpConcatN pops N items from stack, concatenates them and pushes it back to stack
OpConcatN
// OpCapture pops an item and binds it to the variable
OpCapture
// OpEnd is the least positive invalid opcode.
OpEnd
)

View File

@@ -0,0 +1,20 @@
package utilities
import (
"bytes"
"io"
"io/ioutil"
)
// IOReaderFactory takes in an io.Reader and returns a function that will allow you to create a new reader that begins
// at the start of the stream
func IOReaderFactory(r io.Reader) (func() io.Reader, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
return func() io.Reader {
return bytes.NewReader(b)
}, nil
}

View File

@@ -0,0 +1,174 @@
package utilities
import (
"sort"
)
// DoubleArray is a Double Array implementation of trie on sequences of strings.
type DoubleArray struct {
// Encoding keeps an encoding from string to int
Encoding map[string]int
// Base is the base array of Double Array
Base []int
// Check is the check array of Double Array
Check []int
}
// NewDoubleArray builds a DoubleArray from a set of sequences of strings.
func NewDoubleArray(seqs [][]string) *DoubleArray {
da := &DoubleArray{Encoding: make(map[string]int)}
if len(seqs) == 0 {
return da
}
encoded := registerTokens(da, seqs)
sort.Sort(byLex(encoded))
root := node{row: -1, col: -1, left: 0, right: len(encoded)}
addSeqs(da, encoded, 0, root)
for i := len(da.Base); i > 0; i-- {
if da.Check[i-1] != 0 {
da.Base = da.Base[:i]
da.Check = da.Check[:i]
break
}
}
return da
}
func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
var result [][]int
for _, seq := range seqs {
var encoded []int
for _, token := range seq {
if _, ok := da.Encoding[token]; !ok {
da.Encoding[token] = len(da.Encoding)
}
encoded = append(encoded, da.Encoding[token])
}
result = append(result, encoded)
}
for i := range result {
result[i] = append(result[i], len(da.Encoding))
}
return result
}
type node struct {
row, col int
left, right int
}
func (n node) value(seqs [][]int) int {
return seqs[n.row][n.col]
}
func (n node) children(seqs [][]int) []*node {
var result []*node
lastVal := int(-1)
last := new(node)
for i := n.left; i < n.right; i++ {
if lastVal == seqs[i][n.col+1] {
continue
}
last.right = i
last = &node{
row: i,
col: n.col + 1,
left: i,
}
result = append(result, last)
}
last.right = n.right
return result
}
func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
ensureSize(da, pos)
children := n.children(seqs)
var i int
for i = 1; ; i++ {
ok := func() bool {
for _, child := range children {
code := child.value(seqs)
j := i + code
ensureSize(da, j)
if da.Check[j] != 0 {
return false
}
}
return true
}()
if ok {
break
}
}
da.Base[pos] = i
for _, child := range children {
code := child.value(seqs)
j := i + code
da.Check[j] = pos + 1
}
terminator := len(da.Encoding)
for _, child := range children {
code := child.value(seqs)
if code == terminator {
continue
}
j := i + code
addSeqs(da, seqs, j, *child)
}
}
func ensureSize(da *DoubleArray, i int) {
for i >= len(da.Base) {
da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
}
}
type byLex [][]int
func (l byLex) Len() int { return len(l) }
func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l byLex) Less(i, j int) bool {
si := l[i]
sj := l[j]
var k int
for k = 0; k < len(si) && k < len(sj); k++ {
if si[k] < sj[k] {
return true
}
if si[k] > sj[k] {
return false
}
}
return k < len(sj)
}
// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
if len(da.Base) == 0 {
return false
}
var i int
for _, t := range seq {
code, ok := da.Encoding[t]
if !ok {
break
}
j := da.Base[i] + code
if len(da.Check) <= j || da.Check[j] != i+1 {
break
}
i = j
}
j := da.Base[i] + len(da.Encoding)
if len(da.Check) <= j || da.Check[j] != i+1 {
return false
}
return true
}

View File

@@ -3,6 +3,7 @@ package assert
import (
"fmt"
"reflect"
"time"
)
type CompareType int
@@ -30,6 +31,8 @@ var (
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
}
return compareEqual, false
@@ -310,7 +334,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
@@ -320,7 +347,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
@@ -329,7 +359,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
@@ -339,7 +372,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
@@ -347,8 +383,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
@@ -356,8 +395,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatability
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {

View File

@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
return ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {

View File

@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -718,10 +718,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (false, false) if impossible.
// return (true, false) if element was not found.
// return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) {
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind()
listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() {
if e := recover(); e != nil {
ok = false
@@ -764,7 +768,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
@@ -787,7 +791,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
}
@@ -831,7 +835,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -852,7 +856,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper()
}
if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
@@ -875,7 +879,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -1000,27 +1004,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}, string) {
didPanic := false
var message interface{}
var stack string
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
stack = string(debug.Stack())
}
}()
// call the target function
f()
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}()
return didPanic, message, stack
// call the target function
f()
didPanic = false
return
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -1161,11 +1159,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual)
if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
}
if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
return Fail(t, "Expected must not be NaN", msgAndArgs...)
}
if math.IsNaN(bf) {
@@ -1188,7 +1190,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1250,8 +1252,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected)
if !aok {
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
bf, bok := toFloat(actual)
if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
}
if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN")
@@ -1259,10 +1265,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
}
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
}
if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN")
}
@@ -1298,7 +1300,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1375,6 +1377,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1588,12 +1611,17 @@ func diff(expected interface{}, actual interface{}) string {
}
var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
switch et {
case reflect.TypeOf(""):
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1625,6 +1653,14 @@ var spewConfig = spew.ConfigState{
MaxDepth: 10,
}
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface {
Helper()
}

View File

@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
t.FailNow()
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
return
}
t.FailNow()
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
return
}
t.FailNow()
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {

View File

@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {