mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-05-11 23:16:35 +00:00

* Allow partial reads from a GorillaConn; add SetDeadline (from net.Conn) The current implementation of GorillaConn will drop data if the websocket frame isn't read 100%. For example, if a websocket frame is size=3, and Read() is called with a []byte of len=1, the 2 other bytes in the frame are lost forever. This is currently masked by the fact that this is used primarily in io.Copy to another socket (in ingress.Stream) - as long as the read buffer used by io.Copy is big enough (it is 32*1024, so in theory we could see this today?) then data is copied over to the other socket. The client then can do partial reads just fine as the kernel will take care of the buffer from here on out. I hit this by trying to create my own tunnel and avoiding ingress.Stream, but this could be a real bug today I think if a websocket frame bigger than 32*1024 was received, although it is also possible that we are lucky and the upstream size which I haven't checked uses a smaller buffer than that always. The test I added hangs before my change, succeeds after. Also add SetDeadline so that GorillaConn fully implements net.Conn * Comment formatting; fast path * Avoid intermediate buffer for first len(p) bytes; import order
146 lines
3.6 KiB
Go
146 lines
3.6 KiB
Go
package websocket
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
gobwas "github.com/gobwas/ws"
|
|
"github.com/gobwas/ws/wsutil"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
const (
|
|
// Time allowed to write a message to the peer.
|
|
writeWait = 10 * time.Second
|
|
|
|
// Time allowed to read the next pong message from the peer.
|
|
pongWait = 60 * time.Second
|
|
|
|
// Send pings to peer with this period. Must be less than pongWait.
|
|
pingPeriod = (pongWait * 9) / 10
|
|
)
|
|
|
|
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
|
// This is still used by access carrier
|
|
type GorillaConn struct {
|
|
*websocket.Conn
|
|
log *zerolog.Logger
|
|
readBuf bytes.Buffer
|
|
}
|
|
|
|
// Read will read messages from the websocket connection
|
|
func (c *GorillaConn) Read(p []byte) (int, error) {
|
|
// Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
|
|
if c.readBuf.Len() > 0 {
|
|
return c.readBuf.Read(p)
|
|
}
|
|
|
|
_, message, err := c.Conn.ReadMessage()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
copied := copy(p, message)
|
|
|
|
// Write unread bytes to readBuf; if everything was read this is a no-op
|
|
// Write returns a nil error always and grows the buffer; everything is always written or panic
|
|
c.readBuf.Write(message[copied:])
|
|
|
|
return copied, nil
|
|
}
|
|
|
|
// Write will write messages to the websocket connection
|
|
func (c *GorillaConn) Write(p []byte) (int, error) {
|
|
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
// SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
|
|
// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
|
|
// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
|
|
func (c *GorillaConn) SetDeadline(t time.Time) error {
|
|
if err := c.Conn.SetReadDeadline(t); err != nil {
|
|
return fmt.Errorf("error setting read deadline: %w", err)
|
|
}
|
|
if err := c.Conn.SetWriteDeadline(t); err != nil {
|
|
return fmt.Errorf("error setting write deadline: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// pinger simulates the websocket connection to keep it alive
|
|
func (c *GorillaConn) pinger(ctx context.Context) {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
|
c.log.Debug().Msgf("failed to send ping message: %s", err)
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type Conn struct {
|
|
rw io.ReadWriter
|
|
log *zerolog.Logger
|
|
}
|
|
|
|
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
|
c := &Conn{
|
|
rw: rw,
|
|
log: log,
|
|
}
|
|
go c.pinger(ctx)
|
|
return c
|
|
}
|
|
|
|
// Read will read messages from the websocket connection
|
|
func (c *Conn) Read(reader []byte) (int, error) {
|
|
data, err := wsutil.ReadClientBinary(c.rw)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return copy(reader, data), nil
|
|
}
|
|
|
|
// Write will write messages to the websocket connection
|
|
func (c *Conn) Write(p []byte) (int, error) {
|
|
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (c *Conn) pinger(ctx context.Context) {
|
|
pongMessge := wsutil.Message{
|
|
OpCode: gobwas.OpPong,
|
|
Payload: []byte{},
|
|
}
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
|
c.log.Err(err).Msgf("failed to write ping message")
|
|
}
|
|
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
|
c.log.Err(err).Msgf("failed to write pong message")
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|