TUN-2243: Revert "STOR-519: Add db-connect, a SQL over HTTPS server"

This reverts commit 5da2109811.
This commit is contained in:
Adam Chalmers
2019-08-26 16:45:49 -05:00
parent c3c88cc31e
commit 4e1df1a211
410 changed files with 666 additions and 362649 deletions

View File

@@ -1,145 +0,0 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"unicode"
"unicode/utf8"
)
// Client is an interface to talk to any database.
//
// Currently, the only implementation is SQLClient, but its structure
// should be designed to handle a MongoClient or RedisClient in the future.
type Client interface {
Ping(context.Context) error
Submit(context.Context, *Command) (interface{}, error)
}
// NewClient creates a database client based on its URL scheme.
func NewClient(ctx context.Context, originURL *url.URL) (Client, error) {
return NewSQLClient(ctx, originURL)
}
// Command is a standard, non-vendor format for submitting database commands.
//
// When determining the scope of this struct, refer to the following litmus test:
// Could this (roughly) conform to SQL, Document-based, and Key-value command formats?
type Command struct {
Statement string `json:"statement"`
Arguments Arguments `json:"arguments,omitempty"`
Mode string `json:"mode,omitempty"`
Isolation string `json:"isolation,omitempty"`
Timeout time.Duration `json:"timeout,omitempty"`
}
// Validate enforces the contract of Command: non empty statement (both in length and logic),
// lowercase mode and isolation, non-zero timeout, and valid Arguments.
func (cmd *Command) Validate() error {
if cmd.Statement == "" {
return fmt.Errorf("cannot provide an empty statement")
}
if strings.Map(func(char rune) rune {
if char == ';' || unicode.IsSpace(char) {
return -1
}
return char
}, cmd.Statement) == "" {
return fmt.Errorf("cannot provide a statement with no logic: '%s'", cmd.Statement)
}
cmd.Mode = strings.ToLower(cmd.Mode)
cmd.Isolation = strings.ToLower(cmd.Isolation)
if cmd.Timeout.Nanoseconds() <= 0 {
cmd.Timeout = 24 * time.Hour
}
return cmd.Arguments.Validate()
}
// UnmarshalJSON converts a byte representation of JSON into a Command, which is also validated.
func (cmd *Command) UnmarshalJSON(data []byte) error {
// Alias is required to avoid infinite recursion from the default UnmarshalJSON.
type Alias Command
alias := &struct {
*Alias
}{
Alias: (*Alias)(cmd),
}
err := json.Unmarshal(data, &alias)
if err == nil {
err = cmd.Validate()
}
return err
}
// Arguments is a wrapper for either map-based or array-based Command arguments.
//
// Each field is mutually-exclusive and some Client implementations may not
// support both fields (eg. MySQL does not accept named arguments).
type Arguments struct {
Named map[string]interface{}
Positional []interface{}
}
// Validate enforces the contract of Arguments: non nil, mutually exclusive, and no empty or reserved keys.
func (args *Arguments) Validate() error {
if args.Named == nil {
args.Named = map[string]interface{}{}
}
if args.Positional == nil {
args.Positional = []interface{}{}
}
if len(args.Named) > 0 && len(args.Positional) > 0 {
return fmt.Errorf("both named and positional arguments cannot be specified: %+v and %+v", args.Named, args.Positional)
}
for key := range args.Named {
if key == "" {
return fmt.Errorf("named arguments cannot contain an empty key: %+v", args.Named)
}
if !utf8.ValidString(key) {
return fmt.Errorf("named argument does not conform to UTF-8 encoding: %s", key)
}
if strings.HasPrefix(key, "_") {
return fmt.Errorf("named argument cannot start with a reserved keyword '_': %s", key)
}
if unicode.IsNumber([]rune(key)[0]) {
return fmt.Errorf("named argument cannot start with a number: %s", key)
}
}
return nil
}
// UnmarshalJSON converts a byte representation of JSON into Arguments, which is also validated.
func (args *Arguments) UnmarshalJSON(data []byte) error {
var obj interface{}
err := json.Unmarshal(data, &obj)
if err != nil {
return err
}
named, ok := obj.(map[string]interface{})
if ok {
args.Named = named
} else {
positional, ok := obj.([]interface{})
if ok {
args.Positional = positional
} else {
return fmt.Errorf("arguments must either be an object {\"0\":\"val\"} or an array [\"val\"]: %s", string(data))
}
}
return args.Validate()
}

View File

@@ -1,183 +0,0 @@
package dbconnect
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCommandValidateEmpty(t *testing.T) {
stmts := []string{
"",
";",
" \n\t",
";\n;\t;",
}
for _, stmt := range stmts {
cmd := Command{Statement: stmt}
assert.Error(t, cmd.Validate(), stmt)
}
}
func TestCommandValidateMode(t *testing.T) {
modes := []string{
"",
"query",
"ExEc",
"PREPARE",
}
for _, mode := range modes {
cmd := Command{Statement: "Ok", Mode: mode}
assert.NoError(t, cmd.Validate(), mode)
assert.Equal(t, strings.ToLower(mode), cmd.Mode)
}
}
func TestCommandValidateIsolation(t *testing.T) {
isos := []string{
"",
"default",
"read_committed",
"SNAPshot",
}
for _, iso := range isos {
cmd := Command{Statement: "Ok", Isolation: iso}
assert.NoError(t, cmd.Validate(), iso)
assert.Equal(t, strings.ToLower(iso), cmd.Isolation)
}
}
func TestCommandValidateTimeout(t *testing.T) {
cmd := Command{Statement: "Ok", Timeout: 0}
assert.NoError(t, cmd.Validate())
assert.NotZero(t, cmd.Timeout)
cmd = Command{Statement: "Ok", Timeout: 1 * time.Second}
assert.NoError(t, cmd.Validate())
assert.Equal(t, 1*time.Second, cmd.Timeout)
}
func TestCommandValidateArguments(t *testing.T) {
cmd := Command{Statement: "Ok", Arguments: Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
}}
assert.Error(t, cmd.Validate())
}
func TestCommandUnmarshalJSON(t *testing.T) {
strs := []string{
"{\"statement\":\"Ok\"}",
"{\"statement\":\"Ok\",\"arguments\":[0, 3.14, \"apple\"],\"mode\":\"query\"}",
"{\"statement\":\"Ok\",\"isolation\":\"read_uncommitted\",\"timeout\":1000}",
}
for _, str := range strs {
var cmd Command
assert.NoError(t, json.Unmarshal([]byte(str), &cmd), str)
}
strs = []string{
"",
"\"",
"{}",
"{\"argument\":{\"key\":\"val\"}}",
"{\"statement\":[\"Ok\"]}",
}
for _, str := range strs {
var cmd Command
assert.Error(t, json.Unmarshal([]byte(str), &cmd), str)
}
}
func TestArgumentsValidateNotNil(t *testing.T) {
args := Arguments{}
assert.NoError(t, args.Validate())
assert.NotNil(t, args.Named)
assert.NotNil(t, args.Positional)
}
func TestArgumentsValidateMutuallyExclusive(t *testing.T) {
args := []Arguments{
Arguments{},
Arguments{Named: map[string]interface{}{"key": "val"}},
Arguments{Positional: []interface{}{"val"}},
}
for _, arg := range args {
assert.NoError(t, arg.Validate())
assert.False(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
args = []Arguments{
Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
},
}
for _, arg := range args {
assert.Error(t, arg.Validate())
assert.True(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
}
func TestArgumentsValidateKeys(t *testing.T) {
keys := []string{
"",
"_",
"_key",
"1",
"1key",
"\xf0\x28\x8c\xbc", // non-utf8
}
for _, key := range keys {
args := Arguments{Named: map[string]interface{}{key: "val"}}
assert.Error(t, args.Validate(), key)
}
}
func TestArgumentsUnmarshalJSON(t *testing.T) {
strs := []string{
"{}",
"{\"key\":\"val\"}",
"{\"key\":[1, 3.14, {\"key\":\"val\"}]}",
"[]",
"[\"key\",\"val\"]",
"[{}]",
}
for _, str := range strs {
var args Arguments
assert.NoError(t, json.Unmarshal([]byte(str), &args), str)
}
strs = []string{
"",
"\"",
"1",
"\"key\"",
"{\"key\",\"val\"}",
}
for _, str := range strs {
var args Arguments
assert.Error(t, json.Unmarshal([]byte(str), &args), str)
}
}

View File

@@ -1,157 +0,0 @@
package dbconnect
import (
"context"
"log"
"net"
"strconv"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
)
// Cmd is the entrypoint command for dbconnect.
//
// The tunnel package is responsible for appending this to tunnel.Commands().
func Cmd() *cli.Command {
return &cli.Command{
Category: "Database Connect (ALPHA)",
Name: "db-connect",
Usage: "Access your SQL database from Cloudflare Workers or the browser",
ArgsUsage: " ",
Description: `
Creates a connection between your database and the Cloudflare edge.
Now you can execute SQL commands anywhere you can send HTTPS requests.
Connect your database with any of the following commands, you can also try the "playground" without a database:
cloudflared db-connect --hostname sql.mysite.com --url postgres://user:pass@localhost?sslmode=disable \
--auth-domain mysite.cloudflareaccess.com --application-aud my-access-policy-tag
cloudflared db-connect --hostname sql-dev.mysite.com --url mysql://localhost --insecure
cloudflared db-connect --playground
Requests should be authenticated using Cloudflare Access, learn more about how to enable it here:
https://developers.cloudflare.com/access/service-auth/service-token/
`,
Flags: []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Usage: "URL to the database (eg. postgres://user:pass@localhost?sslmode=disable)",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Hostname to accept commands over HTTPS (eg. sql.mysite.com)",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "auth-domain",
Usage: "Cloudflare Access authentication domain for your account (eg. mysite.cloudflareaccess.com)",
EnvVars: []string{"TUNNEL_ACCESS_AUTH_DOMAIN"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "application-aud",
Usage: "Cloudflare Access application \"AUD\" to verify JWTs from requests",
EnvVars: []string{"TUNNEL_ACCESS_APPLICATION_AUD"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "insecure",
Usage: "Disable authentication, the database will be open to the Internet",
Value: false,
EnvVars: []string{"TUNNEL_ACCESS_INSECURE"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "playground",
Usage: "Run a temporary, in-memory SQLite3 database for testing",
Value: false,
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "debug", // Make it more verbose than the tunnel default 'info'.
EnvVars: []string{"TUNNEL_LOGLEVEL"},
Hidden: true,
}),
},
Before: CmdBefore,
Action: CmdAction,
Hidden: true,
}
}
// CmdBefore runs some validation checks before running the command.
func CmdBefore(c *cli.Context) error {
// Show the help text is no flags are specified.
if c.NumFlags() == 0 {
return cli.ShowSubcommandHelp(c)
}
// Hello-world and playground are synonymous with each other,
// unset hello-world to prevent tunnel from initializing the hello package.
if c.IsSet("hello-world") {
c.Set("playground", "true")
c.Set("hello-world", "false")
}
// Unix-socket database urls are supported, but the logic is the same as url.
if c.IsSet("unix-socket") {
c.Set("url", c.String("unix-socket"))
c.Set("unix-socket", "")
}
// When playground mode is enabled, run with an in-memory database.
if c.IsSet("playground") {
c.Set("url", "sqlite3::memory:?cache=shared")
c.Set("insecure", strconv.FormatBool(!c.IsSet("auth-domain") && !c.IsSet("application-aud")))
}
// At this point, insecure configurations are valid.
if c.Bool("insecure") {
return nil
}
// Ensure that secure configurations specify a hostname, domain, and tag for JWT validation.
if !c.IsSet("hostname") || !c.IsSet("auth-domain") || !c.IsSet("application-aud") {
log.Fatal("must specify --hostname, --auth-domain, and --application-aud unless you want to run in --insecure mode")
}
return nil
}
// CmdAction starts the Proxy and sets the url in cli.Context to point to the Proxy address.
func CmdAction(c *cli.Context) error {
// STOR-612: sync with context in tunnel daemon.
ctx := context.Background()
var proxy *Proxy
var err error
if c.Bool("insecure") {
proxy, err = NewInsecureProxy(ctx, c.String("url"))
} else {
proxy, err = NewSecureProxy(ctx, c.String("url"), c.String("auth-domain"), c.String("application-aud"))
}
if err != nil {
log.Fatal(err)
return err
}
listenerC := make(chan net.Listener)
defer close(listenerC)
// Since the Proxy should only talk to the tunnel daemon, find the next available
// localhost port and start to listen to requests.
go func() {
err := proxy.Start(ctx, "127.0.0.1:", listenerC)
if err != nil {
log.Fatal(err)
}
}()
// Block until the the Proxy is online, retreive its address, and change the url to point to it.
// This is effectively "handing over" control to the tunnel package so it can run the tunnel daemon.
c.Set("url", "https://"+(<-listenerC).Addr().String())
return nil
}

View File

@@ -1,27 +0,0 @@
package dbconnect
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/urfave/cli.v2"
)
func TestCmd(t *testing.T) {
tests := [][]string{
{"cloudflared", "db-connect", "--playground"},
{"cloudflared", "db-connect", "--playground", "--hostname", "sql.mysite.com"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--insecure"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--hostname", "sql.mysite.com", "--auth-domain", "mysite.cloudflareaccess.com", "--application-aud", "aud"},
}
app := &cli.App{
Name: "cloudflared",
Commands: []*cli.Command{Cmd()},
}
for _, test := range tests {
assert.NoError(t, app.Run(test))
}
}

View File

@@ -1,78 +0,0 @@
# docker-compose -f ./dbconnect/integration_test/dbconnect.yaml up --build --force-recreate --renew-anon-volumes --exit-code-from cloudflared
version: "2.3"
networks:
test-dbconnect-network:
driver: bridge
services:
cloudflared:
build:
context: ../../
dockerfile: dev.Dockerfile
command: go test github.com/cloudflare/cloudflared/dbconnect/integration_test -v
depends_on:
postgres:
condition: service_healthy
mysql:
condition: service_healthy
mssql:
condition: service_healthy
clickhouse:
condition: service_healthy
environment:
DBCONNECT_INTEGRATION_TEST: "true"
POSTGRESQL_URL: postgres://postgres:secret@postgres/db?sslmode=disable
MYSQL_URL: mysql://root:secret@mysql/db?tls=false
MSSQL_URL: mssql://sa:secret12345!@mssql
CLICKHOUSE_URL: clickhouse://clickhouse:9000/db
networks:
- test-dbconnect-network
postgres:
image: postgres:11.4-alpine
environment:
POSTGRES_DB: db
POSTGRES_PASSWORD: secret
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
mysql:
image: mysql:8.0
environment:
MYSQL_DATABASE: db
MYSQL_ROOT_PASSWORD: secret
healthcheck:
test: ["CMD", "mysqladmin", "ping"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
mssql:
image: mcr.microsoft.com/mssql/server:2017-CU8-ubuntu
environment:
ACCEPT_EULA: "Y"
SA_PASSWORD: secret12345!
healthcheck:
test: ["CMD", "/opt/mssql-tools/bin/sqlcmd", "-S", "localhost", "-U", "sa", "-P", "secret12345!", "-Q", "SELECT 1"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network
clickhouse:
image: yandex/clickhouse-server:19.11
healthcheck:
test: ["CMD", "clickhouse-client", "--query", "SELECT 1"]
start_period: 3s
interval: 1s
timeout: 3s
retries: 10
networks:
- test-dbconnect-network

View File

@@ -1,265 +0,0 @@
package dbconnect_test
import (
"context"
"log"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/dbconnect"
)
func TestIntegrationPostgreSQL(t *testing.T) {
ctx, pq := helperNewSQLClient(t, "POSTGRESQL_URL")
err := pq.Ping(ctx)
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a TEXT, b UUID, c JSON, d INET[], e SERIAL);",
Mode: "exec",
})
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES ($1, $2, $3, $4);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"text",
"6b8d686d-bd8e-43bc-b09a-cfcbbe702c10",
"{\"bool\":true,\"array\":[\"a\", 1, 3.14],\"embed\":{\"num\":21}}",
[]string{"1.1.1.1", "1.0.0.1"},
},
},
})
assert.NoError(t, err)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "UPDATE t SET b = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := pq.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "text",
"b": nil,
"c": map[string]interface{}{
"bool": true,
"array": []interface{}{"a", float64(1), 3.14},
"embed": map[string]interface{}{"num": float64(21)},
},
"d": "{1.1.1.1,1.0.0.1}",
"e": int64(1),
}
assert.EqualValues(t, expected, actual)
_, err = pq.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationMySQL(t *testing.T) {
ctx, my := helperNewSQLClient(t, "MYSQL_URL")
err := my.Ping(ctx)
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a CHAR, b TINYINT, c FLOAT, d JSON, e YEAR);",
Mode: "exec",
})
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"a",
10,
3.14,
"{\"bool\":true}",
2000,
},
},
})
assert.NoError(t, err)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "ALTER TABLE t ADD COLUMN f GEOMETRY;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := my.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "a",
"b": float64(10),
"c": 3.14,
"d": map[string]interface{}{"bool": true},
"e": float64(2000),
"f": nil,
}
assert.EqualValues(t, expected, actual)
_, err = my.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationMSSQL(t *testing.T) {
ctx, ms := helperNewSQLClient(t, "MSSQL_URL")
err := ms.Ping(ctx)
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a BIT, b DECIMAL, c MONEY, d TEXT);",
Mode: "exec"})
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
0,
3,
"$0.99",
"text",
},
},
})
assert.NoError(t, err)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "UPDATE t SET d = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
res, err := ms.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": false,
"b": float64(3),
"c": float64(0.99),
"d": nil,
}
assert.EqualValues(t, expected, actual)
_, err = ms.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func TestIntegrationClickhouse(t *testing.T) {
ctx, ch := helperNewSQLClient(t, "CLICKHOUSE_URL")
err := ch.Ping(ctx)
assert.NoError(t, err)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "CREATE TABLE t (a UUID, b String, c Float64, d UInt32, e Int16, f Array(Enum8('a'=1, 'b'=2, 'c'=3))) engine=Memory;",
Mode: "exec",
})
assert.NoError(t, err)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?, ?, ?);",
Mode: "exec",
Arguments: dbconnect.Arguments{
Positional: []interface{}{
"ec65f626-6f50-4c86-9628-6314ef1edacd",
"",
3.14,
314,
-144,
[]string{"a", "b", "c"},
},
},
})
assert.NoError(t, err)
res, err := ch.Submit(ctx, &dbconnect.Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.IsType(t, make([]map[string]interface{}, 0), res)
actual := res.([]map[string]interface{})[0]
expected := map[string]interface{}{
"a": "ec65f626-6f50-4c86-9628-6314ef1edacd",
"b": "",
"c": float64(3.14),
"d": uint32(314),
"e": int16(-144),
"f": []string{"a", "b", "c"},
}
assert.EqualValues(t, expected, actual)
_, err = ch.Submit(ctx, &dbconnect.Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
}
func helperNewSQLClient(t *testing.T, env string) (context.Context, dbconnect.Client) {
_, ok := os.LookupEnv("DBCONNECT_INTEGRATION_TEST")
if ok {
t.Helper()
} else {
t.SkipNow()
}
val, ok := os.LookupEnv(env)
if !ok {
log.Fatalf("must provide database url as environment variable: %s", env)
}
parsed, err := url.Parse(val)
if err != nil {
log.Fatalf("cannot provide invalid database url: %s=%s", env, val)
}
ctx := context.Background()
client, err := dbconnect.NewSQLClient(ctx, parsed)
if err != nil {
log.Fatalf("could not start test client: %s", err)
}
return ctx, client
}

View File

@@ -1,274 +0,0 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/validation"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
timing "github.com/mitchellh/go-server-timing"
)
// Proxy is an HTTP server that proxies requests to a Client.
type Proxy struct {
client Client
accessValidator *validation.Access
logger *logrus.Logger
}
// NewInsecureProxy creates a Proxy that talks to a Client at an origin.
//
// In insecure mode, the Proxy will allow all Command requests.
func NewInsecureProxy(ctx context.Context, origin string) (*Proxy, error) {
originURL, err := url.Parse(origin)
if err != nil {
return nil, errors.Wrap(err, "must provide a valid database url")
}
client, err := NewClient(ctx, originURL)
if err != nil {
return nil, err
}
err = client.Ping(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not connect to the database")
}
return &Proxy{client, nil, logrus.New()}, nil
}
// NewSecureProxy creates a Proxy that talks to a Client at an origin.
//
// In secure mode, the Proxy will reject any Command requests that are
// not authenticated by Cloudflare Access with a valid JWT.
func NewSecureProxy(ctx context.Context, origin, authDomain, applicationAUD string) (*Proxy, error) {
proxy, err := NewInsecureProxy(ctx, origin)
if err != nil {
return nil, err
}
validator, err := validation.NewAccessValidator(ctx, authDomain, authDomain, applicationAUD)
if err != nil {
return nil, err
}
proxy.accessValidator = validator
return proxy, err
}
// IsInsecure gets whether the Proxy will accept a Command from any source.
func (proxy *Proxy) IsInsecure() bool {
return proxy.accessValidator == nil
}
// IsAllowed checks whether a http.Request is allowed to receive data.
//
// By default, requests must pass through Cloudflare Access for authentication.
// If the proxy is explcitly set to insecure mode, all requests will be allowed.
func (proxy *Proxy) IsAllowed(r *http.Request, verbose ...bool) bool {
if proxy.IsInsecure() {
return true
}
// Access and Tunnel should prevent bad JWTs from even reaching the origin,
// but validate tokens anyway as an abundance of caution.
err := proxy.accessValidator.ValidateRequest(r.Context(), r)
if err == nil {
return true
}
// Warn administrators that invalid JWTs are being rejected. This is indicative
// of either a misconfiguration of the CLI or a massive failure of upstream systems.
if len(verbose) > 0 {
proxy.httpLog(r, err).Error("Failed JWT authentication")
}
return false
}
// Start the Proxy at a given address and notify the listener channel when the server is online.
func (proxy *Proxy) Start(ctx context.Context, addr string, listenerC chan<- net.Listener) error {
// STOR-611: use a seperate listener and consider web socket support.
httpListener, err := hello.CreateTLSListener(addr)
if err != nil {
return errors.Wrapf(err, "could not create listener at %s", addr)
}
errC := make(chan error)
defer close(errC)
// Starts the HTTP server and begins to serve requests.
go func() {
errC <- proxy.httpListen(ctx, httpListener)
}()
// Continually ping the server until it comes online or 10 attempts fail.
go func() {
var err error
for i := 0; i < 10; i++ {
_, err = http.Get("http://" + httpListener.Addr().String())
// Once no error was detected, notify the listener channel and return.
if err == nil {
listenerC <- httpListener
return
}
// Backoff between requests to ping the server.
<-time.After(1 * time.Second)
}
errC <- errors.Wrap(err, "took too long for the http server to start")
}()
return <-errC
}
// httpListen starts the httpServer and blocks until the context closes.
func (proxy *Proxy) httpListen(ctx context.Context, listener net.Listener) error {
httpServer := &http.Server{
Addr: listener.Addr().String(),
Handler: timing.Middleware(proxy.httpRouter(), nil),
ReadTimeout: 10 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
<-ctx.Done()
httpServer.Close()
listener.Close()
}()
return httpServer.Serve(listener)
}
// httpRouter creates a mux.Router for the Proxy.
func (proxy *Proxy) httpRouter() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/ping", proxy.httpPing()).Methods("GET", "HEAD")
router.HandleFunc("/submit", proxy.httpSubmit()).Methods("POST")
return router
}
// httpPing tests the connection to the database.
//
// By default, this endpoint is unauthenticated to allow for health checks.
// To enable authentication, Cloudflare Access must be enabled on this route.
func (proxy *Proxy) httpPing() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
metric := timing.FromContext(ctx).NewMetric("db").Start()
err := proxy.client.Ping(ctx)
metric.Stop()
if err == nil {
proxy.httpRespond(w, r, http.StatusOK, "")
} else {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpSubmit sends a command to the database and returns its response.
//
// By default, this endpoint will reject requests that do not pass through Cloudflare Access.
// To disable authentication, the --insecure flag must be specified in the command line.
func (proxy *Proxy) httpSubmit() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !proxy.IsAllowed(r, true) {
proxy.httpRespondErr(w, r, http.StatusForbidden, fmt.Errorf(""))
return
}
var cmd Command
err := json.NewDecoder(r.Body).Decode(&cmd)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusBadRequest, err)
return
}
ctx := r.Context()
metric := timing.FromContext(ctx).NewMetric("db").Start()
data, err := proxy.client.Submit(ctx, &cmd)
metric.Stop()
if err != nil {
proxy.httpRespondErr(w, r, http.StatusUnprocessableEntity, err)
return
}
w.Header().Set("Content-type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpRespond writes a status code and string response to the response writer.
func (proxy *Proxy) httpRespond(w http.ResponseWriter, r *http.Request, status int, message string) {
w.WriteHeader(status)
// Only expose the message detail of the reponse if the request is not HEAD
// and the user is authenticated. For example, this prevents an unauthenticated
// failed health check from accidentally leaking sensitive information about the Client.
if r.Method != http.MethodHead && proxy.IsAllowed(r) {
if message == "" {
message = http.StatusText(status)
}
fmt.Fprint(w, message)
}
}
// httpRespondErr is similar to httpRespond, except it formats errors to be more friendly.
func (proxy *Proxy) httpRespondErr(w http.ResponseWriter, r *http.Request, defaultStatus int, err error) {
status, err := httpError(defaultStatus, err)
proxy.httpRespond(w, r, status, err.Error())
proxy.httpLog(r, err).Warn("Database connect error")
}
// httpLog returns a logrus.Entry that is formatted to output a request Cf-ray.
func (proxy *Proxy) httpLog(r *http.Request, err error) *logrus.Entry {
return proxy.logger.WithContext(r.Context()).WithField("CF-RAY", r.Header.Get("Cf-ray")).WithError(err)
}
// httpError extracts common errors and returns an status code and friendly error.
func httpError(defaultStatus int, err error) (int, error) {
if err == nil {
return http.StatusNotImplemented, fmt.Errorf("error expected but found none")
}
if err == io.EOF {
return http.StatusBadRequest, fmt.Errorf("request body cannot be empty")
}
if err == context.DeadlineExceeded {
return http.StatusRequestTimeout, err
}
_, ok := err.(net.Error)
if ok {
return http.StatusRequestTimeout, err
}
if err == context.Canceled {
// Does not exist in Golang, but would be: http.StatusClientClosedWithoutResponse
return 444, err
}
return defaultStatus, err
}

View File

@@ -1,238 +0,0 @@
package dbconnect
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestNewInsecureProxy(t *testing.T) {
origins := []string{
"",
":/",
"http://localhost",
"tcp://localhost:9000?debug=true",
"mongodb://127.0.0.1",
}
for _, origin := range origins {
proxy, err := NewInsecureProxy(context.Background(), origin)
assert.Error(t, err)
assert.Empty(t, proxy)
}
}
func TestProxyIsAllowed(t *testing.T) {
proxy := helperNewProxy(t)
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
assert.True(t, proxy.IsAllowed(req))
proxy = helperNewProxy(t, true)
req.Header.Set("Cf-access-jwt-assertion", "xxx")
assert.False(t, proxy.IsAllowed(req))
}
func TestProxyStart(t *testing.T) {
proxy := helperNewProxy(t)
ctx := context.Background()
listenerC := make(chan net.Listener)
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
assert.Error(t, err)
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
assert.Error(t, err)
ctx, cancel := context.WithTimeout(ctx, 0)
defer cancel()
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
assert.IsType(t, http.ErrServerClosed, err)
}
func TestProxyHTTPRouter(t *testing.T) {
proxy := helperNewProxy(t)
router := proxy.httpRouter()
tests := []struct {
path string
method string
valid bool
}{
{"", "GET", false},
{"/", "GET", false},
{"/ping", "GET", true},
{"/ping", "HEAD", true},
{"/ping", "POST", false},
{"/submit", "POST", true},
{"/submit", "GET", false},
{"/submit/extra", "POST", false},
}
for _, test := range tests {
match := &mux.RouteMatch{}
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
assert.True(t, ok == test.valid, test.path)
}
}
func TestProxyHTTPPing(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpPing())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(2), res.ContentLength)
res, err = client.Head(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(-1), res.ContentLength)
}
func TestProxyHTTPSubmit(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
tests := []struct {
input string
status int
output string
}{
{"", http.StatusBadRequest, "request body cannot be empty"},
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
}
for _, test := range tests {
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
assert.NoError(t, err)
assert.Equal(t, test.status, res.StatusCode)
if res.StatusCode > http.StatusOK {
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
} else {
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
}
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
str := string(data)
assert.NoError(t, err)
assert.Equal(t, test.output, str)
}
}
func TestProxyHTTPSubmitForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.Zero(t, res.ContentLength)
}
func TestProxyHTTPRespond(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(5), res.ContentLength)
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
assert.Equal(t, []byte("Hello"), data)
}
func TestProxyHTTPRespondForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(0), res.ContentLength)
}
func TestHTTPError(t *testing.T) {
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
assert.Error(t, errTimeout)
tests := []struct {
input error
status int
output error
}{
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
{context.Canceled, 444, nil},
{errTimeout, http.StatusRequestTimeout, nil},
{fmt.Errorf(""), http.StatusInternalServerError, nil},
}
for _, test := range tests {
status, err := httpError(http.StatusInternalServerError, test.input)
assert.Error(t, err)
assert.Equal(t, test.status, status)
if test.output == nil {
test.output = test.input
}
assert.Equal(t, test.output, err)
}
}
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
t.Helper()
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
assert.NoError(t, err)
assert.NotNil(t, proxy)
if len(secure) == 0 {
proxy.accessValidator = nil // Mark as insecure
}
return proxy
}

View File

@@ -1,318 +0,0 @@
package dbconnect
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/xo/dburl"
// SQL drivers self-register with the database/sql package.
// https://github.com/golang/go/wiki/SQLDrivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/kshvakov/clickhouse"
"github.com/lib/pq"
)
// SQLClient is a Client that talks to a SQL database.
type SQLClient struct {
Dialect string
driver *sqlx.DB
}
// NewSQLClient creates a SQL client based on its URL scheme.
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
res, err := dburl.Parse(originURL.String())
if err != nil {
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
}
// Establishes the driver, but does not test the connection.
driver, err := sqlx.Open(res.Driver, res.DSN)
if err != nil {
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
}
// Closes the driver, will occur when the context finishes.
go func() {
<-ctx.Done()
driver.Close()
}()
return &SQLClient{driver.DriverName(), driver}, nil
}
// Ping verifies a connection to the database is still alive.
func (client *SQLClient) Ping(ctx context.Context) error {
return client.driver.PingContext(ctx)
}
// Submit queries or executes a command to the SQL database.
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
txx, err := cmd.ValidateSQL(client.Dialect)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
defer cancel()
var res interface{}
// Get the next available sql.Conn and submit the Command.
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
stmt := cmd.Statement
args := cmd.Arguments.Positional
if cmd.Mode == "query" {
res, err = sqlQuery(ctx, conn, stmt, args)
} else {
res, err = sqlExec(ctx, conn, stmt, args)
}
return err
})
return res, err
}
// ValidateSQL extends the contract of Command for SQL dialects:
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
//
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
err := cmd.Validate()
if err != nil {
return nil, err
}
mode, err := sqlMode(cmd.Mode)
if err != nil {
return nil, err
}
// Mutates Arguments to only use positional arguments with the type sql.NamedArg.
// This is a required by the sql.Driver before submitting arguments.
cmd.Arguments.sql(dialect)
iso, err := sqlIsolation(cmd.Isolation)
if err != nil {
return nil, err
}
// When isolation is out-of-range, this is indicative that no
// transaction should be executed and sql.TxOptions should be nil.
if iso < sql.LevelDefault {
return nil, nil
}
// In query mode, execute the transaction in read-only, unless it's Microsoft SQL
// which does not support that type of transaction.
readOnly := mode == "query" && dialect != "mssql"
return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
}
// sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
//
// If the transaction options are nil, run the useIt function outside a transaction.
// This is potentially an unsafe operation if the command does not clean up its state.
func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
conn, err := driver.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
// If transaction options are specified, begin and defer a rollback to catch errors.
var tx *sql.Tx
if txx != nil {
tx, err = conn.BeginTx(ctx, txx)
if err != nil {
return err
}
defer tx.Rollback()
}
err = useIt(conn)
// Check if useIt was successful and a transaction exists before committing.
if err == nil && tx != nil {
err = tx.Commit()
}
return err
}
// sqlQuery queries rows on a sql.Conn and returns an array of result objects.
func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
rows, err := conn.QueryContext(ctx, stmt, args...)
if err == nil {
return sqlRows(rows)
}
return nil, err
}
// sqlExec executes a command on a sql.Conn and returns the result of the operation.
func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
exec, err := conn.ExecContext(ctx, stmt, args...)
if err == nil {
return sqlResultFrom(exec), nil
}
return sqlResult{}, err
}
// sql mutates Arguments to contain a positional []sql.NamedArg.
//
// The actual return type is []interface{} due to the native Golang
// function signatures for sql.Exec and sql.Query being generic.
func (args *Arguments) sql(dialect string) {
result := args.Positional
for i, val := range result {
result[i] = sqlArg("", val, dialect)
}
for key, val := range args.Named {
result = append(result, sqlArg(key, val, dialect))
}
args.Positional = result
args.Named = map[string]interface{}{}
}
// sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
//
// Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
switch reflect.ValueOf(val).Kind() {
// PostgreSQL and Clickhouse require arrays to be wrapped before
// being inserted into the driver interface.
case reflect.Slice, reflect.Array:
switch dialect {
case "postgres":
val = pq.Array(val)
case "clickhouse":
val = clickhouse.Array(val)
}
}
return sql.Named(fmt.Sprint(key), val)
}
// sqlIsolation tries to match a string to a sql.IsolationLevel.
func sqlIsolation(str string) (sql.IsolationLevel, error) {
if str == "none" {
return sql.IsolationLevel(-1), nil
}
for iso := sql.LevelDefault; ; iso++ {
if iso > sql.LevelLinearizable {
return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
}
if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
return iso, nil
}
}
}
// sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
func sqlMode(str string) (string, error) {
switch str {
case "query", "exec":
return str, nil
default:
return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
}
}
// sqlRows scans through a SQL result set and returns an array of objects.
func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
columns, err := rows.Columns()
if err != nil {
return nil, errors.Wrap(err, "could not extract columns from result")
}
defer rows.Close()
types, err := rows.ColumnTypes()
if err != nil {
// Some drivers do not support type extraction, so fail silently and continue.
types = make([]*sql.ColumnType, len(columns))
}
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
var results []map[string]interface{}
for rows.Next() {
for i := range columns {
pointers[i] = &values[i]
}
rows.Scan(pointers...)
// Convert a row, an array of values, into an object where
// each key is the name of its respective column.
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = sqlValue(values[i], types[i])
}
results = append(results, entry)
}
return results, nil
}
// sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
bytes, ok := val.([]byte)
if ok {
// Opportunistically check for embeded JSON and convert it to a first-class object.
var embeded interface{}
if json.Unmarshal(bytes, &embeded) == nil {
return embeded
}
// STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
// Although easy with strings, it becomes more difficult with special types like INET[].
return string(bytes)
}
return val
}
// sqlResult is a thin wrapper around sql.Result.
type sqlResult struct {
LastInsertId int64 `json:"last_insert_id"`
RowsAffected int64 `json:"rows_affected"`
}
// sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
func sqlResultFrom(res sql.Result) sqlResult {
insertID, errID := res.LastInsertId()
rowsAffected, errRows := res.RowsAffected()
// If an error occurs when extracting the result, it is because the
// driver does not support that specific field. Instead of passing this
// to the user, omit the field in the response.
if errID != nil {
insertID = -1
}
if errRows != nil {
rowsAffected = -1
}
return sqlResult{insertID, rowsAffected}
}

View File

@@ -1,336 +0,0 @@
package dbconnect
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"testing"
"time"
"github.com/kshvakov/clickhouse"
"github.com/lib/pq"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
func TestNewSQLClient(t *testing.T) {
originURLs := []string{
"postgres://localhost",
"cockroachdb://localhost:1337",
"postgresql://user:pass@127.0.0.1",
"mysql://localhost",
"clickhouse://127.0.0.1:9000/?debug",
"sqlite3::memory:",
"file:test.db?cache=shared",
}
for _, originURL := range originURLs {
origin, _ := url.Parse(originURL)
_, err := NewSQLClient(context.Background(), origin)
assert.NoError(t, err, originURL)
}
originURLs = []string{
"",
"/",
"http://localhost",
"coolthing://user:pass@127.0.0.1",
}
for _, originURL := range originURLs {
origin, _ := url.Parse(originURL)
_, err := NewSQLClient(context.Background(), origin)
assert.Error(t, err, originURL)
}
}
func TestArgumentsSQL(t *testing.T) {
args := []Arguments{
Arguments{
Positional: []interface{}{
"val", 10, 3.14,
},
},
Arguments{
Named: map[string]interface{}{
"key": time.Unix(0, 0),
},
},
}
var nameType sql.NamedArg
for _, arg := range args {
arg.sql("")
for _, named := range arg.Positional {
assert.IsType(t, nameType, named)
}
}
}
func TestSQLArg(t *testing.T) {
tests := []struct {
key interface{}
val interface{}
dialect string
arg sql.NamedArg
}{
{"key", "val", "mssql", sql.Named("key", "val")},
{0, 1, "sqlite3", sql.Named("0", 1)},
{1, []string{"a", "b", "c"}, "postgres", sql.Named("1", pq.Array([]string{"a", "b", "c"}))},
{"in", []uint{0, 1}, "clickhouse", sql.Named("in", clickhouse.Array([]uint{0, 1}))},
{"", time.Unix(0, 0), "", sql.Named("", time.Unix(0, 0))},
}
for _, test := range tests {
arg := sqlArg(test.key, test.val, test.dialect)
assert.Equal(t, test.arg, arg, test.key)
}
}
func TestSQLIsolation(t *testing.T) {
tests := []struct {
str string
iso sql.IsolationLevel
}{
{"", sql.LevelDefault},
{"DEFAULT", sql.LevelDefault},
{"read_UNcommitted", sql.LevelReadUncommitted},
{"serializable", sql.LevelSerializable},
{"none", sql.IsolationLevel(-1)},
{"SNAP shot", -2},
{"blah", -2},
}
for _, test := range tests {
iso, err := sqlIsolation(test.str)
if test.iso < -1 {
assert.Error(t, err, test.str)
} else {
assert.NoError(t, err)
assert.Equal(t, test.iso, iso, test.str)
}
}
}
func TestSQLMode(t *testing.T) {
modes := []string{
"query",
"exec",
}
for _, mode := range modes {
actual, err := sqlMode(mode)
assert.NoError(t, err)
assert.Equal(t, strings.ToLower(mode), actual, mode)
}
modes = []string{
"",
"blah",
}
for _, mode := range modes {
_, err := sqlMode(mode)
assert.Error(t, err)
}
}
func helperRows(mockRows *sqlmock.Rows) *sql.Rows {
db, mock, _ := sqlmock.New()
mock.ExpectQuery("SELECT").WillReturnRows(mockRows)
rows, _ := db.Query("SELECT")
return rows
}
func TestSQLRows(t *testing.T) {
actual, err := sqlRows(helperRows(sqlmock.NewRows(
[]string{"name", "age", "dept"}).
AddRow("alice", 19, "prod")))
expected := []map[string]interface{}{map[string]interface{}{
"name": "alice",
"age": int64(19),
"dept": "prod"}}
assert.NoError(t, err)
assert.ElementsMatch(t, expected, actual)
}
func TestSQLValue(t *testing.T) {
tests := []struct {
input interface{}
output interface{}
}{
{"hello", "hello"},
{1, 1},
{false, false},
{[]byte("random"), "random"},
{[]byte("{\"json\":true}"), map[string]interface{}{"json": true}},
{[]byte("[]"), []interface{}{}},
}
for _, test := range tests {
assert.Equal(t, test.output, sqlValue(test.input, nil), test.input)
}
}
func TestSQLResultFrom(t *testing.T) {
res := sqlResultFrom(sqlmock.NewResult(1, 2))
assert.Equal(t, sqlResult{1, 2}, res)
res = sqlResultFrom(sqlmock.NewErrorResult(fmt.Errorf("")))
assert.Equal(t, sqlResult{-1, -1}, res)
}
func helperSQLite3(t *testing.T) (context.Context, Client) {
t.Helper()
ctx := context.Background()
url, _ := url.Parse("file::memory:?cache=shared")
sqlite3, err := NewSQLClient(ctx, url)
assert.NoError(t, err)
return ctx, sqlite3
}
func TestPing(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
err := sqlite3.Ping(ctx)
assert.NoError(t, err)
}
func TestSubmit(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "CREATE TABLE t (a INTEGER, b FLOAT, c TEXT, d BLOB);",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{0, 0}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "query",
})
assert.NoError(t, err)
assert.Empty(t, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "INSERT INTO t VALUES (?, ?, ?, ?);",
Mode: "exec",
Arguments: Arguments{
Positional: []interface{}{
1,
3.14,
"text",
"blob",
},
},
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "UPDATE t SET c = NULL;",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t WHERE a = ?;",
Mode: "query",
Arguments: Arguments{
Positional: []interface{}{1},
},
})
assert.NoError(t, err)
assert.Len(t, res, 1)
resf, ok := res.([]map[string]interface{})
assert.True(t, ok)
assert.EqualValues(t, map[string]interface{}{
"a": int64(1),
"b": 3.14,
"c": nil,
"d": "blob",
}, resf[0])
res, err = sqlite3.Submit(ctx, &Command{
Statement: "DROP TABLE t;",
Mode: "exec",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{1, 1}, res)
}
func TestSubmitTransaction(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "BEGIN;",
Mode: "exec",
})
assert.Error(t, err)
assert.Empty(t, res)
res, err = sqlite3.Submit(ctx, &Command{
Statement: "BEGIN; CREATE TABLE tt (a INT); COMMIT;",
Mode: "exec",
Isolation: "none",
})
assert.NoError(t, err)
assert.Equal(t, sqlResult{0, 0}, res)
rows, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM tt;",
Mode: "query",
Isolation: "repeatable_read",
})
assert.NoError(t, err)
assert.Empty(t, rows)
}
func TestSubmitTimeout(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "query",
Timeout: 1 * time.Nanosecond,
})
assert.Error(t, err)
assert.Empty(t, res)
}
func TestSubmitMode(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "SELECT * FROM t;",
Mode: "notanoption",
})
assert.Error(t, err)
assert.Empty(t, res)
}
func TestSubmitEmpty(t *testing.T) {
ctx, sqlite3 := helperSQLite3(t)
res, err := sqlite3.Submit(ctx, &Command{
Statement: "; ; ; ;",
Mode: "query",
})
assert.Error(t, err)
assert.Empty(t, res)
}