mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 17:29:58 +00:00
Add db-connect, a SQL over HTTPS server
This commit is contained in:
@@ -26,15 +26,15 @@ var logger = log.CreateLogger()
|
||||
type lock struct {
|
||||
lockFilePath string
|
||||
backoff *origin.BackoffHandler
|
||||
sigHandler *signalHandler
|
||||
sigHandler *signalHandler
|
||||
}
|
||||
|
||||
type signalHandler struct {
|
||||
sigChannel chan os.Signal
|
||||
signals []os.Signal
|
||||
sigChannel chan os.Signal
|
||||
signals []os.Signal
|
||||
}
|
||||
|
||||
func (s *signalHandler) register(handler func()){
|
||||
func (s *signalHandler) register(handler func()) {
|
||||
s.sigChannel = make(chan os.Signal, 1)
|
||||
signal.Notify(s.sigChannel, s.signals...)
|
||||
go func(s *signalHandler) {
|
||||
@@ -59,8 +59,8 @@ func newLock(path string) *lock {
|
||||
return &lock{
|
||||
lockFilePath: lockPath,
|
||||
backoff: &origin.BackoffHandler{MaxRetries: 7},
|
||||
sigHandler: &signalHandler{
|
||||
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
|
||||
sigHandler: &signalHandler{
|
||||
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -68,8 +68,8 @@ func newLock(path string) *lock {
|
||||
func (l *lock) Acquire() error {
|
||||
// Intercept SIGINT and SIGTERM to release lock before exiting
|
||||
l.sigHandler.register(func() {
|
||||
l.deleteLockFile()
|
||||
os.Exit(0)
|
||||
l.deleteLockFile()
|
||||
os.Exit(0)
|
||||
})
|
||||
|
||||
// Check for a path.lock file
|
||||
|
@@ -7,18 +7,18 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"runtime/trace"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/awsuploader"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/cmd/sqlgateway"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/dbconnect"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/origin"
|
||||
@@ -37,7 +37,6 @@ import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
"gopkg.in/urfave/cli.v2/altsrc"
|
||||
)
|
||||
@@ -138,43 +137,7 @@ func Commands() []*cli.Command {
|
||||
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||
Hidden: false,
|
||||
},
|
||||
{
|
||||
Name: "db",
|
||||
Action: func(c *cli.Context) error {
|
||||
tags := make(map[string]string)
|
||||
tags["hostname"] = c.String("hostname")
|
||||
raven.SetTagsContext(tags)
|
||||
|
||||
fmt.Printf("\nSQL Database Password: ")
|
||||
pass, err := terminal.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
go sqlgateway.StartProxy(c, logger, string(pass))
|
||||
|
||||
raven.CapturePanic(func() { err = tunnel(c) }, nil)
|
||||
if err != nil {
|
||||
raven.CaptureError(err, nil)
|
||||
}
|
||||
return err
|
||||
},
|
||||
Before: Before,
|
||||
Usage: "SQL Gateway is an SQL over HTTP reverse proxy",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "db",
|
||||
Value: true,
|
||||
Usage: "Enable the SQL Gateway Proxy",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "address",
|
||||
Value: "",
|
||||
Usage: "Database connection string: db://user:pass",
|
||||
},
|
||||
},
|
||||
Hidden: true,
|
||||
},
|
||||
dbConnectCmd(),
|
||||
}
|
||||
|
||||
var subcommands []*cli.Command
|
||||
@@ -644,6 +607,60 @@ func addPortIfMissing(uri *url.URL, port int) string {
|
||||
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||
}
|
||||
|
||||
func dbConnectCmd() *cli.Command {
|
||||
cmd := dbconnect.Cmd()
|
||||
|
||||
// Append the tunnel commands so users can customize the daemon settings.
|
||||
cmd.Flags = appendFlags(Flags(), cmd.Flags...)
|
||||
|
||||
// Override before to run tunnel validation before dbconnect validation.
|
||||
cmd.Before = func(c *cli.Context) error {
|
||||
err := Before(c)
|
||||
if err == nil {
|
||||
err = dbconnect.CmdBefore(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Override action to setup the Proxy, then if successful, start the tunnel daemon.
|
||||
cmd.Action = func(c *cli.Context) error {
|
||||
err := dbconnect.CmdAction(c)
|
||||
if err == nil {
|
||||
err = tunnel(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// appendFlags will append extra flags to a slice of flags.
|
||||
//
|
||||
// The cli package will panic if two flags exist with the same name,
|
||||
// so if extraFlags contains a flag that was already defined, modify the
|
||||
// original flags to use the extra version.
|
||||
func appendFlags(flags []cli.Flag, extraFlags ...cli.Flag) []cli.Flag {
|
||||
for _, extra := range extraFlags {
|
||||
var found bool
|
||||
|
||||
// Check if an extra flag overrides an existing flag.
|
||||
for i, flag := range flags {
|
||||
if reflect.DeepEqual(extra.Names(), flag.Names()) {
|
||||
flags[i] = extra
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Append the extra flag if it has nothing to override.
|
||||
if !found {
|
||||
flags = append(flags, extra)
|
||||
}
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
|
@@ -1,148 +0,0 @@
|
||||
package sqlgateway
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
|
||||
"github.com/elgs/gosqljson"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Connection Connection `json:"connection"`
|
||||
Command string `json:"command"`
|
||||
Params []interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
SSLMode string `json:"sslmode"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Columns []string `json:"columns"`
|
||||
Rows [][]string `json:"rows"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
Context *cli.Context
|
||||
Router *mux.Router
|
||||
Token string
|
||||
User string
|
||||
Password string
|
||||
Driver string
|
||||
Database string
|
||||
Logger *logrus.Logger
|
||||
}
|
||||
|
||||
func StartProxy(c *cli.Context, logger *logrus.Logger, password string) error {
|
||||
proxy := NewProxy(c, logger, password)
|
||||
|
||||
logger.Infof("Starting SQL Gateway Proxy on port %s", strings.Split(c.String("url"), ":")[1])
|
||||
|
||||
err := http.ListenAndServe(":"+strings.Split(c.String("url"), ":")[1], proxy.Router)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func randID(n int, c *cli.Context) string {
|
||||
charBytes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = charBytes[rand.Intn(len(charBytes))]
|
||||
}
|
||||
return fmt.Sprintf("%s&%s", c.String("hostname"), b)
|
||||
}
|
||||
|
||||
// db://user@dbname
|
||||
func parseInfo(input string) (string, string, string) {
|
||||
p1 := strings.Split(input, "://")
|
||||
p2 := strings.Split(p1[1], "@")
|
||||
return p1[0], p2[0], p2[1]
|
||||
}
|
||||
|
||||
func NewProxy(c *cli.Context, logger *logrus.Logger, pass string) *Proxy {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
driver, user, dbname := parseInfo(c.String("address"))
|
||||
proxy := Proxy{
|
||||
Context: c,
|
||||
Router: mux.NewRouter(),
|
||||
Token: randID(64, c),
|
||||
Logger: logger,
|
||||
User: user,
|
||||
Password: pass,
|
||||
Database: dbname,
|
||||
Driver: driver,
|
||||
}
|
||||
|
||||
logger.Info(fmt.Sprintf(`
|
||||
|
||||
--------------------
|
||||
SQL Gateway Proxy
|
||||
Token: %s
|
||||
--------------------
|
||||
|
||||
`, proxy.Token))
|
||||
|
||||
proxy.Router.HandleFunc("/", proxy.proxyRequest).Methods("POST")
|
||||
return &proxy
|
||||
}
|
||||
|
||||
func (proxy *Proxy) proxyRequest(rw http.ResponseWriter, req *http.Request) {
|
||||
var message Message
|
||||
response := Response{}
|
||||
|
||||
err := json.NewDecoder(req.Body).Decode(&message)
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if message.Connection.Token != proxy.Token {
|
||||
proxy.Logger.Error("Invalid token")
|
||||
http.Error(rw, "400 - Invalid token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", proxy.User, proxy.Password, proxy.Database, message.Connection.SSLMode)
|
||||
|
||||
db, err := sql.Open(proxy.Driver, connStr)
|
||||
defer db.Close()
|
||||
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
} else {
|
||||
proxy.Logger.Info("Forwarding SQL: ", message.Command)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
|
||||
headers, data, err := gosqljson.QueryDbToArray(db, "lower", message.Command, message.Params...)
|
||||
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
} else {
|
||||
response = Response{headers, data, ""}
|
||||
}
|
||||
}
|
||||
json.NewEncoder(rw).Encode(response)
|
||||
}
|
Reference in New Issue
Block a user