mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 01:09:57 +00:00
Add db-connect, a SQL over HTTPS server
This commit is contained in:
145
dbconnect/client.go
Normal file
145
dbconnect/client.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Client is an interface to talk to any database.
|
||||
//
|
||||
// Currently, the only implementation is SQLClient, but its structure
|
||||
// should be designed to handle a MongoClient or RedisClient in the future.
|
||||
type Client interface {
|
||||
Ping(context.Context) error
|
||||
Submit(context.Context, *Command) (interface{}, error)
|
||||
}
|
||||
|
||||
// NewClient creates a database client based on its URL scheme.
|
||||
func NewClient(ctx context.Context, originURL *url.URL) (Client, error) {
|
||||
return NewSQLClient(ctx, originURL)
|
||||
}
|
||||
|
||||
// Command is a standard, non-vendor format for submitting database commands.
|
||||
//
|
||||
// When determining the scope of this struct, refer to the following litmus test:
|
||||
// Could this (roughly) conform to SQL, Document-based, and Key-value command formats?
|
||||
type Command struct {
|
||||
Statement string `json:"statement"`
|
||||
Arguments Arguments `json:"arguments,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Isolation string `json:"isolation,omitempty"`
|
||||
Timeout time.Duration `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// Validate enforces the contract of Command: non empty statement (both in length and logic),
|
||||
// lowercase mode and isolation, non-zero timeout, and valid Arguments.
|
||||
func (cmd *Command) Validate() error {
|
||||
if cmd.Statement == "" {
|
||||
return fmt.Errorf("cannot provide an empty statement")
|
||||
}
|
||||
|
||||
if strings.Map(func(char rune) rune {
|
||||
if char == ';' || unicode.IsSpace(char) {
|
||||
return -1
|
||||
}
|
||||
return char
|
||||
}, cmd.Statement) == "" {
|
||||
return fmt.Errorf("cannot provide a statement with no logic: '%s'", cmd.Statement)
|
||||
}
|
||||
|
||||
cmd.Mode = strings.ToLower(cmd.Mode)
|
||||
cmd.Isolation = strings.ToLower(cmd.Isolation)
|
||||
|
||||
if cmd.Timeout.Nanoseconds() <= 0 {
|
||||
cmd.Timeout = 24 * time.Hour
|
||||
}
|
||||
|
||||
return cmd.Arguments.Validate()
|
||||
}
|
||||
|
||||
// UnmarshalJSON converts a byte representation of JSON into a Command, which is also validated.
|
||||
func (cmd *Command) UnmarshalJSON(data []byte) error {
|
||||
// Alias is required to avoid infinite recursion from the default UnmarshalJSON.
|
||||
type Alias Command
|
||||
alias := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(cmd),
|
||||
}
|
||||
|
||||
err := json.Unmarshal(data, &alias)
|
||||
if err == nil {
|
||||
err = cmd.Validate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Arguments is a wrapper for either map-based or array-based Command arguments.
|
||||
//
|
||||
// Each field is mutually-exclusive and some Client implementations may not
|
||||
// support both fields (eg. MySQL does not accept named arguments).
|
||||
type Arguments struct {
|
||||
Named map[string]interface{}
|
||||
Positional []interface{}
|
||||
}
|
||||
|
||||
// Validate enforces the contract of Arguments: non nil, mutually exclusive, and no empty or reserved keys.
|
||||
func (args *Arguments) Validate() error {
|
||||
if args.Named == nil {
|
||||
args.Named = map[string]interface{}{}
|
||||
}
|
||||
if args.Positional == nil {
|
||||
args.Positional = []interface{}{}
|
||||
}
|
||||
|
||||
if len(args.Named) > 0 && len(args.Positional) > 0 {
|
||||
return fmt.Errorf("both named and positional arguments cannot be specified: %+v and %+v", args.Named, args.Positional)
|
||||
}
|
||||
|
||||
for key := range args.Named {
|
||||
if key == "" {
|
||||
return fmt.Errorf("named arguments cannot contain an empty key: %+v", args.Named)
|
||||
}
|
||||
if !utf8.ValidString(key) {
|
||||
return fmt.Errorf("named argument does not conform to UTF-8 encoding: %s", key)
|
||||
}
|
||||
if strings.HasPrefix(key, "_") {
|
||||
return fmt.Errorf("named argument cannot start with a reserved keyword '_': %s", key)
|
||||
}
|
||||
if unicode.IsNumber([]rune(key)[0]) {
|
||||
return fmt.Errorf("named argument cannot start with a number: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON converts a byte representation of JSON into Arguments, which is also validated.
|
||||
func (args *Arguments) UnmarshalJSON(data []byte) error {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal(data, &obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
named, ok := obj.(map[string]interface{})
|
||||
if ok {
|
||||
args.Named = named
|
||||
} else {
|
||||
positional, ok := obj.([]interface{})
|
||||
if ok {
|
||||
args.Positional = positional
|
||||
} else {
|
||||
return fmt.Errorf("arguments must either be an object {\"0\":\"val\"} or an array [\"val\"]: %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
return args.Validate()
|
||||
}
|
183
dbconnect/client_test.go
Normal file
183
dbconnect/client_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommandValidateEmpty(t *testing.T) {
|
||||
stmts := []string{
|
||||
"",
|
||||
";",
|
||||
" \n\t",
|
||||
";\n;\t;",
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
cmd := Command{Statement: stmt}
|
||||
|
||||
assert.Error(t, cmd.Validate(), stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateMode(t *testing.T) {
|
||||
modes := []string{
|
||||
"",
|
||||
"query",
|
||||
"ExEc",
|
||||
"PREPARE",
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
cmd := Command{Statement: "Ok", Mode: mode}
|
||||
|
||||
assert.NoError(t, cmd.Validate(), mode)
|
||||
assert.Equal(t, strings.ToLower(mode), cmd.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateIsolation(t *testing.T) {
|
||||
isos := []string{
|
||||
"",
|
||||
"default",
|
||||
"read_committed",
|
||||
"SNAPshot",
|
||||
}
|
||||
|
||||
for _, iso := range isos {
|
||||
cmd := Command{Statement: "Ok", Isolation: iso}
|
||||
|
||||
assert.NoError(t, cmd.Validate(), iso)
|
||||
assert.Equal(t, strings.ToLower(iso), cmd.Isolation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateTimeout(t *testing.T) {
|
||||
cmd := Command{Statement: "Ok", Timeout: 0}
|
||||
|
||||
assert.NoError(t, cmd.Validate())
|
||||
assert.NotZero(t, cmd.Timeout)
|
||||
|
||||
cmd = Command{Statement: "Ok", Timeout: 1 * time.Second}
|
||||
|
||||
assert.NoError(t, cmd.Validate())
|
||||
assert.Equal(t, 1*time.Second, cmd.Timeout)
|
||||
}
|
||||
|
||||
func TestCommandValidateArguments(t *testing.T) {
|
||||
cmd := Command{Statement: "Ok", Arguments: Arguments{
|
||||
Named: map[string]interface{}{"key": "val"},
|
||||
Positional: []interface{}{"val"},
|
||||
}}
|
||||
|
||||
assert.Error(t, cmd.Validate())
|
||||
}
|
||||
|
||||
func TestCommandUnmarshalJSON(t *testing.T) {
|
||||
strs := []string{
|
||||
"{\"statement\":\"Ok\"}",
|
||||
"{\"statement\":\"Ok\",\"arguments\":[0, 3.14, \"apple\"],\"mode\":\"query\"}",
|
||||
"{\"statement\":\"Ok\",\"isolation\":\"read_uncommitted\",\"timeout\":1000}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var cmd Command
|
||||
assert.NoError(t, json.Unmarshal([]byte(str), &cmd), str)
|
||||
}
|
||||
|
||||
strs = []string{
|
||||
"",
|
||||
"\"",
|
||||
"{}",
|
||||
"{\"argument\":{\"key\":\"val\"}}",
|
||||
"{\"statement\":[\"Ok\"]}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var cmd Command
|
||||
assert.Error(t, json.Unmarshal([]byte(str), &cmd), str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsValidateNotNil(t *testing.T) {
|
||||
args := Arguments{}
|
||||
|
||||
assert.NoError(t, args.Validate())
|
||||
assert.NotNil(t, args.Named)
|
||||
assert.NotNil(t, args.Positional)
|
||||
}
|
||||
|
||||
func TestArgumentsValidateMutuallyExclusive(t *testing.T) {
|
||||
args := []Arguments{
|
||||
Arguments{},
|
||||
Arguments{Named: map[string]interface{}{"key": "val"}},
|
||||
Arguments{Positional: []interface{}{"val"}},
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
assert.NoError(t, arg.Validate())
|
||||
assert.False(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
|
||||
}
|
||||
|
||||
args = []Arguments{
|
||||
Arguments{
|
||||
Named: map[string]interface{}{"key": "val"},
|
||||
Positional: []interface{}{"val"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
assert.Error(t, arg.Validate())
|
||||
assert.True(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsValidateKeys(t *testing.T) {
|
||||
keys := []string{
|
||||
"",
|
||||
"_",
|
||||
"_key",
|
||||
"1",
|
||||
"1key",
|
||||
"\xf0\x28\x8c\xbc", // non-utf8
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
args := Arguments{Named: map[string]interface{}{key: "val"}}
|
||||
|
||||
assert.Error(t, args.Validate(), key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsUnmarshalJSON(t *testing.T) {
|
||||
strs := []string{
|
||||
"{}",
|
||||
"{\"key\":\"val\"}",
|
||||
"{\"key\":[1, 3.14, {\"key\":\"val\"}]}",
|
||||
"[]",
|
||||
"[\"key\",\"val\"]",
|
||||
"[{}]",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var args Arguments
|
||||
assert.NoError(t, json.Unmarshal([]byte(str), &args), str)
|
||||
}
|
||||
|
||||
strs = []string{
|
||||
"",
|
||||
"\"",
|
||||
"1",
|
||||
"\"key\"",
|
||||
"{\"key\",\"val\"}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var args Arguments
|
||||
assert.Error(t, json.Unmarshal([]byte(str), &args), str)
|
||||
}
|
||||
}
|
157
dbconnect/cmd.go
Normal file
157
dbconnect/cmd.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
"gopkg.in/urfave/cli.v2/altsrc"
|
||||
)
|
||||
|
||||
// Cmd is the entrypoint command for dbconnect.
|
||||
//
|
||||
// The tunnel package is responsible for appending this to tunnel.Commands().
|
||||
func Cmd() *cli.Command {
|
||||
return &cli.Command{
|
||||
Category: "Database Connect (ALPHA)",
|
||||
Name: "db-connect",
|
||||
Usage: "Access your SQL database from Cloudflare Workers or the browser",
|
||||
ArgsUsage: " ",
|
||||
Description: `
|
||||
Creates a connection between your database and the Cloudflare edge.
|
||||
Now you can execute SQL commands anywhere you can send HTTPS requests.
|
||||
|
||||
Connect your database with any of the following commands, you can also try the "playground" without a database:
|
||||
|
||||
cloudflared db-connect --hostname sql.mysite.com --url postgres://user:pass@localhost?sslmode=disable \
|
||||
--auth-domain mysite.cloudflareaccess.com --application-aud my-access-policy-tag
|
||||
cloudflared db-connect --hostname sql-dev.mysite.com --url mysql://localhost --insecure
|
||||
cloudflared db-connect --playground
|
||||
|
||||
Requests should be authenticated using Cloudflare Access, learn more about how to enable it here:
|
||||
|
||||
https://developers.cloudflare.com/access/service-auth/service-token/
|
||||
`,
|
||||
Flags: []cli.Flag{
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "url",
|
||||
Usage: "URL to the database (eg. postgres://user:pass@localhost?sslmode=disable)",
|
||||
EnvVars: []string{"TUNNEL_URL"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "hostname",
|
||||
Usage: "Hostname to accept commands over HTTPS (eg. sql.mysite.com)",
|
||||
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "auth-domain",
|
||||
Usage: "Cloudflare Access authentication domain for your account (eg. mysite.cloudflareaccess.com)",
|
||||
EnvVars: []string{"TUNNEL_ACCESS_AUTH_DOMAIN"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "application-aud",
|
||||
Usage: "Cloudflare Access application \"AUD\" to verify JWTs from requests",
|
||||
EnvVars: []string{"TUNNEL_ACCESS_APPLICATION_AUD"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "insecure",
|
||||
Usage: "Disable authentication, the database will be open to the Internet",
|
||||
Value: false,
|
||||
EnvVars: []string{"TUNNEL_ACCESS_INSECURE"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "playground",
|
||||
Usage: "Run a temporary, in-memory SQLite3 database for testing",
|
||||
Value: false,
|
||||
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "loglevel",
|
||||
Value: "debug", // Make it more verbose than the tunnel default 'info'.
|
||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||
Hidden: true,
|
||||
}),
|
||||
},
|
||||
Before: CmdBefore,
|
||||
Action: CmdAction,
|
||||
Hidden: true,
|
||||
}
|
||||
}
|
||||
|
||||
// CmdBefore runs some validation checks before running the command.
|
||||
func CmdBefore(c *cli.Context) error {
|
||||
// Show the help text is no flags are specified.
|
||||
if c.NumFlags() == 0 {
|
||||
return cli.ShowSubcommandHelp(c)
|
||||
}
|
||||
|
||||
// Hello-world and playground are synonymous with each other,
|
||||
// unset hello-world to prevent tunnel from initializing the hello package.
|
||||
if c.IsSet("hello-world") {
|
||||
c.Set("playground", "true")
|
||||
c.Set("hello-world", "false")
|
||||
}
|
||||
|
||||
// Unix-socket database urls are supported, but the logic is the same as url.
|
||||
if c.IsSet("unix-socket") {
|
||||
c.Set("url", c.String("unix-socket"))
|
||||
c.Set("unix-socket", "")
|
||||
}
|
||||
|
||||
// When playground mode is enabled, run with an in-memory database.
|
||||
if c.IsSet("playground") {
|
||||
c.Set("url", "sqlite3::memory:?cache=shared")
|
||||
c.Set("insecure", strconv.FormatBool(!c.IsSet("auth-domain") && !c.IsSet("application-aud")))
|
||||
}
|
||||
|
||||
// At this point, insecure configurations are valid.
|
||||
if c.Bool("insecure") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure that secure configurations specify a hostname, domain, and tag for JWT validation.
|
||||
if !c.IsSet("hostname") || !c.IsSet("auth-domain") || !c.IsSet("application-aud") {
|
||||
log.Fatal("must specify --hostname, --auth-domain, and --application-aud unless you want to run in --insecure mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CmdAction starts the Proxy and sets the url in cli.Context to point to the Proxy address.
|
||||
func CmdAction(c *cli.Context) error {
|
||||
// STOR-612: sync with context in tunnel daemon.
|
||||
ctx := context.Background()
|
||||
|
||||
var proxy *Proxy
|
||||
var err error
|
||||
if c.Bool("insecure") {
|
||||
proxy, err = NewInsecureProxy(ctx, c.String("url"))
|
||||
} else {
|
||||
proxy, err = NewSecureProxy(ctx, c.String("url"), c.String("auth-domain"), c.String("application-aud"))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
listenerC := make(chan net.Listener)
|
||||
defer close(listenerC)
|
||||
|
||||
// Since the Proxy should only talk to the tunnel daemon, find the next available
|
||||
// localhost port and start to listen to requests.
|
||||
go func() {
|
||||
err := proxy.Start(ctx, "127.0.0.1:", listenerC)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Block until the the Proxy is online, retreive its address, and change the url to point to it.
|
||||
// This is effectively "handing over" control to the tunnel package so it can run the tunnel daemon.
|
||||
c.Set("url", "https://"+(<-listenerC).Addr().String())
|
||||
|
||||
return nil
|
||||
}
|
27
dbconnect/cmd_test.go
Normal file
27
dbconnect/cmd_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func TestCmd(t *testing.T) {
|
||||
tests := [][]string{
|
||||
{"cloudflared", "db-connect", "--playground"},
|
||||
{"cloudflared", "db-connect", "--playground", "--hostname", "sql.mysite.com"},
|
||||
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--insecure"},
|
||||
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--hostname", "sql.mysite.com", "--auth-domain", "mysite.cloudflareaccess.com", "--application-aud", "aud"},
|
||||
}
|
||||
|
||||
app := &cli.App{
|
||||
Name: "cloudflared",
|
||||
Commands: []*cli.Command{Cmd()},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.NoError(t, app.Run(test))
|
||||
}
|
||||
}
|
271
dbconnect/proxy.go
Normal file
271
dbconnect/proxy.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Proxy is an HTTP server that proxies requests to a Client.
|
||||
type Proxy struct {
|
||||
client Client
|
||||
accessValidator *validation.Access
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewInsecureProxy creates a Proxy that talks to a Client at an origin.
|
||||
//
|
||||
// In insecure mode, the Proxy will allow all Command requests.
|
||||
func NewInsecureProxy(ctx context.Context, origin string) (*Proxy, error) {
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "must provide a valid database url")
|
||||
}
|
||||
|
||||
client, err := NewClient(ctx, originURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = client.Ping(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not connect to the database")
|
||||
}
|
||||
|
||||
return &Proxy{client, nil, logrus.New()}, nil
|
||||
}
|
||||
|
||||
// NewSecureProxy creates a Proxy that talks to a Client at an origin.
|
||||
//
|
||||
// In secure mode, the Proxy will reject any Command requests that are
|
||||
// not authenticated by Cloudflare Access with a valid JWT.
|
||||
func NewSecureProxy(ctx context.Context, origin, authDomain, applicationAUD string) (*Proxy, error) {
|
||||
proxy, err := NewInsecureProxy(ctx, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validator, err := validation.NewAccessValidator(ctx, authDomain, authDomain, applicationAUD)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxy.accessValidator = validator
|
||||
|
||||
return proxy, err
|
||||
}
|
||||
|
||||
// IsInsecure gets whether the Proxy will accept a Command from any source.
|
||||
func (proxy *Proxy) IsInsecure() bool {
|
||||
return proxy.accessValidator == nil
|
||||
}
|
||||
|
||||
// IsAllowed checks whether a http.Request is allowed to receive data.
|
||||
//
|
||||
// By default, requests must pass through Cloudflare Access for authentication.
|
||||
// If the proxy is explcitly set to insecure mode, all requests will be allowed.
|
||||
func (proxy *Proxy) IsAllowed(r *http.Request, verbose ...bool) bool {
|
||||
if proxy.IsInsecure() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Access and Tunnel should prevent bad JWTs from even reaching the origin,
|
||||
// but validate tokens anyway as an abundance of caution.
|
||||
err := proxy.accessValidator.ValidateRequest(r.Context(), r)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Warn administrators that invalid JWTs are being rejected. This is indicative
|
||||
// of either a misconfiguration of the CLI or a massive failure of upstream systems.
|
||||
if len(verbose) > 0 {
|
||||
proxy.httpLog(r, err).Error("Failed JWT authentication")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Start the Proxy at a given address and notify the listener channel when the server is online.
|
||||
func (proxy *Proxy) Start(ctx context.Context, addr string, listenerC chan<- net.Listener) error {
|
||||
// STOR-611: use a seperate listener and consider web socket support.
|
||||
httpListener, err := hello.CreateTLSListener(addr)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not create listener at %s", addr)
|
||||
}
|
||||
|
||||
errC := make(chan error)
|
||||
defer close(errC)
|
||||
|
||||
// Starts the HTTP server and begins to serve requests.
|
||||
go func() {
|
||||
errC <- proxy.httpListen(ctx, httpListener)
|
||||
}()
|
||||
|
||||
// Continually ping the server until it comes online or 10 attempts fail.
|
||||
go func() {
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = http.Get("http://" + httpListener.Addr().String())
|
||||
|
||||
// Once no error was detected, notify the listener channel and return.
|
||||
if err == nil {
|
||||
listenerC <- httpListener
|
||||
return
|
||||
}
|
||||
|
||||
// Backoff between requests to ping the server.
|
||||
<-time.After(1 * time.Second)
|
||||
}
|
||||
errC <- errors.Wrap(err, "took too long for the http server to start")
|
||||
}()
|
||||
|
||||
return <-errC
|
||||
}
|
||||
|
||||
// httpListen starts the httpServer and blocks until the context closes.
|
||||
func (proxy *Proxy) httpListen(ctx context.Context, listener net.Listener) error {
|
||||
httpServer := &http.Server{
|
||||
Addr: listener.Addr().String(),
|
||||
Handler: proxy.httpRouter(),
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
httpServer.Close()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
return httpServer.Serve(listener)
|
||||
}
|
||||
|
||||
// httpRouter creates a mux.Router for the Proxy.
|
||||
func (proxy *Proxy) httpRouter() *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/ping", proxy.httpPing()).Methods("GET", "HEAD")
|
||||
router.HandleFunc("/submit", proxy.httpSubmit()).Methods("POST")
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// httpPing tests the connection to the database.
|
||||
//
|
||||
// By default, this endpoint is unauthenticated to allow for health checks.
|
||||
// To enable authentication, Cloudflare Access must be enabled on this route.
|
||||
func (proxy *Proxy) httpPing() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
err := proxy.client.Ping(ctx)
|
||||
|
||||
if err == nil {
|
||||
proxy.httpRespond(w, r, http.StatusOK, "")
|
||||
} else {
|
||||
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpSubmit sends a command to the database and returns its response.
|
||||
//
|
||||
// By default, this endpoint will reject requests that do not pass through Cloudflare Access.
|
||||
// To disable authentication, the --insecure flag must be specified in the command line.
|
||||
func (proxy *Proxy) httpSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !proxy.IsAllowed(r, true) {
|
||||
proxy.httpRespondErr(w, r, http.StatusForbidden, fmt.Errorf(""))
|
||||
return
|
||||
}
|
||||
|
||||
var cmd Command
|
||||
err := json.NewDecoder(r.Body).Decode(&cmd)
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
data, err := proxy.client.Submit(ctx, &cmd)
|
||||
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusUnprocessableEntity, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpRespond writes a status code and string response to the response writer.
|
||||
func (proxy *Proxy) httpRespond(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
w.WriteHeader(status)
|
||||
|
||||
// Only expose the message detail of the reponse if the request is not HEAD
|
||||
// and the user is authenticated. For example, this prevents an unauthenticated
|
||||
// failed health check from accidentally leaking sensitive information about the Client.
|
||||
if r.Method != http.MethodHead && proxy.IsAllowed(r) {
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
fmt.Fprint(w, message)
|
||||
}
|
||||
}
|
||||
|
||||
// httpRespondErr is similar to httpRespond, except it formats errors to be more friendly.
|
||||
func (proxy *Proxy) httpRespondErr(w http.ResponseWriter, r *http.Request, defaultStatus int, err error) {
|
||||
status, err := httpError(defaultStatus, err)
|
||||
|
||||
proxy.httpRespond(w, r, status, err.Error())
|
||||
if len(err.Error()) > 0 {
|
||||
proxy.httpLog(r, err).Warn("Database proxy error")
|
||||
}
|
||||
}
|
||||
|
||||
// httpLog returns a logrus.Entry that is formatted to output a request Cf-ray.
|
||||
func (proxy *Proxy) httpLog(r *http.Request, err error) *logrus.Entry {
|
||||
return proxy.logger.WithContext(r.Context()).WithField("CF-RAY", r.Header.Get("Cf-ray")).WithError(err)
|
||||
}
|
||||
|
||||
// httpError extracts common errors and returns an status code and friendly error.
|
||||
func httpError(defaultStatus int, err error) (int, error) {
|
||||
if err == nil {
|
||||
return http.StatusNotImplemented, fmt.Errorf("error expected but found none")
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return http.StatusBadRequest, fmt.Errorf("request body cannot be empty")
|
||||
}
|
||||
|
||||
if err == context.DeadlineExceeded {
|
||||
return http.StatusRequestTimeout, err
|
||||
}
|
||||
|
||||
_, ok := err.(net.Error)
|
||||
if ok {
|
||||
return http.StatusRequestTimeout, err
|
||||
}
|
||||
|
||||
if err == context.Canceled {
|
||||
// Does not exist in Golang, but would be: http.StatusClientClosedWithoutResponse
|
||||
return 444, err
|
||||
}
|
||||
|
||||
return defaultStatus, err
|
||||
}
|
238
dbconnect/proxy_test.go
Normal file
238
dbconnect/proxy_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewInsecureProxy(t *testing.T) {
|
||||
origins := []string{
|
||||
"",
|
||||
":/",
|
||||
"http://localhost",
|
||||
"tcp://localhost:9000?debug=true",
|
||||
"mongodb://127.0.0.1",
|
||||
}
|
||||
|
||||
for _, origin := range origins {
|
||||
proxy, err := NewInsecureProxy(context.Background(), origin)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, proxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyIsAllowed(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
|
||||
assert.True(t, proxy.IsAllowed(req))
|
||||
|
||||
proxy = helperNewProxy(t, true)
|
||||
req.Header.Set("Cf-access-jwt-assertion", "xxx")
|
||||
assert.False(t, proxy.IsAllowed(req))
|
||||
}
|
||||
|
||||
func TestProxyStart(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
ctx := context.Background()
|
||||
listenerC := make(chan net.Listener)
|
||||
|
||||
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
|
||||
assert.Error(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
|
||||
assert.IsType(t, http.ErrServerClosed, err)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRouter(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
router := proxy.httpRouter()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
method string
|
||||
valid bool
|
||||
}{
|
||||
{"", "GET", false},
|
||||
{"/", "GET", false},
|
||||
{"/ping", "GET", true},
|
||||
{"/ping", "HEAD", true},
|
||||
{"/ping", "POST", false},
|
||||
{"/submit", "POST", true},
|
||||
{"/submit", "GET", false},
|
||||
{"/submit/extra", "POST", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
match := &mux.RouteMatch{}
|
||||
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
|
||||
|
||||
assert.True(t, ok == test.valid, test.path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPPing(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(proxy.httpPing())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, int64(2), res.ContentLength)
|
||||
|
||||
res, err = client.Head(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, int64(-1), res.ContentLength)
|
||||
}
|
||||
|
||||
func TestProxyHTTPSubmit(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(proxy.httpSubmit())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
status int
|
||||
output string
|
||||
}{
|
||||
{"", http.StatusBadRequest, "request body cannot be empty"},
|
||||
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
|
||||
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
|
||||
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
|
||||
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.status, res.StatusCode)
|
||||
if res.StatusCode > http.StatusOK {
|
||||
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
|
||||
} else {
|
||||
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
defer res.Body.Close()
|
||||
str := string(data)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.output, str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPSubmitForbidden(t *testing.T) {
|
||||
proxy := helperNewProxy(t, true)
|
||||
|
||||
server := httptest.NewServer(proxy.httpSubmit())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
assert.Zero(t, res.ContentLength)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRespond(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
|
||||
}))
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, int64(5), res.ContentLength)
|
||||
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, []byte("Hello"), data)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRespondForbidden(t *testing.T) {
|
||||
proxy := helperNewProxy(t, true)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
|
||||
}))
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, int64(0), res.ContentLength)
|
||||
}
|
||||
|
||||
func TestHTTPError(t *testing.T) {
|
||||
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
|
||||
assert.Error(t, errTimeout)
|
||||
|
||||
tests := []struct {
|
||||
input error
|
||||
status int
|
||||
output error
|
||||
}{
|
||||
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
|
||||
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
|
||||
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
|
||||
{context.Canceled, 444, nil},
|
||||
{errTimeout, http.StatusRequestTimeout, nil},
|
||||
{fmt.Errorf(""), http.StatusInternalServerError, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
status, err := httpError(http.StatusInternalServerError, test.input)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, test.status, status)
|
||||
if test.output == nil {
|
||||
test.output = test.input
|
||||
}
|
||||
assert.Equal(t, test.output, err)
|
||||
}
|
||||
}
|
||||
|
||||
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
|
||||
t.Helper()
|
||||
|
||||
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proxy)
|
||||
|
||||
if len(secure) == 0 {
|
||||
proxy.accessValidator = nil // Mark as insecure
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
318
dbconnect/sql.go
Normal file
318
dbconnect/sql.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xo/dburl"
|
||||
|
||||
// SQL drivers self-register with the database/sql package.
|
||||
// https://github.com/golang/go/wiki/SQLDrivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/kshvakov/clickhouse"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// SQLClient is a Client that talks to a SQL database.
|
||||
type SQLClient struct {
|
||||
Dialect string
|
||||
driver *sqlx.DB
|
||||
}
|
||||
|
||||
// NewSQLClient creates a SQL client based on its URL scheme.
|
||||
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
|
||||
res, err := dburl.Parse(originURL.String())
|
||||
if err != nil {
|
||||
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
|
||||
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
|
||||
}
|
||||
|
||||
// Establishes the driver, but does not test the connection.
|
||||
driver, err := sqlx.Open(res.Driver, res.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
|
||||
}
|
||||
|
||||
// Closes the driver, will occur when the context finishes.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
driver.Close()
|
||||
}()
|
||||
|
||||
return &SQLClient{driver.DriverName(), driver}, nil
|
||||
}
|
||||
|
||||
// Ping verifies a connection to the database is still alive.
|
||||
func (client *SQLClient) Ping(ctx context.Context) error {
|
||||
return client.driver.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Submit queries or executes a command to the SQL database.
|
||||
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
|
||||
txx, err := cmd.ValidateSQL(client.Dialect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
|
||||
defer cancel()
|
||||
|
||||
var res interface{}
|
||||
|
||||
// Get the next available sql.Conn and submit the Command.
|
||||
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
|
||||
stmt := cmd.Statement
|
||||
args := cmd.Arguments.Positional
|
||||
|
||||
if cmd.Mode == "query" {
|
||||
res, err = sqlQuery(ctx, conn, stmt, args)
|
||||
} else {
|
||||
res, err = sqlExec(ctx, conn, stmt, args)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// ValidateSQL extends the contract of Command for SQL dialects:
|
||||
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
|
||||
//
|
||||
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
|
||||
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
|
||||
err := cmd.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mode, err := sqlMode(cmd.Mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mutates Arguments to only use positional arguments with the type sql.NamedArg.
|
||||
// This is a required by the sql.Driver before submitting arguments.
|
||||
cmd.Arguments.sql(dialect)
|
||||
|
||||
iso, err := sqlIsolation(cmd.Isolation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// When isolation is out-of-range, this is indicative that no
|
||||
// transaction should be executed and sql.TxOptions should be nil.
|
||||
if iso < sql.LevelDefault {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// In query mode, execute the transaction in read-only, unless it's Microsoft SQL
|
||||
// which does not support that type of transaction.
|
||||
readOnly := mode == "query" && dialect != "mssql"
|
||||
|
||||
return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
|
||||
}
|
||||
|
||||
// sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
|
||||
//
|
||||
// If the transaction options are nil, run the useIt function outside a transaction.
|
||||
// This is potentially an unsafe operation if the command does not clean up its state.
|
||||
func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
|
||||
conn, err := driver.Conn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// If transaction options are specified, begin and defer a rollback to catch errors.
|
||||
var tx *sql.Tx
|
||||
if txx != nil {
|
||||
tx, err = conn.BeginTx(ctx, txx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
}
|
||||
|
||||
err = useIt(conn)
|
||||
|
||||
// Check if useIt was successful and a transaction exists before committing.
|
||||
if err == nil && tx != nil {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// sqlQuery queries rows on a sql.Conn and returns an array of result objects.
|
||||
func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := conn.QueryContext(ctx, stmt, args...)
|
||||
if err == nil {
|
||||
return sqlRows(rows)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// sqlExec executes a command on a sql.Conn and returns the result of the operation.
|
||||
func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
|
||||
exec, err := conn.ExecContext(ctx, stmt, args...)
|
||||
if err == nil {
|
||||
return sqlResultFrom(exec), nil
|
||||
}
|
||||
return sqlResult{}, err
|
||||
}
|
||||
|
||||
// sql mutates Arguments to contain a positional []sql.NamedArg.
|
||||
//
|
||||
// The actual return type is []interface{} due to the native Golang
|
||||
// function signatures for sql.Exec and sql.Query being generic.
|
||||
func (args *Arguments) sql(dialect string) {
|
||||
result := args.Positional
|
||||
|
||||
for i, val := range result {
|
||||
result[i] = sqlArg("", val, dialect)
|
||||
}
|
||||
|
||||
for key, val := range args.Named {
|
||||
result = append(result, sqlArg(key, val, dialect))
|
||||
}
|
||||
|
||||
args.Positional = result
|
||||
args.Named = map[string]interface{}{}
|
||||
}
|
||||
|
||||
// sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
|
||||
//
|
||||
// Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
|
||||
func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
|
||||
switch reflect.ValueOf(val).Kind() {
|
||||
|
||||
// PostgreSQL and Clickhouse require arrays to be wrapped before
|
||||
// being inserted into the driver interface.
|
||||
case reflect.Slice, reflect.Array:
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
val = pq.Array(val)
|
||||
case "clickhouse":
|
||||
val = clickhouse.Array(val)
|
||||
}
|
||||
}
|
||||
|
||||
return sql.Named(fmt.Sprint(key), val)
|
||||
}
|
||||
|
||||
// sqlIsolation tries to match a string to a sql.IsolationLevel.
|
||||
func sqlIsolation(str string) (sql.IsolationLevel, error) {
|
||||
if str == "none" {
|
||||
return sql.IsolationLevel(-1), nil
|
||||
}
|
||||
|
||||
for iso := sql.LevelDefault; ; iso++ {
|
||||
if iso > sql.LevelLinearizable {
|
||||
return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
|
||||
}
|
||||
|
||||
if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
|
||||
return iso, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
|
||||
func sqlMode(str string) (string, error) {
|
||||
switch str {
|
||||
case "query", "exec":
|
||||
return str, nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
|
||||
}
|
||||
}
|
||||
|
||||
// sqlRows scans through a SQL result set and returns an array of objects.
|
||||
func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not extract columns from result")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
types, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
// Some drivers do not support type extraction, so fail silently and continue.
|
||||
types = make([]*sql.ColumnType, len(columns))
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columns))
|
||||
pointers := make([]interface{}, len(columns))
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
for i := range columns {
|
||||
pointers[i] = &values[i]
|
||||
}
|
||||
rows.Scan(pointers...)
|
||||
|
||||
// Convert a row, an array of values, into an object where
|
||||
// each key is the name of its respective column.
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = sqlValue(values[i], types[i])
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
|
||||
func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
|
||||
bytes, ok := val.([]byte)
|
||||
if ok {
|
||||
// Opportunistically check for embeded JSON and convert it to a first-class object.
|
||||
var embeded interface{}
|
||||
if json.Unmarshal(bytes, &embeded) == nil {
|
||||
return embeded
|
||||
}
|
||||
|
||||
// STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
|
||||
// Although easy with strings, it becomes more difficult with special types like INET[].
|
||||
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// sqlResult is a thin wrapper around sql.Result.
|
||||
type sqlResult struct {
|
||||
LastInsertId int64 `json:"last_insert_id"`
|
||||
RowsAffected int64 `json:"rows_affected"`
|
||||
}
|
||||
|
||||
// sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
|
||||
func sqlResultFrom(res sql.Result) sqlResult {
|
||||
insertID, errID := res.LastInsertId()
|
||||
rowsAffected, errRows := res.RowsAffected()
|
||||
|
||||
// If an error occurs when extracting the result, it is because the
|
||||
// driver does not support that specific field. Instead of passing this
|
||||
// to the user, omit the field in the response.
|
||||
if errID != nil {
|
||||
insertID = -1
|
||||
}
|
||||
if errRows != nil {
|
||||
rowsAffected = -1
|
||||
}
|
||||
|
||||
return sqlResult{insertID, rowsAffected}
|
||||
}
|
336
dbconnect/sql_test.go
Normal file
336
dbconnect/sql_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kshvakov/clickhouse"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSQLClient(t *testing.T) {
|
||||
originURLs := []string{
|
||||
"postgres://localhost",
|
||||
"cockroachdb://localhost:1337",
|
||||
"postgresql://user:pass@127.0.0.1",
|
||||
"mysql://localhost",
|
||||
"clickhouse://127.0.0.1:9000/?debug",
|
||||
"sqlite3::memory:",
|
||||
"file:test.db?cache=shared",
|
||||
}
|
||||
|
||||
for _, originURL := range originURLs {
|
||||
origin, _ := url.Parse(originURL)
|
||||
_, err := NewSQLClient(context.Background(), origin)
|
||||
|
||||
assert.NoError(t, err, originURL)
|
||||
}
|
||||
|
||||
originURLs = []string{
|
||||
"",
|
||||
"/",
|
||||
"http://localhost",
|
||||
"coolthing://user:pass@127.0.0.1",
|
||||
}
|
||||
|
||||
for _, originURL := range originURLs {
|
||||
origin, _ := url.Parse(originURL)
|
||||
_, err := NewSQLClient(context.Background(), origin)
|
||||
|
||||
assert.Error(t, err, originURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsSQL(t *testing.T) {
|
||||
args := []Arguments{
|
||||
Arguments{
|
||||
Positional: []interface{}{
|
||||
"val", 10, 3.14,
|
||||
},
|
||||
},
|
||||
Arguments{
|
||||
Named: map[string]interface{}{
|
||||
"key": time.Unix(0, 0),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var nameType sql.NamedArg
|
||||
for _, arg := range args {
|
||||
arg.sql("")
|
||||
for _, named := range arg.Positional {
|
||||
assert.IsType(t, nameType, named)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
key interface{}
|
||||
val interface{}
|
||||
dialect string
|
||||
arg sql.NamedArg
|
||||
}{
|
||||
{"key", "val", "mssql", sql.Named("key", "val")},
|
||||
{0, 1, "sqlite3", sql.Named("0", 1)},
|
||||
{1, []string{"a", "b", "c"}, "postgres", sql.Named("1", pq.Array([]string{"a", "b", "c"}))},
|
||||
{"in", []uint{0, 1}, "clickhouse", sql.Named("in", clickhouse.Array([]uint{0, 1}))},
|
||||
{"", time.Unix(0, 0), "", sql.Named("", time.Unix(0, 0))},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
arg := sqlArg(test.key, test.val, test.dialect)
|
||||
assert.Equal(t, test.arg, arg, test.key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
str string
|
||||
iso sql.IsolationLevel
|
||||
}{
|
||||
{"", sql.LevelDefault},
|
||||
{"DEFAULT", sql.LevelDefault},
|
||||
{"read_UNcommitted", sql.LevelReadUncommitted},
|
||||
{"serializable", sql.LevelSerializable},
|
||||
{"none", sql.IsolationLevel(-1)},
|
||||
{"SNAP shot", -2},
|
||||
{"blah", -2},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
iso, err := sqlIsolation(test.str)
|
||||
|
||||
if test.iso < -1 {
|
||||
assert.Error(t, err, test.str)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.iso, iso, test.str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLMode(t *testing.T) {
|
||||
modes := []string{
|
||||
"query",
|
||||
"exec",
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
actual, err := sqlMode(mode)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, strings.ToLower(mode), actual, mode)
|
||||
}
|
||||
|
||||
modes = []string{
|
||||
"",
|
||||
"blah",
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
_, err := sqlMode(mode)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func helperRows(mockRows *sqlmock.Rows) *sql.Rows {
|
||||
db, mock, _ := sqlmock.New()
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(mockRows)
|
||||
rows, _ := db.Query("SELECT")
|
||||
return rows
|
||||
}
|
||||
|
||||
func TestSQLRows(t *testing.T) {
|
||||
actual, err := sqlRows(helperRows(sqlmock.NewRows(
|
||||
[]string{"name", "age", "dept"}).
|
||||
AddRow("alice", 19, "prod")))
|
||||
expected := []map[string]interface{}{map[string]interface{}{
|
||||
"name": "alice",
|
||||
"age": int64(19),
|
||||
"dept": "prod"}}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expected, actual)
|
||||
}
|
||||
|
||||
func TestSQLValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
input interface{}
|
||||
output interface{}
|
||||
}{
|
||||
{"hello", "hello"},
|
||||
{1, 1},
|
||||
{false, false},
|
||||
{[]byte("random"), "random"},
|
||||
{[]byte("{\"json\":true}"), map[string]interface{}{"json": true}},
|
||||
{[]byte("[]"), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.output, sqlValue(test.input, nil), test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLResultFrom(t *testing.T) {
|
||||
res := sqlResultFrom(sqlmock.NewResult(1, 2))
|
||||
assert.Equal(t, sqlResult{1, 2}, res)
|
||||
|
||||
res = sqlResultFrom(sqlmock.NewErrorResult(fmt.Errorf("")))
|
||||
assert.Equal(t, sqlResult{-1, -1}, res)
|
||||
}
|
||||
|
||||
func helperSQLite3(t *testing.T) (context.Context, Client) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
url, _ := url.Parse("file::memory:?cache=shared")
|
||||
|
||||
sqlite3, err := NewSQLClient(ctx, url)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return ctx, sqlite3
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
err := sqlite3.Ping(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubmit(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
|
||||
res, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "CREATE TABLE t (a INTEGER, b FLOAT, c TEXT, d BLOB);",
|
||||
Mode: "exec",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sqlResult{0, 0}, res)
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "SELECT * FROM t;",
|
||||
Mode: "query",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, res)
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "INSERT INTO t VALUES (?, ?, ?, ?);",
|
||||
Mode: "exec",
|
||||
Arguments: Arguments{
|
||||
Positional: []interface{}{
|
||||
1,
|
||||
3.14,
|
||||
"text",
|
||||
"blob",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sqlResult{1, 1}, res)
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "UPDATE t SET c = NULL;",
|
||||
Mode: "exec",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sqlResult{1, 1}, res)
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "SELECT * FROM t WHERE a = ?;",
|
||||
Mode: "query",
|
||||
Arguments: Arguments{
|
||||
Positional: []interface{}{1},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 1)
|
||||
|
||||
resf, ok := res.([]map[string]interface{})
|
||||
assert.True(t, ok)
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"a": int64(1),
|
||||
"b": 3.14,
|
||||
"c": nil,
|
||||
"d": "blob",
|
||||
}, resf[0])
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "DROP TABLE t;",
|
||||
Mode: "exec",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sqlResult{1, 1}, res)
|
||||
}
|
||||
|
||||
func TestSubmitTransaction(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
|
||||
res, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "BEGIN;",
|
||||
Mode: "exec",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, res)
|
||||
|
||||
res, err = sqlite3.Submit(ctx, &Command{
|
||||
Statement: "BEGIN; CREATE TABLE tt (a INT); COMMIT;",
|
||||
Mode: "exec",
|
||||
Isolation: "none",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sqlResult{0, 0}, res)
|
||||
|
||||
rows, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "SELECT * FROM tt;",
|
||||
Mode: "query",
|
||||
Isolation: "repeatable_read",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, rows)
|
||||
}
|
||||
|
||||
func TestSubmitTimeout(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
|
||||
res, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "SELECT * FROM t;",
|
||||
Mode: "query",
|
||||
Timeout: 1 * time.Nanosecond,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, res)
|
||||
}
|
||||
|
||||
func TestSubmitMode(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
|
||||
res, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "SELECT * FROM t;",
|
||||
Mode: "notanoption",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, res)
|
||||
}
|
||||
|
||||
func TestSubmitEmpty(t *testing.T) {
|
||||
ctx, sqlite3 := helperSQLite3(t)
|
||||
|
||||
res, err := sqlite3.Submit(ctx, &Command{
|
||||
Statement: "; ; ; ;",
|
||||
Mode: "query",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, res)
|
||||
}
|
Reference in New Issue
Block a user