mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 14:49:57 +00:00
TUN-528: Move cloudflared into a separate repo
This commit is contained in:
287
h2mux/muxwriter.go
Normal file
287
h2mux/muxwriter.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type MuxWriter struct {
|
||||
// f is used to write HTTP2 frames.
|
||||
f *http2.Framer
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// streamErrors receives stream errors raised by the MuxReader.
|
||||
streamErrors *StreamErrorMap
|
||||
// readyStreamChan is used to multiplex writable streams onto the single connection.
|
||||
// When a stream becomes writable its ID is sent on this channel.
|
||||
readyStreamChan <-chan uint32
|
||||
// newStreamChan is used to create new streams with a given set of headers.
|
||||
newStreamChan <-chan MuxedStreamRequest
|
||||
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
|
||||
// is the HTTP/2 error code to send.
|
||||
goAwayChan <-chan http2.ErrCode
|
||||
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||
abortChan <-chan struct{}
|
||||
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||
pingTimestamp *PingTimestamp
|
||||
// A timer used to measure idle connection time. Reset after sending data.
|
||||
idleTimer *IdleTimer
|
||||
// connActiveChan receives a signal that the connection received some (read) activity.
|
||||
connActiveChan <-chan struct{}
|
||||
// Maximum size of all frames that can be sent on this connection.
|
||||
maxFrameSize uint32
|
||||
// headerEncoder is the stateful header encoder for this connection
|
||||
headerEncoder *hpack.Encoder
|
||||
// headerBuffer is the temporary buffer used by headerEncoder.
|
||||
headerBuffer bytes.Buffer
|
||||
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
|
||||
updateReceiveWindowChan chan<- uint32
|
||||
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
|
||||
updateSendWindowChan chan<- uint32
|
||||
// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
|
||||
bytesWrote *AtomicCounter
|
||||
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
|
||||
updateOutBoundBytesChan chan<- uint64
|
||||
|
||||
useDictChan <-chan useDictRequest
|
||||
}
|
||||
|
||||
type MuxedStreamRequest struct {
|
||||
stream *MuxedStream
|
||||
body io.Reader
|
||||
}
|
||||
|
||||
func (r *MuxedStreamRequest) flushBody() {
|
||||
io.Copy(r.stream, r.body)
|
||||
r.stream.CloseWrite()
|
||||
}
|
||||
|
||||
func tsToPingData(ts int64) [8]byte {
|
||||
pingData := [8]byte{}
|
||||
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
|
||||
return pingData
|
||||
}
|
||||
|
||||
func (w *MuxWriter) run(parentLogger *log.Entry) error {
|
||||
logger := parentLogger.WithFields(log.Fields{
|
||||
"subsystem": "mux",
|
||||
"dir": "write",
|
||||
})
|
||||
defer logger.Debug("event loop finished")
|
||||
|
||||
// routine to periodically communicate bytesWrote
|
||||
go func() {
|
||||
tickC := time.Tick(updateFreq)
|
||||
for {
|
||||
select {
|
||||
case <-w.abortChan:
|
||||
return
|
||||
case <-tickC:
|
||||
w.updateOutBoundBytesChan <- w.bytesWrote.Count()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.abortChan:
|
||||
logger.Debug("aborting writer thread")
|
||||
return nil
|
||||
case errCode := <-w.goAwayChan:
|
||||
logger.Debug("sending GOAWAY code ", errCode)
|
||||
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.pingTimestamp.GetUpdateChan():
|
||||
logger.Debug("sending PING ACK")
|
||||
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.idleTimer.C:
|
||||
if !w.idleTimer.Retry() {
|
||||
return ErrConnectionDropped
|
||||
}
|
||||
logger.Debug("sending PING")
|
||||
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.ResetTimer()
|
||||
case <-w.connActiveChan:
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.streamErrors.GetSignalChan():
|
||||
for streamID, errCode := range w.streamErrors.GetErrors() {
|
||||
logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream")
|
||||
err := w.f.WriteRSTStream(streamID, errCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamRequest := <-w.newStreamChan:
|
||||
streamID := w.streams.AcquireLocalID()
|
||||
streamRequest.stream.streamID = streamID
|
||||
if !w.streams.Set(streamRequest.stream) {
|
||||
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
|
||||
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
|
||||
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
|
||||
// reasons why you'd call Shutdown anyway.
|
||||
continue
|
||||
}
|
||||
if streamRequest.body != nil {
|
||||
go streamRequest.flushBody()
|
||||
}
|
||||
streamLogger := logger.WithField("stream", streamID)
|
||||
err := w.writeStreamData(streamRequest.stream, streamLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamID := <-w.readyStreamChan:
|
||||
streamLogger := logger.WithField("stream", streamID)
|
||||
stream, ok := w.streams.Get(streamID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
err := w.writeStreamData(stream, streamLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case useDict := <-w.useDictChan:
|
||||
err := w.writeUseDictionary(useDict)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing use dictionary")
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
|
||||
logger.Debug("writable")
|
||||
chunk := stream.getChunk()
|
||||
w.updateReceiveWindowChan <- stream.getReceiveWindow()
|
||||
w.updateSendWindowChan <- stream.getSendWindow()
|
||||
if chunk.sendHeadersFrame() {
|
||||
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing headers")
|
||||
return err
|
||||
}
|
||||
logger.Debug("output headers")
|
||||
}
|
||||
|
||||
if chunk.sendWindowUpdateFrame() {
|
||||
// Send a WINDOW_UPDATE frame to update our receive window.
|
||||
// If the Stream ID is zero, the window update applies to the connection as a whole
|
||||
// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
|
||||
// always account for its contribution against the connection flow-control
|
||||
// window, unless the receiver treats this as a connection error"
|
||||
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing window update")
|
||||
return err
|
||||
}
|
||||
logger.Debugf("increment receive window by %d", chunk.windowUpdate)
|
||||
}
|
||||
|
||||
for chunk.sendDataFrame() {
|
||||
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
|
||||
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing data")
|
||||
return err
|
||||
}
|
||||
// update the amount of data wrote
|
||||
w.bytesWrote.IncrementBy(uint64(len(payload)))
|
||||
logger.WithField("len", len(payload)).Debug("output data")
|
||||
|
||||
if sentEOF {
|
||||
if stream.readBuffer.Closed() {
|
||||
// transition into closed state
|
||||
if !stream.gotReceiveEOF() {
|
||||
// the peer may send data that we no longer want to receive. Force them into the
|
||||
// closed state.
|
||||
logger.Debug("resetting stream")
|
||||
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
|
||||
} else {
|
||||
// Half-open stream transitioned into closed
|
||||
logger.Debug("closing stream")
|
||||
}
|
||||
w.streams.Delete(chunk.streamID)
|
||||
} else {
|
||||
logger.Debug("closing stream write side")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
|
||||
w.headerBuffer.Reset()
|
||||
for _, header := range headers {
|
||||
err := w.headerEncoder.WriteField(hpack.HeaderField{
|
||||
Name: header.Name,
|
||||
Value: header.Value,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return w.headerBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
|
||||
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
|
||||
encodedHeaders, err := w.encodeHeaders(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
blockSize := int(w.maxFrameSize)
|
||||
endHeaders := len(encodedHeaders) == 0
|
||||
for !endHeaders && err == nil {
|
||||
blockFragment := encodedHeaders
|
||||
if len(encodedHeaders) > blockSize {
|
||||
blockFragment = blockFragment[:blockSize]
|
||||
encodedHeaders = encodedHeaders[blockSize:]
|
||||
// Send CONTINUATION frame if the headers can't be fit into 1 frame
|
||||
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
|
||||
} else {
|
||||
endHeaders = true
|
||||
err = w.f.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
EndHeaders: endHeaders,
|
||||
BlockFragment: blockFragment,
|
||||
})
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error {
|
||||
err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := make([]byte, 0, 64)
|
||||
for _, set := range dictRequest.setDict {
|
||||
payload = append(payload, byte(set.dictID))
|
||||
payload = appendVarInt(payload, 7, uint64(set.dictSZ))
|
||||
payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0
|
||||
}
|
||||
|
||||
err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload)
|
||||
return err
|
||||
}
|
Reference in New Issue
Block a user