mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 23:49:57 +00:00
TUN-528: Move cloudflared into a separate repo
This commit is contained in:
49
vendor/zombiezen.com/go/capnproto2/rpc/BUILD.bazel
generated
vendored
Normal file
49
vendor/zombiezen.com/go/capnproto2/rpc/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"answer.go",
|
||||
"errors.go",
|
||||
"introspect.go",
|
||||
"log.go",
|
||||
"question.go",
|
||||
"rpc.go",
|
||||
"tables.go",
|
||||
"transport.go",
|
||||
],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//:go_default_library",
|
||||
"//internal/fulfiller:go_default_library",
|
||||
"//internal/queue:go_default_library",
|
||||
"//rpc/internal/refcount:go_default_library",
|
||||
"//std/capnp/rpc:go_default_library",
|
||||
"@org_golang_x_net//context:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = [
|
||||
"bench_test.go",
|
||||
"cancel_test.go",
|
||||
"embargo_test.go",
|
||||
"example_test.go",
|
||||
"issue3_test.go",
|
||||
"promise_test.go",
|
||||
"release_test.go",
|
||||
"rpc_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
"//:go_default_library",
|
||||
"//rpc/internal/logtransport:go_default_library",
|
||||
"//rpc/internal/pipetransport:go_default_library",
|
||||
"//rpc/internal/testcapnp:go_default_library",
|
||||
"//server:go_default_library",
|
||||
"//std/capnp/rpc:go_default_library",
|
||||
"@org_golang_x_net//context:go_default_library",
|
||||
],
|
||||
)
|
498
vendor/zombiezen.com/go/capnproto2/rpc/answer.go
generated
vendored
Normal file
498
vendor/zombiezen.com/go/capnproto2/rpc/answer.go
generated
vendored
Normal file
@@ -0,0 +1,498 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/internal/fulfiller"
|
||||
"zombiezen.com/go/capnproto2/internal/queue"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// callQueueSize is the maximum number of calls that can be queued per answer or client.
|
||||
// TODO(light): make this a ConnOption
|
||||
const callQueueSize = 64
|
||||
|
||||
// insertAnswer creates a new answer with the given ID, returning nil
|
||||
// if the ID is already in use.
|
||||
func (c *Conn) insertAnswer(id answerID, cancel context.CancelFunc) *answer {
|
||||
if c.answers == nil {
|
||||
c.answers = make(map[answerID]*answer)
|
||||
} else if _, exists := c.answers[id]; exists {
|
||||
return nil
|
||||
}
|
||||
a := &answer{
|
||||
id: id,
|
||||
cancel: cancel,
|
||||
conn: c,
|
||||
resolved: make(chan struct{}),
|
||||
queue: make([]pcall, 0, callQueueSize),
|
||||
}
|
||||
c.answers[id] = a
|
||||
return a
|
||||
}
|
||||
|
||||
func (c *Conn) popAnswer(id answerID) *answer {
|
||||
if c.answers == nil {
|
||||
return nil
|
||||
}
|
||||
a := c.answers[id]
|
||||
delete(c.answers, id)
|
||||
return a
|
||||
}
|
||||
|
||||
type answer struct {
|
||||
id answerID
|
||||
cancel context.CancelFunc
|
||||
resultCaps []exportID
|
||||
conn *Conn
|
||||
resolved chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
obj capnp.Ptr
|
||||
err error
|
||||
done bool
|
||||
queue []pcall
|
||||
}
|
||||
|
||||
// fulfill is called to resolve an answer successfully. It returns an
|
||||
// error if its connection is shut down while sending messages. The
|
||||
// caller must be holding onto a.conn.mu.
|
||||
func (a *answer) fulfill(obj capnp.Ptr) error {
|
||||
a.mu.Lock()
|
||||
if a.done {
|
||||
panic("answer.fulfill called more than once")
|
||||
}
|
||||
a.obj, a.done = obj, true
|
||||
// TODO(light): populate resultCaps
|
||||
|
||||
var firstErr error
|
||||
if err := a.conn.startWork(); err != nil {
|
||||
firstErr = err
|
||||
for i := range a.queue {
|
||||
a.queue[i].a.reject(err)
|
||||
}
|
||||
a.queue = nil
|
||||
} else {
|
||||
retmsg := newReturnMessage(nil, a.id)
|
||||
ret, _ := retmsg.Return()
|
||||
payload, _ := ret.NewResults()
|
||||
payload.SetContentPtr(obj)
|
||||
if payloadTab, err := a.conn.makeCapTable(ret.Segment()); err != nil {
|
||||
firstErr = err
|
||||
} else {
|
||||
payload.SetCapTable(payloadTab)
|
||||
if err := a.conn.sendMessage(retmsg); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
queues, err := a.emptyQueue(obj)
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
ctab := obj.Segment().Message().CapTable
|
||||
for capIdx, q := range queues {
|
||||
ctab[capIdx] = newQueueClient(a.conn, ctab[capIdx], q)
|
||||
}
|
||||
a.conn.workers.Done()
|
||||
}
|
||||
close(a.resolved)
|
||||
a.mu.Unlock()
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// reject is called to resolve an answer with failure. It returns an
|
||||
// error if its connection is shut down while sending messages. The
|
||||
// caller must be holding onto a.conn.mu.
|
||||
func (a *answer) reject(err error) error {
|
||||
if err == nil {
|
||||
panic("answer.reject called with nil")
|
||||
}
|
||||
a.mu.Lock()
|
||||
if a.done {
|
||||
panic("answer.reject called more than once")
|
||||
}
|
||||
a.err, a.done = err, true
|
||||
m := newReturnMessage(nil, a.id)
|
||||
mret, _ := m.Return()
|
||||
setReturnException(mret, err)
|
||||
var firstErr error
|
||||
if err := a.conn.sendMessage(m); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
for i := range a.queue {
|
||||
if err := a.queue[i].a.reject(err); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
a.queue = nil
|
||||
close(a.resolved)
|
||||
a.mu.Unlock()
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// emptyQueue splits the queue by which capability it targets
|
||||
// and drops any invalid calls. Once this function returns, a.queue
|
||||
// will be nil.
|
||||
func (a *answer) emptyQueue(obj capnp.Ptr) (map[capnp.CapabilityID][]qcall, error) {
|
||||
var firstErr error
|
||||
qs := make(map[capnp.CapabilityID][]qcall, len(a.queue))
|
||||
for i, pc := range a.queue {
|
||||
c, err := capnp.TransformPtr(obj, pc.transform)
|
||||
if err != nil {
|
||||
if err := pc.a.reject(err); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
ci := c.Interface()
|
||||
if !ci.IsValid() {
|
||||
if err := pc.a.reject(capnp.ErrNullClient); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
cn := ci.Capability()
|
||||
if qs[cn] == nil {
|
||||
qs[cn] = make([]qcall, 0, len(a.queue)-i)
|
||||
}
|
||||
qs[cn] = append(qs[cn], pc.qcall)
|
||||
}
|
||||
a.queue = nil
|
||||
return qs, firstErr
|
||||
}
|
||||
|
||||
// queueCallLocked enqueues a call to be made after the answer has been
|
||||
// resolved. The answer must not be resolved yet. pc should have
|
||||
// transform and one of pc.a or pc.f to be set. The caller must be
|
||||
// holding onto a.mu.
|
||||
func (a *answer) queueCallLocked(call *capnp.Call, pc pcall) error {
|
||||
if len(a.queue) == cap(a.queue) {
|
||||
return errQueueFull
|
||||
}
|
||||
var err error
|
||||
pc.call, err = call.Copy(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.queue = append(a.queue, pc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueDisembargo enqueues a disembargo message.
|
||||
func (a *answer) queueDisembargo(transform []capnp.PipelineOp, id embargoID, target rpccapnp.MessageTarget) (queued bool, err error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if !a.done {
|
||||
return false, errDisembargoOngoingAnswer
|
||||
}
|
||||
if a.err != nil {
|
||||
return false, errDisembargoNonImport
|
||||
}
|
||||
targetPtr, err := capnp.TransformPtr(a.obj, transform)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
client := targetPtr.Interface().Client()
|
||||
qc, ok := client.(*queueClient)
|
||||
if !ok {
|
||||
// No need to embargo, disembargo immediately.
|
||||
return false, nil
|
||||
}
|
||||
if ic := isImport(qc.client); ic == nil || a.conn != ic.conn {
|
||||
return false, errDisembargoNonImport
|
||||
}
|
||||
qc.mu.Lock()
|
||||
if !qc.isPassthrough() {
|
||||
err = qc.pushEmbargoLocked(id, target)
|
||||
if err == nil {
|
||||
queued = true
|
||||
}
|
||||
}
|
||||
qc.mu.Unlock()
|
||||
return queued, err
|
||||
}
|
||||
|
||||
func (a *answer) pipelineClient(transform []capnp.PipelineOp) capnp.Client {
|
||||
return &localAnswerClient{a: a, transform: transform}
|
||||
}
|
||||
|
||||
// joinAnswer resolves an RPC answer by waiting on a generic answer.
|
||||
// The caller must not be holding onto a.conn.mu.
|
||||
func joinAnswer(a *answer, ca capnp.Answer) {
|
||||
s, err := ca.Struct()
|
||||
a.conn.mu.Lock()
|
||||
if err == nil {
|
||||
a.fulfill(s.ToPtr())
|
||||
} else {
|
||||
a.reject(err)
|
||||
}
|
||||
a.conn.mu.Unlock()
|
||||
}
|
||||
|
||||
// joinFulfiller resolves a fulfiller by waiting on a generic answer.
|
||||
func joinFulfiller(f *fulfiller.Fulfiller, ca capnp.Answer) {
|
||||
s, err := ca.Struct()
|
||||
if err != nil {
|
||||
f.Reject(err)
|
||||
} else {
|
||||
f.Fulfill(s)
|
||||
}
|
||||
}
|
||||
|
||||
type queueClient struct {
|
||||
client capnp.Client
|
||||
conn *Conn
|
||||
|
||||
mu sync.RWMutex
|
||||
q queue.Queue
|
||||
calls qcallList
|
||||
}
|
||||
|
||||
func newQueueClient(c *Conn, client capnp.Client, queue []qcall) *queueClient {
|
||||
qc := &queueClient{
|
||||
client: client,
|
||||
conn: c,
|
||||
calls: make(qcallList, callQueueSize),
|
||||
}
|
||||
qc.q.Init(qc.calls, copy(qc.calls, queue))
|
||||
go qc.flushQueue()
|
||||
return qc
|
||||
}
|
||||
|
||||
func (qc *queueClient) pushCallLocked(cl *capnp.Call) capnp.Answer {
|
||||
f := new(fulfiller.Fulfiller)
|
||||
cl, err := cl.Copy(nil)
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
i := qc.q.Push()
|
||||
if i == -1 {
|
||||
return capnp.ErrorAnswer(errQueueFull)
|
||||
}
|
||||
qc.calls[i] = qcall{call: cl, f: f}
|
||||
return f
|
||||
}
|
||||
|
||||
func (qc *queueClient) pushEmbargoLocked(id embargoID, tgt rpccapnp.MessageTarget) error {
|
||||
i := qc.q.Push()
|
||||
if i == -1 {
|
||||
return errQueueFull
|
||||
}
|
||||
qc.calls[i] = qcall{embargoID: id, embargoTarget: tgt}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushQueue is run in its own goroutine.
|
||||
func (qc *queueClient) flushQueue() {
|
||||
var c qcall
|
||||
qc.mu.RLock()
|
||||
if i := qc.q.Front(); i != -1 {
|
||||
c = qc.calls[i]
|
||||
}
|
||||
qc.mu.RUnlock()
|
||||
for c.which() != qcallInvalid {
|
||||
qc.handle(&c)
|
||||
|
||||
qc.mu.Lock()
|
||||
qc.q.Pop()
|
||||
if i := qc.q.Front(); i != -1 {
|
||||
c = qc.calls[i]
|
||||
} else {
|
||||
c = qcall{}
|
||||
}
|
||||
qc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (qc *queueClient) handle(c *qcall) {
|
||||
switch c.which() {
|
||||
case qcallRemoteCall:
|
||||
answer := qc.client.Call(c.call)
|
||||
go joinAnswer(c.a, answer)
|
||||
case qcallLocalCall:
|
||||
answer := qc.client.Call(c.call)
|
||||
go joinFulfiller(c.f, answer)
|
||||
case qcallDisembargo:
|
||||
msg := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
|
||||
d, _ := msg.Disembargo()
|
||||
d.SetTarget(c.embargoTarget)
|
||||
qc.conn.sendMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (qc *queueClient) isPassthrough() bool {
|
||||
return qc.q.Len() == 0
|
||||
}
|
||||
|
||||
func (qc *queueClient) Call(cl *capnp.Call) capnp.Answer {
|
||||
// Fast path: queue is flushed.
|
||||
qc.mu.RLock()
|
||||
ok := qc.isPassthrough()
|
||||
qc.mu.RUnlock()
|
||||
if ok {
|
||||
return qc.client.Call(cl)
|
||||
}
|
||||
|
||||
// Add to queue.
|
||||
qc.mu.Lock()
|
||||
// Since we released the lock, check that the queue hasn't been flushed.
|
||||
if qc.isPassthrough() {
|
||||
qc.mu.Unlock()
|
||||
return qc.client.Call(cl)
|
||||
}
|
||||
ans := qc.pushCallLocked(cl)
|
||||
qc.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
func (qc *queueClient) tryQueue(cl *capnp.Call) capnp.Answer {
|
||||
qc.mu.Lock()
|
||||
if qc.isPassthrough() {
|
||||
qc.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
ans := qc.pushCallLocked(cl)
|
||||
qc.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
func (qc *queueClient) Close() error {
|
||||
qc.conn.mu.Lock()
|
||||
if err := qc.conn.startWork(); err != nil {
|
||||
qc.conn.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
rejErr := qc.rejectQueue()
|
||||
qc.conn.workers.Done()
|
||||
qc.conn.mu.Unlock()
|
||||
if err := qc.client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return rejErr
|
||||
}
|
||||
|
||||
// rejectQueue drains the client's queue. It returns an error if the
|
||||
// connection was shut down while messages are sent. The caller must be
|
||||
// holding onto qc.conn.mu.
|
||||
func (qc *queueClient) rejectQueue() error {
|
||||
var firstErr error
|
||||
qc.mu.Lock()
|
||||
for ; qc.q.Len() > 0; qc.q.Pop() {
|
||||
c := qc.calls[qc.q.Front()]
|
||||
switch c.which() {
|
||||
case qcallRemoteCall:
|
||||
if err := c.a.reject(errQueueCallCancel); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
case qcallLocalCall:
|
||||
c.f.Reject(errQueueCallCancel)
|
||||
case qcallDisembargo:
|
||||
m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
|
||||
d, _ := m.Disembargo()
|
||||
d.SetTarget(c.embargoTarget)
|
||||
if err := qc.conn.sendMessage(m); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
qc.mu.Unlock()
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// pcall is a queued pipeline call.
|
||||
type pcall struct {
|
||||
transform []capnp.PipelineOp
|
||||
qcall
|
||||
}
|
||||
|
||||
// qcall is a queued call.
|
||||
type qcall struct {
|
||||
// Calls
|
||||
a *answer // non-nil if remote call
|
||||
f *fulfiller.Fulfiller // non-nil if local call
|
||||
call *capnp.Call
|
||||
|
||||
// Disembargo
|
||||
embargoID embargoID
|
||||
embargoTarget rpccapnp.MessageTarget
|
||||
}
|
||||
|
||||
// Queued call types.
|
||||
const (
|
||||
qcallInvalid = iota
|
||||
qcallRemoteCall
|
||||
qcallLocalCall
|
||||
qcallDisembargo
|
||||
)
|
||||
|
||||
func (c *qcall) which() int {
|
||||
switch {
|
||||
case c.a != nil:
|
||||
return qcallRemoteCall
|
||||
case c.f != nil:
|
||||
return qcallLocalCall
|
||||
case c.embargoTarget.IsValid():
|
||||
return qcallDisembargo
|
||||
default:
|
||||
return qcallInvalid
|
||||
}
|
||||
}
|
||||
|
||||
type qcallList []qcall
|
||||
|
||||
func (ql qcallList) Len() int {
|
||||
return len(ql)
|
||||
}
|
||||
|
||||
func (ql qcallList) Clear(i int) {
|
||||
ql[i] = qcall{}
|
||||
}
|
||||
|
||||
// A localAnswerClient is used to provide a pipelined client of an answer.
|
||||
type localAnswerClient struct {
|
||||
a *answer
|
||||
transform []capnp.PipelineOp
|
||||
}
|
||||
|
||||
func (lac *localAnswerClient) Call(call *capnp.Call) capnp.Answer {
|
||||
lac.a.mu.Lock()
|
||||
if lac.a.done {
|
||||
obj, err := lac.a.obj, lac.a.err
|
||||
lac.a.mu.Unlock()
|
||||
return clientFromResolution(lac.transform, obj, err).Call(call)
|
||||
}
|
||||
f := new(fulfiller.Fulfiller)
|
||||
err := lac.a.queueCallLocked(call, pcall{
|
||||
transform: lac.transform,
|
||||
qcall: qcall{f: f},
|
||||
})
|
||||
lac.a.mu.Unlock()
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(errQueueFull)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (lac *localAnswerClient) Close() error {
|
||||
lac.a.mu.RLock()
|
||||
obj, err, done := lac.a.obj, lac.a.err, lac.a.done
|
||||
lac.a.mu.RUnlock()
|
||||
if !done {
|
||||
return nil
|
||||
}
|
||||
client := clientFromResolution(lac.transform, obj, err)
|
||||
return client.Close()
|
||||
}
|
||||
|
||||
var (
|
||||
errQueueFull = errors.New("rpc: pipeline queue full")
|
||||
errQueueCallCancel = errors.New("rpc: queued call canceled")
|
||||
|
||||
errDisembargoOngoingAnswer = errors.New("rpc: disembargo attempted on in-progress answer")
|
||||
errDisembargoNonImport = errors.New("rpc: disembargo attempted on non-import capability")
|
||||
errDisembargoMissingAnswer = errors.New("rpc: disembargo attempted on missing answer (finished too early?)")
|
||||
)
|
55
vendor/zombiezen.com/go/capnproto2/rpc/bench_test.go
generated
vendored
Normal file
55
vendor/zombiezen.com/go/capnproto2/rpc/bench_test.go
generated
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
)
|
||||
|
||||
func BenchmarkPingPong(b *testing.B) {
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{b}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
d := rpc.NewConn(q, rpc.ConnLog(log), rpc.BootstrapFunc(bootstrapPingPong))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
client := testcapnp.PingPong{Client: c.Bootstrap(ctx)}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
promise := client.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error {
|
||||
p.SetN(42)
|
||||
return nil
|
||||
})
|
||||
result, err := promise.Struct()
|
||||
if err != nil {
|
||||
b.Errorf("EchoNum(42) failed on iteration %d: %v", i, err)
|
||||
break
|
||||
}
|
||||
if result.N() != 42 {
|
||||
b.Errorf("EchoNum(42) = %d; want 42", result.N())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootstrapPingPong(ctx context.Context) (capnp.Client, error) {
|
||||
return testcapnp.PingPong_ServerToClient(pingPongServer{}).Client, nil
|
||||
}
|
||||
|
||||
type pingPongServer struct{}
|
||||
|
||||
func (pingPongServer) EchoNum(call testcapnp.PingPong_echoNum) error {
|
||||
call.Results.SetN(call.Params.N())
|
||||
return nil
|
||||
}
|
52
vendor/zombiezen.com/go/capnproto2/rpc/cancel_test.go
generated
vendored
Normal file
52
vendor/zombiezen.com/go/capnproto2/rpc/cancel_test.go
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func TestCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
log := testLogger{t}
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
notify := make(chan struct{})
|
||||
hanger := testcapnp.Hanger_ServerToClient(Hanger{notify: notify})
|
||||
d := rpc.NewConn(q, rpc.MainInterface(hanger.Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.Hanger{Client: c.Bootstrap(ctx)}
|
||||
|
||||
subctx, subcancel := context.WithCancel(ctx)
|
||||
promise := client.Hang(subctx, nil)
|
||||
<-notify
|
||||
subcancel()
|
||||
_, err := promise.Struct()
|
||||
<-notify // test will deadlock if cancel not delivered
|
||||
|
||||
if err != context.Canceled {
|
||||
t.Errorf("promise.Get() error: %v; want %v", err, context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
type Hanger struct {
|
||||
notify chan struct{}
|
||||
}
|
||||
|
||||
func (h Hanger) Hang(call testcapnp.Hanger_hang) error {
|
||||
server.Ack(call.Options)
|
||||
h.notify <- struct{}{}
|
||||
<-call.Ctx.Done()
|
||||
close(h.notify)
|
||||
return nil
|
||||
}
|
91
vendor/zombiezen.com/go/capnproto2/rpc/embargo_test.go
generated
vendored
Normal file
91
vendor/zombiezen.com/go/capnproto2/rpc/embargo_test.go
generated
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
)
|
||||
|
||||
func TestEmbargo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{t}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
echoSrv := testcapnp.Echoer_ServerToClient(new(Echoer))
|
||||
d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.Echoer{Client: c.Bootstrap(ctx)}
|
||||
localCap := testcapnp.CallOrder_ServerToClient(new(CallOrder))
|
||||
|
||||
earlyCall := callseq(ctx, client.Client, 0)
|
||||
echo := client.Echo(ctx, func(p testcapnp.Echoer_echo_Params) error {
|
||||
return p.SetCap(localCap)
|
||||
})
|
||||
pipeline := echo.Cap()
|
||||
call0 := callseq(ctx, pipeline.Client, 0)
|
||||
call1 := callseq(ctx, pipeline.Client, 1)
|
||||
_, err := earlyCall.Struct()
|
||||
if err != nil {
|
||||
t.Errorf("earlyCall error: %v", err)
|
||||
}
|
||||
call2 := callseq(ctx, pipeline.Client, 2)
|
||||
_, err = echo.Struct()
|
||||
if err != nil {
|
||||
t.Errorf("echo.Get() error: %v", err)
|
||||
}
|
||||
call3 := callseq(ctx, pipeline.Client, 3)
|
||||
call4 := callseq(ctx, pipeline.Client, 4)
|
||||
call5 := callseq(ctx, pipeline.Client, 5)
|
||||
|
||||
check := func(promise testcapnp.CallOrder_getCallSequence_Results_Promise, n uint32) {
|
||||
r, err := promise.Struct()
|
||||
if err != nil {
|
||||
t.Errorf("call%d error: %v", n, err)
|
||||
}
|
||||
if r.N() != n {
|
||||
t.Errorf("call%d = %d; want %d", n, r.N(), n)
|
||||
}
|
||||
}
|
||||
check(call0, 0)
|
||||
check(call1, 1)
|
||||
check(call2, 2)
|
||||
check(call3, 3)
|
||||
check(call4, 4)
|
||||
check(call5, 5)
|
||||
}
|
||||
|
||||
func callseq(c context.Context, client capnp.Client, n uint32) testcapnp.CallOrder_getCallSequence_Results_Promise {
|
||||
return testcapnp.CallOrder{Client: client}.GetCallSequence(c, func(p testcapnp.CallOrder_getCallSequence_Params) error {
|
||||
p.SetExpected(n)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type CallOrder struct {
|
||||
n uint32
|
||||
}
|
||||
|
||||
func (co *CallOrder) GetCallSequence(call testcapnp.CallOrder_getCallSequence) error {
|
||||
call.Results.SetN(co.n)
|
||||
co.n++
|
||||
return nil
|
||||
}
|
||||
|
||||
type Echoer struct {
|
||||
CallOrder
|
||||
}
|
||||
|
||||
func (*Echoer) Echo(call testcapnp.Echoer_echo) error {
|
||||
call.Results.SetCap(call.Params.Cap())
|
||||
return nil
|
||||
}
|
102
vendor/zombiezen.com/go/capnproto2/rpc/errors.go
generated
vendored
Normal file
102
vendor/zombiezen.com/go/capnproto2/rpc/errors.go
generated
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"zombiezen.com/go/capnproto2"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// An Exception is a Cap'n Proto RPC error.
|
||||
type Exception struct {
|
||||
rpccapnp.Exception
|
||||
}
|
||||
|
||||
// Error returns the exception's reason.
|
||||
func (e Exception) Error() string {
|
||||
r, err := e.Reason()
|
||||
if err != nil {
|
||||
return "rpc exception"
|
||||
}
|
||||
return "rpc exception: " + r
|
||||
}
|
||||
|
||||
// An Abort is a hang-up by a remote vat.
|
||||
type Abort Exception
|
||||
|
||||
func copyAbort(m rpccapnp.Message) (Abort, error) {
|
||||
ma, err := m.Abort()
|
||||
if err != nil {
|
||||
return Abort{}, err
|
||||
}
|
||||
msg, _, _ := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err := msg.SetRootPtr(ma.ToPtr()); err != nil {
|
||||
return Abort{}, err
|
||||
}
|
||||
p, err := msg.RootPtr()
|
||||
if err != nil {
|
||||
return Abort{}, err
|
||||
}
|
||||
return Abort{rpccapnp.Exception{Struct: p.Struct()}}, nil
|
||||
}
|
||||
|
||||
// Error returns the exception's reason.
|
||||
func (a Abort) Error() string {
|
||||
r, err := a.Reason()
|
||||
if err != nil {
|
||||
return "rpc: aborted by remote"
|
||||
}
|
||||
return "rpc: aborted by remote: " + r
|
||||
}
|
||||
|
||||
// toException sets fields on exc to match err.
|
||||
func toException(exc rpccapnp.Exception, err error) {
|
||||
if ee, ok := err.(Exception); ok {
|
||||
// TODO(light): copy struct
|
||||
r, err := ee.Reason()
|
||||
if err == nil {
|
||||
exc.SetReason(r)
|
||||
}
|
||||
exc.SetType(ee.Type())
|
||||
return
|
||||
}
|
||||
|
||||
exc.SetReason(err.Error())
|
||||
exc.SetType(rpccapnp.Exception_Type_failed)
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrConnClosed = errors.New("rpc: connection closed")
|
||||
)
|
||||
|
||||
// Internal errors
|
||||
var (
|
||||
errQuestionReused = errors.New("rpc: question ID reused")
|
||||
errNoMainInterface = errors.New("rpc: no bootstrap interface")
|
||||
errBadTarget = errors.New("rpc: target not found")
|
||||
errShutdown = errors.New("rpc: shutdown")
|
||||
errUnimplemented = errors.New("rpc: remote used unimplemented protocol feature")
|
||||
)
|
||||
|
||||
type bootstrapError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e bootstrapError) Error() string {
|
||||
return "rpc bootstrap:" + e.err.Error()
|
||||
}
|
||||
|
||||
type questionError struct {
|
||||
id questionID
|
||||
method *capnp.Method // nil if this is bootstrap
|
||||
err error
|
||||
}
|
||||
|
||||
func (qe *questionError) Error() string {
|
||||
if qe.method == nil {
|
||||
return fmt.Sprintf("bootstrap call id=%d: %v", qe.id, qe.err)
|
||||
}
|
||||
return fmt.Sprintf("%v call id=%d: %v", qe.method, qe.id, qe.err)
|
||||
}
|
75
vendor/zombiezen.com/go/capnproto2/rpc/example_test.go
generated
vendored
Normal file
75
vendor/zombiezen.com/go/capnproto2/rpc/example_test.go
generated
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
// Create an in-memory transport. In a real application, you would probably
|
||||
// use a net.TCPConn (for RPC) or an os.Pipe (for IPC).
|
||||
p1, p2 := net.Pipe()
|
||||
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
|
||||
|
||||
// Server-side
|
||||
srv := testcapnp.Adder_ServerToClient(AdderServer{})
|
||||
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
|
||||
defer serverConn.Wait()
|
||||
|
||||
// Client-side
|
||||
ctx := context.Background()
|
||||
clientConn := rpc.NewConn(t2)
|
||||
defer clientConn.Close()
|
||||
adderClient := testcapnp.Adder{Client: clientConn.Bootstrap(ctx)}
|
||||
// Every client call returns a promise. You can make multiple calls
|
||||
// concurrently.
|
||||
call1 := adderClient.Add(ctx, func(p testcapnp.Adder_add_Params) error {
|
||||
p.SetA(5)
|
||||
p.SetB(2)
|
||||
return nil
|
||||
})
|
||||
call2 := adderClient.Add(ctx, func(p testcapnp.Adder_add_Params) error {
|
||||
p.SetA(10)
|
||||
p.SetB(20)
|
||||
return nil
|
||||
})
|
||||
// Calling Struct() on a promise waits until it returns.
|
||||
result1, err := call1.Struct()
|
||||
if err != nil {
|
||||
fmt.Println("Add #1 failed:", err)
|
||||
return
|
||||
}
|
||||
result2, err := call2.Struct()
|
||||
if err != nil {
|
||||
fmt.Println("Add #2 failed:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Results:", result1.Result(), result2.Result())
|
||||
// Output:
|
||||
// Results: 7 30
|
||||
}
|
||||
|
||||
// An AdderServer is a local implementation of the Adder interface.
|
||||
type AdderServer struct{}
|
||||
|
||||
// Add implements a method
|
||||
func (AdderServer) Add(call testcapnp.Adder_add) error {
|
||||
// Acknowledging the call allows other calls to be made (it returns the Answer
|
||||
// to the caller).
|
||||
server.Ack(call.Options)
|
||||
|
||||
// Parameters are accessed with call.Params.
|
||||
a := call.Params.A()
|
||||
b := call.Params.B()
|
||||
|
||||
// A result struct is allocated for you at call.Results.
|
||||
call.Results.SetResult(a + b)
|
||||
|
||||
return nil
|
||||
}
|
15
vendor/zombiezen.com/go/capnproto2/rpc/internal/logtransport/BUILD.bazel
generated
vendored
Normal file
15
vendor/zombiezen.com/go/capnproto2/rpc/internal/logtransport/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["logtransport.go"],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc/internal/logtransport",
|
||||
visibility = ["//rpc:__subpackages__"],
|
||||
deps = [
|
||||
"//encoding/text:go_default_library",
|
||||
"//rpc:go_default_library",
|
||||
"//rpc/internal/logutil:go_default_library",
|
||||
"//std/capnp/rpc:go_default_library",
|
||||
"@org_golang_x_net//context:go_default_library",
|
||||
],
|
||||
)
|
52
vendor/zombiezen.com/go/capnproto2/rpc/internal/logtransport/logtransport.go
generated
vendored
Normal file
52
vendor/zombiezen.com/go/capnproto2/rpc/internal/logtransport/logtransport.go
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
// Package logtransport provides a transport that logs all of its messages.
|
||||
package logtransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/encoding/text"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logutil"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
type transport struct {
|
||||
rpc.Transport
|
||||
l *log.Logger
|
||||
sendBuf bytes.Buffer
|
||||
recvBuf bytes.Buffer
|
||||
}
|
||||
|
||||
// New creates a new logger that proxies messages to and from t and
|
||||
// logs them to l. If l is nil, then the log package's default
|
||||
// logger is used.
|
||||
func New(l *log.Logger, t rpc.Transport) rpc.Transport {
|
||||
return &transport{Transport: t, l: l}
|
||||
}
|
||||
|
||||
func (t *transport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
||||
t.sendBuf.Reset()
|
||||
t.sendBuf.WriteString("<- ")
|
||||
formatMsg(&t.sendBuf, msg)
|
||||
logutil.Print(t.l, t.sendBuf.String())
|
||||
return t.Transport.SendMessage(ctx, msg)
|
||||
}
|
||||
|
||||
func (t *transport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
||||
msg, err := t.Transport.RecvMessage(ctx)
|
||||
if err != nil {
|
||||
return msg, err
|
||||
}
|
||||
t.recvBuf.Reset()
|
||||
t.recvBuf.WriteString("-> ")
|
||||
formatMsg(&t.recvBuf, msg)
|
||||
logutil.Print(t.l, t.recvBuf.String())
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func formatMsg(w io.Writer, m rpccapnp.Message) {
|
||||
text.NewEncoder(w).Encode(0x91b79f1f808db032, m.Struct)
|
||||
}
|
8
vendor/zombiezen.com/go/capnproto2/rpc/internal/logutil/BUILD.bazel
generated
vendored
Normal file
8
vendor/zombiezen.com/go/capnproto2/rpc/internal/logutil/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["logutil.go"],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc/internal/logutil",
|
||||
visibility = ["//rpc:__subpackages__"],
|
||||
)
|
36
vendor/zombiezen.com/go/capnproto2/rpc/internal/logutil/logutil.go
generated
vendored
Normal file
36
vendor/zombiezen.com/go/capnproto2/rpc/internal/logutil/logutil.go
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
// Package logutil provides functions that can print to a logger.
|
||||
// Any function in this package that takes in a *log.Logger can be
|
||||
// passed nil to use the log package's default logger.
|
||||
package logutil
|
||||
|
||||
import "log"
|
||||
|
||||
// Print calls Print on a logger or the default logger.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
func Print(l *log.Logger, v ...interface{}) {
|
||||
if l == nil {
|
||||
log.Print(v...)
|
||||
} else {
|
||||
l.Print(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Printf calls Printf on a logger or the default logger.
|
||||
// Arguments are handled in the manner of fmt.Printf.
|
||||
func Printf(l *log.Logger, format string, v ...interface{}) {
|
||||
if l == nil {
|
||||
log.Printf(format, v...)
|
||||
} else {
|
||||
l.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Println calls Println on a logger or the default logger.
|
||||
// Arguments are handled in the manner of fmt.Println.
|
||||
func Println(l *log.Logger, v ...interface{}) {
|
||||
if l == nil {
|
||||
log.Println(v...)
|
||||
} else {
|
||||
l.Println(v...)
|
||||
}
|
||||
}
|
14
vendor/zombiezen.com/go/capnproto2/rpc/internal/pipetransport/BUILD.bazel
generated
vendored
Normal file
14
vendor/zombiezen.com/go/capnproto2/rpc/internal/pipetransport/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["pipetransport.go"],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc/internal/pipetransport",
|
||||
visibility = ["//rpc:__subpackages__"],
|
||||
deps = [
|
||||
"//:go_default_library",
|
||||
"//rpc:go_default_library",
|
||||
"//std/capnp/rpc:go_default_library",
|
||||
"@org_golang_x_net//context:go_default_library",
|
||||
],
|
||||
)
|
139
vendor/zombiezen.com/go/capnproto2/rpc/internal/pipetransport/pipetransport.go
generated
vendored
Normal file
139
vendor/zombiezen.com/go/capnproto2/rpc/internal/pipetransport/pipetransport.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
// Package pipetransport provides in-memory implementations of rpc.Transport for testing.
|
||||
package pipetransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
type pipeTransport struct {
|
||||
r <-chan rpccapnp.Message
|
||||
w chan<- rpccapnp.Message
|
||||
finish chan struct{}
|
||||
otherFin chan struct{}
|
||||
|
||||
rbuf bytes.Buffer
|
||||
|
||||
mu sync.Mutex
|
||||
inflight int
|
||||
done bool
|
||||
}
|
||||
|
||||
// New creates a synchronous in-memory pipe transport.
|
||||
func New() (p, q rpc.Transport) {
|
||||
a, b := make(chan rpccapnp.Message), make(chan rpccapnp.Message)
|
||||
afin, bfin := make(chan struct{}), make(chan struct{})
|
||||
p = &pipeTransport{
|
||||
r: a,
|
||||
w: b,
|
||||
finish: afin,
|
||||
otherFin: bfin,
|
||||
}
|
||||
q = &pipeTransport{
|
||||
r: b,
|
||||
w: a,
|
||||
finish: bfin,
|
||||
otherFin: afin,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *pipeTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
||||
if !p.startSend() {
|
||||
return errClosed
|
||||
}
|
||||
defer p.finishSend()
|
||||
|
||||
buf, err := msg.Segment().Message().Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mm, err := capnp.Unmarshal(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg, err = rpccapnp.ReadRootMessage(mm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case p.w <- msg:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.finish:
|
||||
return errClosed
|
||||
case <-p.otherFin:
|
||||
return errBrokenPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pipeTransport) startSend() bool {
|
||||
p.mu.Lock()
|
||||
ok := !p.done
|
||||
if ok {
|
||||
p.inflight++
|
||||
}
|
||||
p.mu.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (p *pipeTransport) finishSend() {
|
||||
p.mu.Lock()
|
||||
p.inflight--
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *pipeTransport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
||||
// Scribble over shared buffer to test for race conditions.
|
||||
for b, i := p.rbuf.Bytes(), 0; i < len(b); i++ {
|
||||
b[i] = 0xff
|
||||
}
|
||||
p.rbuf.Reset()
|
||||
|
||||
select {
|
||||
case msg, ok := <-p.r:
|
||||
if !ok {
|
||||
return rpccapnp.Message{}, errBrokenPipe
|
||||
}
|
||||
if err := capnp.NewEncoder(&p.rbuf).Encode(msg.Segment().Message()); err != nil {
|
||||
return rpccapnp.Message{}, err
|
||||
}
|
||||
m, err := capnp.Unmarshal(p.rbuf.Bytes())
|
||||
if err != nil {
|
||||
return rpccapnp.Message{}, err
|
||||
}
|
||||
return rpccapnp.ReadRootMessage(m)
|
||||
case <-ctx.Done():
|
||||
return rpccapnp.Message{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pipeTransport) Close() error {
|
||||
p.mu.Lock()
|
||||
done := p.done
|
||||
if !done {
|
||||
p.done = true
|
||||
close(p.finish)
|
||||
if p.inflight == 0 {
|
||||
close(p.w)
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if done {
|
||||
return errClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errBrokenPipe = errors.New("pipetransport: broken pipe")
|
||||
errClosed = errors.New("pipetransport: write to broken pipe")
|
||||
)
|
16
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/BUILD.bazel
generated
vendored
Normal file
16
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["refcount.go"],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc/internal/refcount",
|
||||
visibility = ["//rpc:__subpackages__"],
|
||||
deps = ["//:go_default_library"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["refcount_test.go"],
|
||||
embed = [":go_default_library"],
|
||||
deps = ["//:go_default_library"],
|
||||
)
|
116
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/refcount.go
generated
vendored
Normal file
116
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/refcount.go
generated
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
// Package refcount implements a reference-counting client.
|
||||
package refcount
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// A RefCount will close its underlying client once all its references are closed.
|
||||
type RefCount struct {
|
||||
Client capnp.Client
|
||||
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
// New creates a reference counter and the first client reference.
|
||||
func New(c capnp.Client) (rc *RefCount, ref1 capnp.Client) {
|
||||
if rr, ok := c.(*Ref); ok {
|
||||
return rr.rc, rr.rc.Ref()
|
||||
}
|
||||
rc = &RefCount{Client: c, refs: 1}
|
||||
ref1 = rc.newRef()
|
||||
return
|
||||
}
|
||||
|
||||
// Ref makes a new client reference.
|
||||
func (rc *RefCount) Ref() capnp.Client {
|
||||
rc.mu.Lock()
|
||||
if rc.refs <= 0 {
|
||||
rc.mu.Unlock()
|
||||
return capnp.ErrorClient(errZeroRef)
|
||||
}
|
||||
rc.refs++
|
||||
rc.mu.Unlock()
|
||||
return rc.newRef()
|
||||
}
|
||||
|
||||
func (rc *RefCount) newRef() *Ref {
|
||||
r := &Ref{rc: rc}
|
||||
runtime.SetFinalizer(r, (*Ref).Close)
|
||||
return r
|
||||
}
|
||||
|
||||
func (rc *RefCount) call(cl *capnp.Call) capnp.Answer {
|
||||
// We lock here so that we can prevent the client from being closed
|
||||
// while we start the call.
|
||||
rc.mu.Lock()
|
||||
if rc.refs <= 0 {
|
||||
rc.mu.Unlock()
|
||||
return capnp.ErrorAnswer(errClosed)
|
||||
}
|
||||
ans := rc.Client.Call(cl)
|
||||
rc.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
// decref decreases the reference count by one, closing the Client if it reaches zero.
|
||||
func (rc *RefCount) decref() error {
|
||||
shouldClose := false
|
||||
|
||||
rc.mu.Lock()
|
||||
if rc.refs <= 0 {
|
||||
rc.mu.Unlock()
|
||||
return errClosed
|
||||
}
|
||||
rc.refs--
|
||||
if rc.refs == 0 {
|
||||
shouldClose = true
|
||||
}
|
||||
rc.mu.Unlock()
|
||||
|
||||
if shouldClose {
|
||||
return rc.Client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errZeroRef = errors.New("rpc: Ref() called on zeroed refcount")
|
||||
errClosed = errors.New("rpc: Close() called on closed client")
|
||||
)
|
||||
|
||||
// A Ref is a single reference to a client wrapped by RefCount.
|
||||
type Ref struct {
|
||||
rc *RefCount
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Call makes a call on the underlying client.
|
||||
func (r *Ref) Call(cl *capnp.Call) capnp.Answer {
|
||||
return r.rc.call(cl)
|
||||
}
|
||||
|
||||
// Client returns the underlying client.
|
||||
func (r *Ref) Client() capnp.Client {
|
||||
return r.rc.Client
|
||||
}
|
||||
|
||||
// Close decrements the reference count. Close will be called on
|
||||
// finalization (i.e. garbage collection).
|
||||
func (r *Ref) Close() error {
|
||||
var err error
|
||||
closed := false
|
||||
r.once.Do(func() {
|
||||
err = r.rc.decref()
|
||||
closed = true
|
||||
})
|
||||
if !closed {
|
||||
return errClosed
|
||||
}
|
||||
return err
|
||||
}
|
89
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/refcount_test.go
generated
vendored
Normal file
89
vendor/zombiezen.com/go/capnproto2/rpc/internal/refcount/refcount_test.go
generated
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
package refcount
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
func TestSingleRefCloses(t *testing.T) {
|
||||
c := new(fakeClient)
|
||||
|
||||
_, ref := New(c)
|
||||
err := ref.Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("ref.Close(): %v", err)
|
||||
}
|
||||
if c.closed != 1 {
|
||||
t.Errorf("client Close() called %d times; want 1 time", c.closed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRefMultipleDecrefsOnce(t *testing.T) {
|
||||
c := new(fakeClient)
|
||||
|
||||
rc, ref1 := New(c)
|
||||
ref2 := rc.Ref()
|
||||
err1 := ref1.Close()
|
||||
err2 := ref1.Close()
|
||||
_ = ref2
|
||||
|
||||
if err1 != nil {
|
||||
t.Errorf("ref.Close() #1: %v", err1)
|
||||
}
|
||||
if err2 != errClosed {
|
||||
t.Errorf("ref.Close() #2: %v; want %v", err2, errClosed)
|
||||
}
|
||||
if c.closed != 0 {
|
||||
t.Errorf("client Close() called %d times; want 0 times", c.closed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClosingOneOfManyRefsDoesntClose(t *testing.T) {
|
||||
c := new(fakeClient)
|
||||
|
||||
rc, ref1 := New(c)
|
||||
ref2 := rc.Ref()
|
||||
err := ref1.Close()
|
||||
_ = ref2
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("ref1.Close(): %v", err)
|
||||
}
|
||||
if c.closed != 0 {
|
||||
t.Errorf("client Close() called %d times; want 0 times", c.closed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClosingAllRefsCloses(t *testing.T) {
|
||||
c := new(fakeClient)
|
||||
|
||||
rc, ref1 := New(c)
|
||||
ref2 := rc.Ref()
|
||||
err1 := ref1.Close()
|
||||
err2 := ref2.Close()
|
||||
|
||||
if err1 != nil {
|
||||
t.Errorf("ref1.Close(): %v", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
t.Errorf("ref2.Close(): %v", err2)
|
||||
}
|
||||
if c.closed != 1 {
|
||||
t.Errorf("client Close() called %d times; want 1 times", c.closed)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
closed int
|
||||
}
|
||||
|
||||
func (c *fakeClient) Call(cl *capnp.Call) capnp.Answer {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *fakeClient) Close() error {
|
||||
c.closed++
|
||||
return nil
|
||||
}
|
18
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/BUILD.bazel
generated
vendored
Normal file
18
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/BUILD.bazel
generated
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"generate.go",
|
||||
"test.capnp.go",
|
||||
],
|
||||
importpath = "zombiezen.com/go/capnproto2/rpc/internal/testcapnp",
|
||||
visibility = ["//rpc:__subpackages__"],
|
||||
deps = [
|
||||
"//:go_default_library",
|
||||
"//encoding/text:go_default_library",
|
||||
"//schemas:go_default_library",
|
||||
"//server:go_default_library",
|
||||
"@org_golang_x_net//context:go_default_library",
|
||||
],
|
||||
)
|
3
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/generate.go
generated
vendored
Normal file
3
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/generate.go
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
package testcapnp
|
||||
|
||||
//go:generate capnp compile -I ../../../std -ogo test.capnp
|
40
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/test.capnp
generated
vendored
Normal file
40
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/test.capnp
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# Test interfaces for RPC tests.
|
||||
|
||||
using Go = import "/go.capnp";
|
||||
|
||||
@0xef12a34b9807e19c;
|
||||
$Go.package("testcapnp");
|
||||
$Go.import("zombiezen.com/go/capnproto2/rpc/internal/testcapnp");
|
||||
|
||||
interface Handle {}
|
||||
|
||||
interface HandleFactory {
|
||||
newHandle @0 () -> (handle :Handle);
|
||||
}
|
||||
|
||||
interface Hanger {
|
||||
hang @0 () -> ();
|
||||
# Block until context is cancelled
|
||||
}
|
||||
|
||||
interface CallOrder {
|
||||
getCallSequence @0 (expected: UInt32) -> (n: UInt32);
|
||||
# First call returns 0, next returns 1, ...
|
||||
#
|
||||
# The input `expected` is ignored but useful for disambiguating debug logs.
|
||||
}
|
||||
|
||||
interface Echoer extends(CallOrder) {
|
||||
echo @0 (cap :CallOrder) -> (cap :CallOrder);
|
||||
# Just returns the input cap.
|
||||
}
|
||||
|
||||
interface PingPong {
|
||||
echoNum @0 (n :Int32) -> (n :Int32);
|
||||
}
|
||||
|
||||
# Example interfaces
|
||||
|
||||
interface Adder {
|
||||
add @0 (a :Int32, b :Int32) -> (result :Int32);
|
||||
}
|
1326
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/test.capnp.go
generated
vendored
Normal file
1326
vendor/zombiezen.com/go/capnproto2/rpc/internal/testcapnp/test.capnp.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
347
vendor/zombiezen.com/go/capnproto2/rpc/introspect.go
generated
vendored
Normal file
347
vendor/zombiezen.com/go/capnproto2/rpc/introspect.go
generated
vendored
Normal file
@@ -0,0 +1,347 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/internal/fulfiller"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/refcount"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// While the code below looks repetitive, resist the urge to refactor.
|
||||
// Each operation is distinct in assumptions it can make about
|
||||
// particular cases, and there isn't a convenient type signature that
|
||||
// fits all cases.
|
||||
|
||||
// lockedCall is used to make a call to an arbitrary client while
|
||||
// holding onto c.mu. Since the client could point back to c, naively
|
||||
// calling c.Call could deadlock.
|
||||
func (c *Conn) lockedCall(client capnp.Client, cl *capnp.Call) capnp.Answer {
|
||||
dig:
|
||||
for client := client; ; {
|
||||
switch curr := client.(type) {
|
||||
case *importClient:
|
||||
if curr.conn != c {
|
||||
// This doesn't use our conn's lock, so it is safe to call.
|
||||
return curr.Call(cl)
|
||||
}
|
||||
return curr.lockedCall(cl)
|
||||
case *fulfiller.EmbargoClient:
|
||||
if ans := curr.TryQueue(cl); ans != nil {
|
||||
return ans
|
||||
}
|
||||
client = curr.Client()
|
||||
case *refcount.Ref:
|
||||
client = curr.Client()
|
||||
case *embargoClient:
|
||||
if ans := curr.tryQueue(cl); ans != nil {
|
||||
return ans
|
||||
}
|
||||
client = curr.client
|
||||
case *queueClient:
|
||||
if ans := curr.tryQueue(cl); ans != nil {
|
||||
return ans
|
||||
}
|
||||
client = curr.client
|
||||
case *localAnswerClient:
|
||||
curr.a.mu.Lock()
|
||||
if curr.a.done {
|
||||
obj, err := curr.a.obj, curr.a.err
|
||||
curr.a.mu.Unlock()
|
||||
client = clientFromResolution(curr.transform, obj, err)
|
||||
} else {
|
||||
f := new(fulfiller.Fulfiller)
|
||||
err := curr.a.queueCallLocked(cl, pcall{
|
||||
transform: curr.transform,
|
||||
qcall: qcall{f: f},
|
||||
})
|
||||
curr.a.mu.Unlock()
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
case *capnp.PipelineClient:
|
||||
p := (*capnp.Pipeline)(curr)
|
||||
ans := p.Answer()
|
||||
transform := p.Transform()
|
||||
if capnp.IsFixedAnswer(ans) {
|
||||
s, err := ans.Struct()
|
||||
client = clientFromResolution(transform, s.ToPtr(), err)
|
||||
continue
|
||||
}
|
||||
switch ans := ans.(type) {
|
||||
case *fulfiller.Fulfiller:
|
||||
ap := ans.Peek()
|
||||
if ap == nil {
|
||||
break dig
|
||||
}
|
||||
s, err := ap.Struct()
|
||||
client = clientFromResolution(transform, s.ToPtr(), err)
|
||||
case *question:
|
||||
if ans.conn != c {
|
||||
// This doesn't use our conn's lock, so it is safe to call.
|
||||
return ans.PipelineCall(transform, cl)
|
||||
}
|
||||
return ans.lockedPipelineCall(transform, cl)
|
||||
default:
|
||||
break dig
|
||||
}
|
||||
default:
|
||||
break dig
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(light): Add a CallOption that signals to bypass sync.
|
||||
// The above hack works in *most* cases.
|
||||
//
|
||||
// If your code is deadlocking here, you've hit the edge of the
|
||||
// compromise between these three goals:
|
||||
// 1) Package capnp is loosely coupled with package rpc
|
||||
// 2) Arbitrary implementations of Client may exist
|
||||
// 3) Local E-order must be preserved
|
||||
//
|
||||
// #3 is the one that creates a deadlock, since application code must
|
||||
// acquire the connection mutex to preserve order of delivery. You
|
||||
// can't really overcome this without breaking one of the first two
|
||||
// constraints.
|
||||
//
|
||||
// To avoid #2 as much as possible, implementing Client is discouraged
|
||||
// by several docs.
|
||||
return client.Call(cl)
|
||||
}
|
||||
|
||||
// descriptorForClient fills desc for client, adding it to the export
|
||||
// table if necessary. The caller must be holding onto c.mu.
|
||||
func (c *Conn) descriptorForClient(desc rpccapnp.CapDescriptor, client capnp.Client) error {
|
||||
dig:
|
||||
for client := client; ; {
|
||||
switch ct := client.(type) {
|
||||
case *importClient:
|
||||
if ct.conn != c {
|
||||
break dig
|
||||
}
|
||||
desc.SetReceiverHosted(uint32(ct.id))
|
||||
return nil
|
||||
case *fulfiller.EmbargoClient:
|
||||
client = ct.Client()
|
||||
if client == nil {
|
||||
break dig
|
||||
}
|
||||
case *refcount.Ref:
|
||||
client = ct.Client()
|
||||
case *embargoClient:
|
||||
ct.mu.RLock()
|
||||
ok := ct.isPassthrough()
|
||||
ct.mu.RUnlock()
|
||||
if !ok {
|
||||
break dig
|
||||
}
|
||||
client = ct.client
|
||||
case *queueClient:
|
||||
ct.mu.RLock()
|
||||
ok := ct.isPassthrough()
|
||||
ct.mu.RUnlock()
|
||||
if !ok {
|
||||
break dig
|
||||
}
|
||||
client = ct.client
|
||||
case *localAnswerClient:
|
||||
ct.a.mu.RLock()
|
||||
obj, err, done := ct.a.obj, ct.a.err, ct.a.done
|
||||
ct.a.mu.RUnlock()
|
||||
if !done {
|
||||
break dig
|
||||
}
|
||||
client = clientFromResolution(ct.transform, obj, err)
|
||||
case *capnp.PipelineClient:
|
||||
p := (*capnp.Pipeline)(ct)
|
||||
ans := p.Answer()
|
||||
transform := p.Transform()
|
||||
if capnp.IsFixedAnswer(ans) {
|
||||
s, err := ans.Struct()
|
||||
client = clientFromResolution(transform, s.ToPtr(), err)
|
||||
continue
|
||||
}
|
||||
switch ans := ans.(type) {
|
||||
case *fulfiller.Fulfiller:
|
||||
ap := ans.Peek()
|
||||
if ap == nil {
|
||||
break dig
|
||||
}
|
||||
s, err := ap.Struct()
|
||||
client = clientFromResolution(transform, s.ToPtr(), err)
|
||||
case *question:
|
||||
ans.mu.RLock()
|
||||
obj, err, state := ans.obj, ans.err, ans.state
|
||||
ans.mu.RUnlock()
|
||||
if state != questionInProgress {
|
||||
client = clientFromResolution(transform, obj, err)
|
||||
continue
|
||||
}
|
||||
if ans.conn != c {
|
||||
break dig
|
||||
}
|
||||
a, err := desc.NewReceiverAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.SetQuestionId(uint32(ans.id))
|
||||
err = transformToPromisedAnswer(desc.Segment(), a, p.Transform())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
break dig
|
||||
}
|
||||
default:
|
||||
break dig
|
||||
}
|
||||
}
|
||||
|
||||
id := c.addExport(client)
|
||||
desc.SetSenderHosted(uint32(id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// isSameClient reports whether c and d refer to the same capability.
|
||||
func isSameClient(c, d capnp.Client) bool {
|
||||
norm := func(client capnp.Client) capnp.Client {
|
||||
for {
|
||||
switch curr := client.(type) {
|
||||
case *fulfiller.EmbargoClient:
|
||||
client = curr.Client()
|
||||
if client == nil {
|
||||
return curr
|
||||
}
|
||||
case *refcount.Ref:
|
||||
client = curr.Client()
|
||||
case *embargoClient:
|
||||
curr.mu.RLock()
|
||||
ok := curr.isPassthrough()
|
||||
curr.mu.RUnlock()
|
||||
if !ok {
|
||||
return curr
|
||||
}
|
||||
client = curr.client
|
||||
case *queueClient:
|
||||
curr.mu.RLock()
|
||||
ok := curr.isPassthrough()
|
||||
curr.mu.RUnlock()
|
||||
if !ok {
|
||||
return curr
|
||||
}
|
||||
client = curr.client
|
||||
case *localAnswerClient:
|
||||
curr.a.mu.RLock()
|
||||
obj, err, done := curr.a.obj, curr.a.err, curr.a.done
|
||||
curr.a.mu.RUnlock()
|
||||
if !done {
|
||||
return curr
|
||||
}
|
||||
client = clientFromResolution(curr.transform, obj, err)
|
||||
case *capnp.PipelineClient:
|
||||
p := (*capnp.Pipeline)(curr)
|
||||
ans := p.Answer()
|
||||
if capnp.IsFixedAnswer(ans) {
|
||||
s, err := ans.Struct()
|
||||
client = clientFromResolution(p.Transform(), s.ToPtr(), err)
|
||||
continue
|
||||
}
|
||||
switch ans := ans.(type) {
|
||||
case *fulfiller.Fulfiller:
|
||||
ap := ans.Peek()
|
||||
if ap == nil {
|
||||
return curr
|
||||
}
|
||||
s, err := ap.Struct()
|
||||
client = clientFromResolution(p.Transform(), s.ToPtr(), err)
|
||||
case *question:
|
||||
ans.mu.RLock()
|
||||
obj, err, state := ans.obj, ans.err, ans.state
|
||||
ans.mu.RUnlock()
|
||||
if state != questionResolved {
|
||||
return curr
|
||||
}
|
||||
client = clientFromResolution(p.Transform(), obj, err)
|
||||
default:
|
||||
return curr
|
||||
}
|
||||
default:
|
||||
return curr
|
||||
}
|
||||
}
|
||||
}
|
||||
return norm(c) == norm(d)
|
||||
}
|
||||
|
||||
// isImport returns the underlying import if client represents an import
|
||||
// or nil otherwise.
|
||||
func isImport(client capnp.Client) *importClient {
|
||||
for {
|
||||
switch curr := client.(type) {
|
||||
case *importClient:
|
||||
return curr
|
||||
case *fulfiller.EmbargoClient:
|
||||
client = curr.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
case *refcount.Ref:
|
||||
client = curr.Client()
|
||||
case *embargoClient:
|
||||
curr.mu.RLock()
|
||||
ok := curr.isPassthrough()
|
||||
curr.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
client = curr.client
|
||||
case *queueClient:
|
||||
curr.mu.RLock()
|
||||
ok := curr.isPassthrough()
|
||||
curr.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
client = curr.client
|
||||
case *localAnswerClient:
|
||||
curr.a.mu.RLock()
|
||||
obj, err, done := curr.a.obj, curr.a.err, curr.a.done
|
||||
curr.a.mu.RUnlock()
|
||||
if !done {
|
||||
return nil
|
||||
}
|
||||
client = clientFromResolution(curr.transform, obj, err)
|
||||
case *capnp.PipelineClient:
|
||||
p := (*capnp.Pipeline)(curr)
|
||||
ans := p.Answer()
|
||||
if capnp.IsFixedAnswer(ans) {
|
||||
s, err := ans.Struct()
|
||||
client = clientFromResolution(p.Transform(), s.ToPtr(), err)
|
||||
continue
|
||||
}
|
||||
switch ans := ans.(type) {
|
||||
case *fulfiller.Fulfiller:
|
||||
ap := ans.Peek()
|
||||
if ap == nil {
|
||||
return nil
|
||||
}
|
||||
s, err := ap.Struct()
|
||||
client = clientFromResolution(p.Transform(), s.ToPtr(), err)
|
||||
case *question:
|
||||
ans.mu.RLock()
|
||||
obj, err, state := ans.obj, ans.err, ans.state
|
||||
ans.mu.RUnlock()
|
||||
if state != questionResolved {
|
||||
return nil
|
||||
}
|
||||
client = clientFromResolution(p.Transform(), obj, err)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
47
vendor/zombiezen.com/go/capnproto2/rpc/issue3_test.go
generated
vendored
Normal file
47
vendor/zombiezen.com/go/capnproto2/rpc/issue3_test.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
)
|
||||
|
||||
func TestIssue3(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{t}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
echoSrv := testcapnp.Echoer_ServerToClient(new(SideEffectEchoer))
|
||||
d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.Echoer{Client: c.Bootstrap(ctx)}
|
||||
localCap := testcapnp.CallOrder_ServerToClient(new(CallOrder))
|
||||
echo := client.Echo(ctx, func(p testcapnp.Echoer_echo_Params) error {
|
||||
return p.SetCap(localCap)
|
||||
})
|
||||
|
||||
// This should not deadlock.
|
||||
_, err := echo.Struct()
|
||||
if err != nil {
|
||||
t.Error("Echo error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
type SideEffectEchoer struct {
|
||||
CallOrder
|
||||
}
|
||||
|
||||
func (*SideEffectEchoer) Echo(call testcapnp.Echoer_echo) error {
|
||||
call.Params.Cap().GetCallSequence(call.Ctx, nil)
|
||||
call.Results.SetCap(call.Params.Cap())
|
||||
return nil
|
||||
}
|
49
vendor/zombiezen.com/go/capnproto2/rpc/log.go
generated
vendored
Normal file
49
vendor/zombiezen.com/go/capnproto2/rpc/log.go
generated
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// A Logger records diagnostic information and errors that are not
|
||||
// associated with a call. The arguments passed into a log call are
|
||||
// interpreted like fmt.Printf. They should not be held onto past the
|
||||
// call's return.
|
||||
type Logger interface {
|
||||
Infof(ctx context.Context, format string, args ...interface{})
|
||||
Errorf(ctx context.Context, format string, args ...interface{})
|
||||
}
|
||||
|
||||
type defaultLogger struct{}
|
||||
|
||||
func (defaultLogger) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||
log.Printf("rpc: "+format, args...)
|
||||
}
|
||||
|
||||
func (defaultLogger) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||
log.Printf("rpc: "+format, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) infof(format string, args ...interface{}) {
|
||||
if c.log == nil {
|
||||
return
|
||||
}
|
||||
c.log.Infof(c.bg, format, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) errorf(format string, args ...interface{}) {
|
||||
if c.log == nil {
|
||||
return
|
||||
}
|
||||
c.log.Errorf(c.bg, format, args...)
|
||||
}
|
||||
|
||||
// ConnLog sets the connection's log to the given Logger, which may be
|
||||
// nil to disable logging. By default, logs are sent to the standard
|
||||
// log package.
|
||||
func ConnLog(log Logger) ConnOption {
|
||||
return ConnOption{func(c *connParams) {
|
||||
c.log = log
|
||||
}}
|
||||
}
|
60
vendor/zombiezen.com/go/capnproto2/rpc/promise_test.go
generated
vendored
Normal file
60
vendor/zombiezen.com/go/capnproto2/rpc/promise_test.go
generated
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func TestPromisedCapability(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{t}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
delay := make(chan struct{})
|
||||
echoSrv := testcapnp.Echoer_ServerToClient(&DelayEchoer{delay: delay})
|
||||
d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.Echoer{Client: c.Bootstrap(ctx)}
|
||||
|
||||
echo := client.Echo(ctx, func(p testcapnp.Echoer_echo_Params) error {
|
||||
return p.SetCap(testcapnp.CallOrder{Client: client.Client})
|
||||
})
|
||||
pipeline := echo.Cap()
|
||||
call0 := callseq(ctx, pipeline.Client, 0)
|
||||
call1 := callseq(ctx, pipeline.Client, 1)
|
||||
close(delay)
|
||||
|
||||
check := func(promise testcapnp.CallOrder_getCallSequence_Results_Promise, n uint32) {
|
||||
r, err := promise.Struct()
|
||||
if err != nil {
|
||||
t.Errorf("call%d error: %v", n, err)
|
||||
}
|
||||
if r.N() != n {
|
||||
t.Errorf("call%d = %d; want %d", n, r.N(), n)
|
||||
}
|
||||
}
|
||||
check(call0, 0)
|
||||
check(call1, 1)
|
||||
}
|
||||
|
||||
type DelayEchoer struct {
|
||||
Echoer
|
||||
delay chan struct{}
|
||||
}
|
||||
|
||||
func (de *DelayEchoer) Echo(call testcapnp.Echoer_echo) error {
|
||||
server.Ack(call.Options)
|
||||
<-de.delay
|
||||
return de.Echoer.Echo(call)
|
||||
}
|
442
vendor/zombiezen.com/go/capnproto2/rpc/question.go
generated
vendored
Normal file
442
vendor/zombiezen.com/go/capnproto2/rpc/question.go
generated
vendored
Normal file
@@ -0,0 +1,442 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/internal/fulfiller"
|
||||
"zombiezen.com/go/capnproto2/internal/queue"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// newQuestion creates a new question with an unassigned ID.
|
||||
func (c *Conn) newQuestion(ctx context.Context, method *capnp.Method) *question {
|
||||
id := questionID(c.questionID.next())
|
||||
q := &question{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
method: method,
|
||||
resolved: make(chan struct{}),
|
||||
id: id,
|
||||
}
|
||||
// TODO(light): populate paramCaps
|
||||
if int(id) == len(c.questions) {
|
||||
c.questions = append(c.questions, q)
|
||||
} else {
|
||||
c.questions[id] = q
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (c *Conn) findQuestion(id questionID) *question {
|
||||
if int(id) >= len(c.questions) {
|
||||
return nil
|
||||
}
|
||||
return c.questions[id]
|
||||
}
|
||||
|
||||
func (c *Conn) popQuestion(id questionID) *question {
|
||||
q := c.findQuestion(id)
|
||||
if q == nil {
|
||||
return nil
|
||||
}
|
||||
c.questions[id] = nil
|
||||
c.questionID.remove(uint32(id))
|
||||
return q
|
||||
}
|
||||
|
||||
type question struct {
|
||||
id questionID
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
method *capnp.Method // nil if this is bootstrap
|
||||
paramCaps []exportID
|
||||
resolved chan struct{}
|
||||
|
||||
// Protected by conn.mu
|
||||
derived [][]capnp.PipelineOp
|
||||
|
||||
// Fields below are protected by mu.
|
||||
mu sync.RWMutex
|
||||
obj capnp.Ptr
|
||||
err error
|
||||
state questionState
|
||||
}
|
||||
|
||||
type questionState uint8
|
||||
|
||||
// Question states
|
||||
const (
|
||||
questionInProgress questionState = iota
|
||||
questionResolved
|
||||
questionCanceled
|
||||
)
|
||||
|
||||
// start signals that the question has been sent.
|
||||
func (q *question) start() {
|
||||
go func() {
|
||||
select {
|
||||
case <-q.resolved:
|
||||
// Resolved naturally, nothing to do.
|
||||
case <-q.conn.bg.Done():
|
||||
case <-q.ctx.Done():
|
||||
select {
|
||||
case <-q.resolved:
|
||||
case <-q.conn.bg.Done():
|
||||
case <-q.conn.mu:
|
||||
if err := q.conn.startWork(); err != nil {
|
||||
// teardown calls cancel.
|
||||
q.conn.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if q.cancel(q.ctx.Err()) {
|
||||
q.conn.sendMessage(newFinishMessage(nil, q.id, true /* release */))
|
||||
}
|
||||
q.conn.workers.Done()
|
||||
q.conn.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// fulfill is called to resolve a question successfully.
|
||||
// The caller must be holding onto q.conn.mu.
|
||||
func (q *question) fulfill(obj capnp.Ptr) {
|
||||
var ctab []capnp.Client
|
||||
if obj.IsValid() {
|
||||
ctab = obj.Segment().Message().CapTable
|
||||
}
|
||||
visited := make([]bool, len(ctab))
|
||||
for _, d := range q.derived {
|
||||
tgt, err := capnp.TransformPtr(obj, d)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
in := tgt.Interface()
|
||||
if !in.IsValid() {
|
||||
continue
|
||||
}
|
||||
if ic := isImport(in.Client()); ic != nil && ic.conn == q.conn {
|
||||
// Imported from remote vat. Don't need to disembargo.
|
||||
continue
|
||||
}
|
||||
cn := in.Capability()
|
||||
if visited[cn] {
|
||||
continue
|
||||
}
|
||||
visited[cn] = true
|
||||
id, e := q.conn.newEmbargo()
|
||||
ctab[cn] = newEmbargoClient(ctab[cn], e, q.conn.bg.Done())
|
||||
m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id)
|
||||
dis, _ := m.Disembargo()
|
||||
mt, _ := dis.NewTarget()
|
||||
pa, _ := mt.NewPromisedAnswer()
|
||||
pa.SetQuestionId(uint32(q.id))
|
||||
transformToPromisedAnswer(m.Segment(), pa, d)
|
||||
mt.SetPromisedAnswer(pa)
|
||||
|
||||
select {
|
||||
case q.conn.out <- m:
|
||||
case <-q.conn.bg.Done():
|
||||
// TODO(soon): perhaps just drop all embargoes in this case?
|
||||
}
|
||||
}
|
||||
|
||||
q.mu.Lock()
|
||||
if q.state != questionInProgress {
|
||||
panic("question.fulfill called more than once")
|
||||
}
|
||||
q.obj, q.state = obj, questionResolved
|
||||
close(q.resolved)
|
||||
q.mu.Unlock()
|
||||
}
|
||||
|
||||
// reject is called to resolve a question with failure.
|
||||
// The caller must be holding onto q.conn.mu.
|
||||
func (q *question) reject(err error) {
|
||||
if err == nil {
|
||||
panic("question.reject called with nil")
|
||||
}
|
||||
q.mu.Lock()
|
||||
if q.state != questionInProgress {
|
||||
panic("question.reject called more than once")
|
||||
}
|
||||
q.err = err
|
||||
q.state = questionResolved
|
||||
close(q.resolved)
|
||||
q.mu.Unlock()
|
||||
}
|
||||
|
||||
// cancel is called to resolve a question with cancellation.
|
||||
// The caller must be holding onto q.conn.mu.
|
||||
func (q *question) cancel(err error) bool {
|
||||
if err == nil {
|
||||
panic("question.cancel called with nil")
|
||||
}
|
||||
q.mu.Lock()
|
||||
canceled := q.state == questionInProgress
|
||||
if canceled {
|
||||
q.err = err
|
||||
q.state = questionCanceled
|
||||
close(q.resolved)
|
||||
}
|
||||
q.mu.Unlock()
|
||||
return canceled
|
||||
}
|
||||
|
||||
// addPromise records a returned capability as being used for a call.
|
||||
// This is needed for determining embargoes upon resolution. The
|
||||
// caller must be holding onto q.conn.mu.
|
||||
func (q *question) addPromise(transform []capnp.PipelineOp) {
|
||||
for _, d := range q.derived {
|
||||
if transformsEqual(transform, d) {
|
||||
return
|
||||
}
|
||||
}
|
||||
q.derived = append(q.derived, transform)
|
||||
}
|
||||
|
||||
func transformsEqual(t, u []capnp.PipelineOp) bool {
|
||||
if len(t) != len(u) {
|
||||
return false
|
||||
}
|
||||
for i := range t {
|
||||
if t[i].Field != u[i].Field {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (q *question) Struct() (capnp.Struct, error) {
|
||||
select {
|
||||
case <-q.resolved:
|
||||
case <-q.conn.bg.Done():
|
||||
return capnp.Struct{}, ErrConnClosed
|
||||
}
|
||||
q.mu.RLock()
|
||||
s, err := q.obj.Struct(), q.err
|
||||
q.mu.RUnlock()
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (q *question) PipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
|
||||
select {
|
||||
case <-q.conn.mu:
|
||||
if err := q.conn.startWork(); err != nil {
|
||||
q.conn.mu.Unlock()
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
case <-ccall.Ctx.Done():
|
||||
return capnp.ErrorAnswer(ccall.Ctx.Err())
|
||||
}
|
||||
ans := q.lockedPipelineCall(transform, ccall)
|
||||
q.conn.workers.Done()
|
||||
q.conn.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
// lockedPipelineCall is equivalent to PipelineCall but assumes that the
|
||||
// caller is already holding onto q.conn.mu.
|
||||
func (q *question) lockedPipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
|
||||
if q.conn.findQuestion(q.id) != q {
|
||||
// Question has been finished. The call should happen as if it is
|
||||
// back in application code.
|
||||
q.mu.RLock()
|
||||
obj, err, state := q.obj, q.err, q.state
|
||||
q.mu.RUnlock()
|
||||
if state == questionInProgress {
|
||||
panic("question popped but not done")
|
||||
}
|
||||
client := clientFromResolution(transform, obj, err)
|
||||
return q.conn.lockedCall(client, ccall)
|
||||
}
|
||||
|
||||
pipeq := q.conn.newQuestion(ccall.Ctx, &ccall.Method)
|
||||
msg := newMessage(nil)
|
||||
msgCall, _ := msg.NewCall()
|
||||
msgCall.SetQuestionId(uint32(pipeq.id))
|
||||
msgCall.SetInterfaceId(ccall.Method.InterfaceID)
|
||||
msgCall.SetMethodId(ccall.Method.MethodID)
|
||||
target, _ := msgCall.NewTarget()
|
||||
a, _ := target.NewPromisedAnswer()
|
||||
a.SetQuestionId(uint32(q.id))
|
||||
err := transformToPromisedAnswer(a.Segment(), a, transform)
|
||||
if err != nil {
|
||||
q.conn.popQuestion(pipeq.id)
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
payload, _ := msgCall.NewParams()
|
||||
if err := q.conn.fillParams(payload, ccall); err != nil {
|
||||
q.conn.popQuestion(q.id)
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case q.conn.out <- msg:
|
||||
case <-ccall.Ctx.Done():
|
||||
q.conn.popQuestion(pipeq.id)
|
||||
return capnp.ErrorAnswer(ccall.Ctx.Err())
|
||||
case <-q.conn.bg.Done():
|
||||
q.conn.popQuestion(pipeq.id)
|
||||
return capnp.ErrorAnswer(ErrConnClosed)
|
||||
}
|
||||
q.addPromise(transform)
|
||||
pipeq.start()
|
||||
return pipeq
|
||||
}
|
||||
|
||||
func (q *question) PipelineClose(transform []capnp.PipelineOp) error {
|
||||
<-q.resolved
|
||||
q.mu.RLock()
|
||||
obj, err := q.obj, q.err
|
||||
q.mu.RUnlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
x, err := capnp.TransformPtr(obj, transform)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c := x.Interface().Client()
|
||||
if c == nil {
|
||||
return capnp.ErrNullClient
|
||||
}
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// embargoClient is a client that waits until an embargo signal is
|
||||
// received to deliver calls.
|
||||
type embargoClient struct {
|
||||
cancel <-chan struct{}
|
||||
client capnp.Client
|
||||
embargo embargo
|
||||
|
||||
mu sync.RWMutex
|
||||
q queue.Queue
|
||||
calls ecallList
|
||||
}
|
||||
|
||||
func newEmbargoClient(client capnp.Client, e embargo, cancel <-chan struct{}) *embargoClient {
|
||||
ec := &embargoClient{
|
||||
client: client,
|
||||
embargo: e,
|
||||
cancel: cancel,
|
||||
calls: make(ecallList, callQueueSize),
|
||||
}
|
||||
ec.q.Init(ec.calls, 0)
|
||||
go ec.flushQueue()
|
||||
return ec
|
||||
}
|
||||
|
||||
func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer {
|
||||
f := new(fulfiller.Fulfiller)
|
||||
cl, err := cl.Copy(nil)
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
i := ec.q.Push()
|
||||
if i == -1 {
|
||||
return capnp.ErrorAnswer(errQueueFull)
|
||||
}
|
||||
ec.calls[i] = ecall{cl, f}
|
||||
return f
|
||||
}
|
||||
|
||||
func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer {
|
||||
// Fast path: queue is flushed.
|
||||
ec.mu.RLock()
|
||||
ok := ec.isPassthrough()
|
||||
ec.mu.RUnlock()
|
||||
if ok {
|
||||
return ec.client.Call(cl)
|
||||
}
|
||||
|
||||
ec.mu.Lock()
|
||||
if ec.isPassthrough() {
|
||||
ec.mu.Unlock()
|
||||
return ec.client.Call(cl)
|
||||
}
|
||||
ans := ec.push(cl)
|
||||
ec.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
func (ec *embargoClient) tryQueue(cl *capnp.Call) capnp.Answer {
|
||||
ec.mu.Lock()
|
||||
if ec.isPassthrough() {
|
||||
ec.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
ans := ec.push(cl)
|
||||
ec.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
func (ec *embargoClient) isPassthrough() bool {
|
||||
select {
|
||||
case <-ec.embargo:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return ec.q.Len() == 0
|
||||
}
|
||||
|
||||
func (ec *embargoClient) Close() error {
|
||||
ec.mu.Lock()
|
||||
for ; ec.q.Len() > 0; ec.q.Pop() {
|
||||
c := ec.calls[ec.q.Front()]
|
||||
c.f.Reject(errQueueCallCancel)
|
||||
}
|
||||
ec.mu.Unlock()
|
||||
return ec.client.Close()
|
||||
}
|
||||
|
||||
// flushQueue is run in its own goroutine.
|
||||
func (ec *embargoClient) flushQueue() {
|
||||
select {
|
||||
case <-ec.embargo:
|
||||
case <-ec.cancel:
|
||||
ec.mu.Lock()
|
||||
for ec.q.Len() > 0 {
|
||||
ec.q.Pop()
|
||||
}
|
||||
ec.mu.Unlock()
|
||||
return
|
||||
}
|
||||
var c ecall
|
||||
ec.mu.RLock()
|
||||
if i := ec.q.Front(); i != -1 {
|
||||
c = ec.calls[i]
|
||||
}
|
||||
ec.mu.RUnlock()
|
||||
for c.call != nil {
|
||||
ans := ec.client.Call(c.call)
|
||||
go joinFulfiller(c.f, ans)
|
||||
|
||||
ec.mu.Lock()
|
||||
ec.q.Pop()
|
||||
if i := ec.q.Front(); i != -1 {
|
||||
c = ec.calls[i]
|
||||
} else {
|
||||
c = ecall{}
|
||||
}
|
||||
ec.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
type ecall struct {
|
||||
call *capnp.Call
|
||||
f *fulfiller.Fulfiller
|
||||
}
|
||||
|
||||
type ecallList []ecall
|
||||
|
||||
func (el ecallList) Len() int {
|
||||
return len(el)
|
||||
}
|
||||
|
||||
func (el ecallList) Clear(i int) {
|
||||
el[i] = ecall{}
|
||||
}
|
145
vendor/zombiezen.com/go/capnproto2/rpc/release_test.go
generated
vendored
Normal file
145
vendor/zombiezen.com/go/capnproto2/rpc/release_test.go
generated
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/testcapnp"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func TestRelease(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{t}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
hf := new(HandleFactory)
|
||||
d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.HandleFactory{Client: c.Bootstrap(ctx)}
|
||||
r, err := client.NewHandle(ctx, nil).Struct()
|
||||
if err != nil {
|
||||
t.Fatal("NewHandle:", err)
|
||||
}
|
||||
handle := r.Handle()
|
||||
if n := hf.numHandles(); n != 1 {
|
||||
t.Fatalf("numHandles = %d; want 1", n)
|
||||
}
|
||||
|
||||
if err := handle.Client.Close(); err != nil {
|
||||
t.Error("handle.Client.Close():", err)
|
||||
}
|
||||
flushConn(ctx, c)
|
||||
|
||||
if n := hf.numHandles(); n != 0 {
|
||||
t.Errorf("numHandles = %d; want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseAlias(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
log := testLogger{t}
|
||||
c := rpc.NewConn(p, rpc.ConnLog(log))
|
||||
hf := singletonHandleFactory()
|
||||
d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client), rpc.ConnLog(log))
|
||||
defer d.Wait()
|
||||
defer c.Close()
|
||||
client := testcapnp.HandleFactory{Client: c.Bootstrap(ctx)}
|
||||
r1, err := client.NewHandle(ctx, nil).Struct()
|
||||
if err != nil {
|
||||
t.Fatal("NewHandle #1:", err)
|
||||
}
|
||||
handle1 := r1.Handle()
|
||||
r2, err := client.NewHandle(ctx, nil).Struct()
|
||||
if err != nil {
|
||||
t.Fatal("NewHandle #2:", err)
|
||||
}
|
||||
handle2 := r2.Handle()
|
||||
if n := hf.numHandles(); n != 1 {
|
||||
t.Fatalf("after creation, numHandles = %d; want 1", n)
|
||||
}
|
||||
|
||||
if err := handle1.Client.Close(); err != nil {
|
||||
t.Error("handle1.Client.Close():", err)
|
||||
}
|
||||
flushConn(ctx, c)
|
||||
if n := hf.numHandles(); n != 1 {
|
||||
t.Errorf("after handle1.Client.Close(), numHandles = %d; want 1", n)
|
||||
}
|
||||
if err := handle2.Client.Close(); err != nil {
|
||||
t.Error("handle2.Client.Close():", err)
|
||||
}
|
||||
flushConn(ctx, c)
|
||||
if n := hf.numHandles(); n != 0 {
|
||||
t.Errorf("after handle1.Close() and handle2.Close(), numHandles = %d; want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func flushConn(ctx context.Context, c *rpc.Conn) {
|
||||
// discard result
|
||||
c.Bootstrap(ctx).Call(&capnp.Call{
|
||||
Ctx: ctx,
|
||||
Method: capnp.Method{InterfaceID: 0xdeadbeef, MethodID: 42},
|
||||
}).Struct()
|
||||
}
|
||||
|
||||
type Handle struct {
|
||||
f *HandleFactory
|
||||
}
|
||||
|
||||
func (h Handle) Close() error {
|
||||
h.f.mu.Lock()
|
||||
h.f.n--
|
||||
h.f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
type HandleFactory struct {
|
||||
n int
|
||||
mu sync.Mutex
|
||||
singleton testcapnp.Handle
|
||||
}
|
||||
|
||||
func singletonHandleFactory() *HandleFactory {
|
||||
hf := new(HandleFactory)
|
||||
hf.singleton = testcapnp.Handle_ServerToClient(&Handle{f: hf})
|
||||
return hf
|
||||
}
|
||||
|
||||
func (hf *HandleFactory) NewHandle(call testcapnp.HandleFactory_newHandle) error {
|
||||
server.Ack(call.Options)
|
||||
if hf.singleton.Client == nil {
|
||||
hf.mu.Lock()
|
||||
hf.n++
|
||||
hf.mu.Unlock()
|
||||
call.Results.SetHandle(testcapnp.Handle_ServerToClient(&Handle{f: hf}))
|
||||
} else {
|
||||
hf.mu.Lock()
|
||||
hf.n = 1
|
||||
hf.mu.Unlock()
|
||||
call.Results.SetHandle(hf.singleton)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hf *HandleFactory) numHandles() int {
|
||||
hf.mu.Lock()
|
||||
n := hf.n
|
||||
hf.mu.Unlock()
|
||||
return n
|
||||
}
|
913
vendor/zombiezen.com/go/capnproto2/rpc/rpc.go
generated
vendored
Normal file
913
vendor/zombiezen.com/go/capnproto2/rpc/rpc.go
generated
vendored
Normal file
@@ -0,0 +1,913 @@
|
||||
// Package rpc implements the Cap'n Proto RPC protocol.
|
||||
package rpc // import "zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/refcount"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// A Conn is a connection to another Cap'n Proto vat.
|
||||
// It is safe to use from multiple goroutines.
|
||||
type Conn struct {
|
||||
transport Transport
|
||||
log Logger
|
||||
mainFunc func(context.Context) (capnp.Client, error)
|
||||
mainCloser io.Closer
|
||||
death chan struct{} // closed after state is connDead
|
||||
|
||||
out chan rpccapnp.Message
|
||||
|
||||
bg context.Context
|
||||
bgCancel context.CancelFunc
|
||||
workers sync.WaitGroup
|
||||
|
||||
// Mutable state protected by stateMu
|
||||
// If you need to acquire both mu and stateMu, acquire mu first.
|
||||
stateMu sync.RWMutex
|
||||
state connState
|
||||
closeErr error
|
||||
|
||||
// Mutable state protected by mu
|
||||
mu chanMutex
|
||||
questions []*question
|
||||
questionID idgen
|
||||
exports []*export
|
||||
exportID idgen
|
||||
embargoes []chan<- struct{}
|
||||
embargoID idgen
|
||||
answers map[answerID]*answer
|
||||
imports map[importID]*impent
|
||||
}
|
||||
|
||||
type connParams struct {
|
||||
log Logger
|
||||
mainFunc func(context.Context) (capnp.Client, error)
|
||||
mainCloser io.Closer
|
||||
sendBufferSize int
|
||||
}
|
||||
|
||||
// A ConnOption is an option for opening a connection.
|
||||
type ConnOption struct {
|
||||
f func(*connParams)
|
||||
}
|
||||
|
||||
// MainInterface specifies that the connection should use client when
|
||||
// receiving bootstrap messages. By default, all bootstrap messages will
|
||||
// fail. The client will be closed when the connection is closed.
|
||||
func MainInterface(client capnp.Client) ConnOption {
|
||||
rc, ref1 := refcount.New(client)
|
||||
ref2 := rc.Ref()
|
||||
return ConnOption{func(c *connParams) {
|
||||
c.mainFunc = func(ctx context.Context) (capnp.Client, error) {
|
||||
return ref1, nil
|
||||
}
|
||||
c.mainCloser = ref2
|
||||
}}
|
||||
}
|
||||
|
||||
// BootstrapFunc specifies the function to call to create a capability
|
||||
// for handling bootstrap messages. This function should not make any
|
||||
// RPCs or block.
|
||||
func BootstrapFunc(f func(context.Context) (capnp.Client, error)) ConnOption {
|
||||
return ConnOption{func(c *connParams) {
|
||||
c.mainFunc = f
|
||||
}}
|
||||
}
|
||||
|
||||
// SendBufferSize sets the number of outgoing messages to buffer on the
|
||||
// connection. This is in addition to whatever buffering the connection's
|
||||
// transport performs.
|
||||
func SendBufferSize(numMsgs int) ConnOption {
|
||||
return ConnOption{func(c *connParams) {
|
||||
c.sendBufferSize = numMsgs
|
||||
}}
|
||||
}
|
||||
|
||||
// NewConn creates a new connection that communicates on c.
|
||||
// Closing the connection will cause c to be closed.
|
||||
func NewConn(t Transport, options ...ConnOption) *Conn {
|
||||
p := &connParams{
|
||||
log: defaultLogger{},
|
||||
sendBufferSize: 4,
|
||||
}
|
||||
for _, o := range options {
|
||||
o.f(p)
|
||||
}
|
||||
|
||||
conn := &Conn{
|
||||
transport: t,
|
||||
out: make(chan rpccapnp.Message, p.sendBufferSize),
|
||||
mainFunc: p.mainFunc,
|
||||
mainCloser: p.mainCloser,
|
||||
log: p.log,
|
||||
death: make(chan struct{}),
|
||||
mu: newChanMutex(),
|
||||
}
|
||||
conn.bg, conn.bgCancel = context.WithCancel(context.Background())
|
||||
conn.workers.Add(2)
|
||||
go conn.dispatchRecv()
|
||||
go conn.dispatchSend()
|
||||
return conn
|
||||
}
|
||||
|
||||
// Wait waits until the connection is closed or aborted by the remote vat.
|
||||
// Wait will always return an error, usually ErrConnClosed or of type Abort.
|
||||
func (c *Conn) Wait() error {
|
||||
<-c.Done()
|
||||
return c.Err()
|
||||
}
|
||||
|
||||
// Done is a channel that is closed once the connection is fully shut down.
|
||||
func (c *Conn) Done() <-chan struct{} {
|
||||
return c.death
|
||||
}
|
||||
|
||||
// Err returns the error that caused the connection to close.
|
||||
// Err returns nil before Done is closed.
|
||||
func (c *Conn) Err() error {
|
||||
c.stateMu.RLock()
|
||||
var err error
|
||||
if c.state != connDead {
|
||||
err = c.closeErr
|
||||
}
|
||||
c.stateMu.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the connection and the underlying transport.
|
||||
func (c *Conn) Close() error {
|
||||
c.stateMu.Lock()
|
||||
alive := c.state == connAlive
|
||||
if alive {
|
||||
c.bgCancel()
|
||||
c.closeErr = ErrConnClosed
|
||||
c.state = connDying
|
||||
}
|
||||
c.stateMu.Unlock()
|
||||
if !alive {
|
||||
return ErrConnClosed
|
||||
}
|
||||
c.teardown(newAbortMessage(nil, errShutdown))
|
||||
c.stateMu.RLock()
|
||||
err := c.closeErr
|
||||
c.stateMu.RUnlock()
|
||||
if err != ErrConnClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdown cancels the background context and sets closeErr to e.
|
||||
// No abort message will be sent on the transport. After shutdown
|
||||
// returns, the Conn will be in the dying or dead state. Calling
|
||||
// shutdown on a dying or dead Conn is a no-op.
|
||||
func (c *Conn) shutdown(e error) {
|
||||
c.stateMu.Lock()
|
||||
if c.state == connAlive {
|
||||
c.bgCancel()
|
||||
c.closeErr = e
|
||||
c.state = connDying
|
||||
go c.teardown(rpccapnp.Message{})
|
||||
}
|
||||
c.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// abort cancels the background context, sets closeErr to e, and queues
|
||||
// an abort message to be sent on the transport before the Conn goes
|
||||
// into the dead state. After abort returns, the Conn will be in the
|
||||
// dying or dead state. Calling abort on a dying or dead Conn is a
|
||||
// no-op.
|
||||
func (c *Conn) abort(e error) {
|
||||
c.stateMu.Lock()
|
||||
if c.state == connAlive {
|
||||
c.bgCancel()
|
||||
c.closeErr = e
|
||||
c.state = connDying
|
||||
go c.teardown(newAbortMessage(nil, e))
|
||||
}
|
||||
c.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// startWork adds a new worker if c is not dying or dead.
|
||||
// Otherwise, it returns the close error.
|
||||
// The caller is responsible for calling c.workers.Done().
|
||||
// The caller must not be holding onto c.stateMu.
|
||||
func (c *Conn) startWork() error {
|
||||
var err error
|
||||
c.stateMu.RLock()
|
||||
if c.state == connAlive {
|
||||
c.workers.Add(1)
|
||||
} else {
|
||||
err = c.closeErr
|
||||
}
|
||||
c.stateMu.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// teardown moves the connection from the dying to the dead state.
|
||||
func (c *Conn) teardown(abort rpccapnp.Message) {
|
||||
c.workers.Wait()
|
||||
|
||||
c.mu.Lock()
|
||||
for _, q := range c.questions {
|
||||
if q != nil {
|
||||
q.cancel(ErrConnClosed)
|
||||
}
|
||||
}
|
||||
c.questions = nil
|
||||
exps := c.exports
|
||||
c.exports = nil
|
||||
c.embargoes = nil
|
||||
for _, a := range c.answers {
|
||||
a.cancel()
|
||||
}
|
||||
c.answers = nil
|
||||
c.imports = nil
|
||||
c.mainFunc = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if c.mainCloser != nil {
|
||||
if err := c.mainCloser.Close(); err != nil {
|
||||
c.errorf("closing main interface: %v", err)
|
||||
}
|
||||
c.mainCloser = nil
|
||||
}
|
||||
// Closing an export may try to lock the Conn, so run it outside
|
||||
// critical section.
|
||||
for id, e := range exps {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
if err := e.client.Close(); err != nil {
|
||||
c.errorf("export %v close: %v", id, err)
|
||||
}
|
||||
}
|
||||
exps = nil
|
||||
|
||||
var werr error
|
||||
if abort.IsValid() {
|
||||
werr = c.transport.SendMessage(context.Background(), abort)
|
||||
}
|
||||
cerr := c.transport.Close()
|
||||
|
||||
c.stateMu.Lock()
|
||||
if c.closeErr == ErrConnClosed {
|
||||
if cerr != nil {
|
||||
c.closeErr = cerr
|
||||
} else if werr != nil {
|
||||
c.closeErr = werr
|
||||
}
|
||||
}
|
||||
c.state = connDead
|
||||
close(c.death)
|
||||
c.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// Bootstrap returns the receiver's main interface.
|
||||
func (c *Conn) Bootstrap(ctx context.Context) capnp.Client {
|
||||
// TODO(light): Create a client that returns immediately.
|
||||
select {
|
||||
case <-c.mu:
|
||||
// Locked.
|
||||
defer c.mu.Unlock()
|
||||
if err := c.startWork(); err != nil {
|
||||
return capnp.ErrorClient(err)
|
||||
}
|
||||
defer c.workers.Done()
|
||||
case <-ctx.Done():
|
||||
return capnp.ErrorClient(ctx.Err())
|
||||
case <-c.bg.Done():
|
||||
return capnp.ErrorClient(ErrConnClosed)
|
||||
}
|
||||
|
||||
q := c.newQuestion(ctx, nil /* method */)
|
||||
msg := newMessage(nil)
|
||||
boot, _ := msg.NewBootstrap()
|
||||
boot.SetQuestionId(uint32(q.id))
|
||||
// The mutex must be held while sending so that call order is preserved.
|
||||
// Worst case, this blocks until a message is sent on the transport.
|
||||
// Common case, this just adds to the channel queue.
|
||||
select {
|
||||
case c.out <- msg:
|
||||
q.start()
|
||||
return capnp.NewPipeline(q).Client()
|
||||
case <-ctx.Done():
|
||||
c.popQuestion(q.id)
|
||||
return capnp.ErrorClient(ctx.Err())
|
||||
case <-c.bg.Done():
|
||||
c.popQuestion(q.id)
|
||||
return capnp.ErrorClient(ErrConnClosed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage is run from the receive goroutine to process a single
|
||||
// message. m cannot be held onto past the return of handleMessage, and
|
||||
// c.mu is not held at the start of handleMessage.
|
||||
func (c *Conn) handleMessage(m rpccapnp.Message) {
|
||||
switch m.Which() {
|
||||
case rpccapnp.Message_Which_unimplemented:
|
||||
// no-op for now to avoid feedback loop
|
||||
case rpccapnp.Message_Which_abort:
|
||||
a, err := copyAbort(m)
|
||||
if err != nil {
|
||||
c.errorf("decode abort: %v", err)
|
||||
// Keep going, since we're trying to abort anyway.
|
||||
}
|
||||
c.infof("abort: %v", a)
|
||||
c.shutdown(a)
|
||||
case rpccapnp.Message_Which_return:
|
||||
m = copyRPCMessage(m)
|
||||
c.mu.Lock()
|
||||
err := c.handleReturnMessage(m)
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
c.errorf("handle return: %v", err)
|
||||
}
|
||||
case rpccapnp.Message_Which_finish:
|
||||
mfin, err := m.Finish()
|
||||
if err != nil {
|
||||
c.errorf("decode finish: %v", err)
|
||||
return
|
||||
}
|
||||
id := answerID(mfin.QuestionId())
|
||||
|
||||
c.mu.Lock()
|
||||
a := c.popAnswer(id)
|
||||
if a == nil {
|
||||
c.mu.Unlock()
|
||||
c.errorf("finish called for unknown answer %d", id)
|
||||
return
|
||||
}
|
||||
a.cancel()
|
||||
if mfin.ReleaseResultCaps() {
|
||||
for _, id := range a.resultCaps {
|
||||
c.releaseExport(id, 1)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
case rpccapnp.Message_Which_bootstrap:
|
||||
boot, err := m.Bootstrap()
|
||||
if err != nil {
|
||||
c.errorf("decode bootstrap: %v", err)
|
||||
return
|
||||
}
|
||||
id := answerID(boot.QuestionId())
|
||||
|
||||
c.mu.Lock()
|
||||
err = c.handleBootstrapMessage(id)
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
c.errorf("handle bootstrap: %v", err)
|
||||
}
|
||||
case rpccapnp.Message_Which_call:
|
||||
m = copyRPCMessage(m)
|
||||
c.mu.Lock()
|
||||
err := c.handleCallMessage(m)
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
c.errorf("handle call: %v", err)
|
||||
}
|
||||
case rpccapnp.Message_Which_release:
|
||||
rel, err := m.Release()
|
||||
if err != nil {
|
||||
c.errorf("decode release: %v", err)
|
||||
return
|
||||
}
|
||||
id := exportID(rel.Id())
|
||||
refs := int(rel.ReferenceCount())
|
||||
|
||||
c.mu.Lock()
|
||||
c.releaseExport(id, refs)
|
||||
c.mu.Unlock()
|
||||
case rpccapnp.Message_Which_disembargo:
|
||||
m = copyRPCMessage(m)
|
||||
c.mu.Lock()
|
||||
err := c.handleDisembargoMessage(m)
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// Any failure in a disembargo is a protocol violation.
|
||||
c.abort(err)
|
||||
}
|
||||
default:
|
||||
c.infof("received unimplemented message, which = %v", m.Which())
|
||||
um := newUnimplementedMessage(nil, m)
|
||||
c.sendMessage(um)
|
||||
}
|
||||
}
|
||||
|
||||
func newUnimplementedMessage(buf []byte, m rpccapnp.Message) rpccapnp.Message {
|
||||
n := newMessage(buf)
|
||||
n.SetUnimplemented(m)
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *Conn) fillParams(payload rpccapnp.Payload, cl *capnp.Call) error {
|
||||
params, err := cl.PlaceParams(payload.Segment())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := payload.SetContent(params); err != nil {
|
||||
return err
|
||||
}
|
||||
ctab, err := c.makeCapTable(payload.Segment())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := payload.SetCapTable(ctab); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func transformToPromisedAnswer(s *capnp.Segment, answer rpccapnp.PromisedAnswer, transform []capnp.PipelineOp) error {
|
||||
opList, err := rpccapnp.NewPromisedAnswer_Op_List(s, int32(len(transform)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, op := range transform {
|
||||
opList.At(i).SetGetPointerField(uint16(op.Field))
|
||||
}
|
||||
err = answer.SetTransform(opList)
|
||||
return err
|
||||
}
|
||||
|
||||
// handleReturnMessage is to handle a received return message.
|
||||
// The caller is holding onto c.mu.
|
||||
func (c *Conn) handleReturnMessage(m rpccapnp.Message) error {
|
||||
ret, err := m.Return()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
id := questionID(ret.AnswerId())
|
||||
q := c.popQuestion(id)
|
||||
if q == nil {
|
||||
return fmt.Errorf("received return for unknown question id=%d", id)
|
||||
}
|
||||
if ret.ReleaseParamCaps() {
|
||||
for _, id := range q.paramCaps {
|
||||
c.releaseExport(id, 1)
|
||||
}
|
||||
}
|
||||
q.mu.RLock()
|
||||
qstate := q.state
|
||||
q.mu.RUnlock()
|
||||
if qstate == questionCanceled {
|
||||
// We already sent the finish message.
|
||||
return nil
|
||||
}
|
||||
releaseResultCaps := true
|
||||
switch ret.Which() {
|
||||
case rpccapnp.Return_Which_results:
|
||||
releaseResultCaps = false
|
||||
results, err := ret.Results()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.populateMessageCapTable(results); err == errUnimplemented {
|
||||
um := newUnimplementedMessage(nil, m)
|
||||
c.sendMessage(um)
|
||||
return errUnimplemented
|
||||
} else if err != nil {
|
||||
c.abort(err)
|
||||
return err
|
||||
}
|
||||
content, err := results.ContentPtr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.fulfill(content)
|
||||
case rpccapnp.Return_Which_exception:
|
||||
exc, err := ret.Exception()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e := error(Exception{exc})
|
||||
if q.method != nil {
|
||||
e = &capnp.MethodError{
|
||||
Method: q.method,
|
||||
Err: e,
|
||||
}
|
||||
} else {
|
||||
e = bootstrapError{e}
|
||||
}
|
||||
q.reject(e)
|
||||
case rpccapnp.Return_Which_canceled:
|
||||
err := &questionError{
|
||||
id: id,
|
||||
method: q.method,
|
||||
err: fmt.Errorf("receiver reported canceled"),
|
||||
}
|
||||
c.errorf("%v", err)
|
||||
q.reject(err)
|
||||
return nil
|
||||
default:
|
||||
um := newUnimplementedMessage(nil, m)
|
||||
c.sendMessage(um)
|
||||
return errUnimplemented
|
||||
}
|
||||
fin := newFinishMessage(nil, id, releaseResultCaps)
|
||||
c.sendMessage(fin)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newFinishMessage(buf []byte, questionID questionID, release bool) rpccapnp.Message {
|
||||
m := newMessage(buf)
|
||||
f, _ := m.NewFinish()
|
||||
f.SetQuestionId(uint32(questionID))
|
||||
f.SetReleaseResultCaps(release)
|
||||
return m
|
||||
}
|
||||
|
||||
// populateMessageCapTable converts the descriptors in the payload into
|
||||
// clients and sets it on the message the payload is a part of.
|
||||
func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error {
|
||||
msg := payload.Segment().Message()
|
||||
ctab, err := payload.CapTable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, n := 0, ctab.Len(); i < n; i++ {
|
||||
desc := ctab.At(i)
|
||||
switch desc.Which() {
|
||||
case rpccapnp.CapDescriptor_Which_none:
|
||||
msg.AddCap(nil)
|
||||
case rpccapnp.CapDescriptor_Which_senderHosted:
|
||||
id := importID(desc.SenderHosted())
|
||||
client := c.addImport(id)
|
||||
msg.AddCap(client)
|
||||
case rpccapnp.CapDescriptor_Which_senderPromise:
|
||||
// We do the same thing as senderHosted, above. @kentonv suggested this on
|
||||
// issue #2; this let's messages be delivered properly, although it's a bit
|
||||
// of a hack, and as Kenton describes, it has some disadvantages:
|
||||
//
|
||||
// > * Apps sometimes want to wait for promise resolution, and to find out if
|
||||
// > it resolved to an exception. You won't be able to provide that API. But,
|
||||
// > usually, it isn't needed.
|
||||
// > * If the promise resolves to a capability hosted on the receiver,
|
||||
// > messages sent to it will uselessly round-trip over the network
|
||||
// > rather than being delivered locally.
|
||||
id := importID(desc.SenderPromise())
|
||||
client := c.addImport(id)
|
||||
msg.AddCap(client)
|
||||
case rpccapnp.CapDescriptor_Which_receiverHosted:
|
||||
id := exportID(desc.ReceiverHosted())
|
||||
e := c.findExport(id)
|
||||
if e == nil {
|
||||
return fmt.Errorf("rpc: capability table references unknown export ID %d", id)
|
||||
}
|
||||
msg.AddCap(e.rc.Ref())
|
||||
case rpccapnp.CapDescriptor_Which_receiverAnswer:
|
||||
recvAns, err := desc.ReceiverAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
id := answerID(recvAns.QuestionId())
|
||||
a := c.answers[id]
|
||||
if a == nil {
|
||||
return fmt.Errorf("rpc: capability table references unknown answer ID %d", id)
|
||||
}
|
||||
recvTransform, err := recvAns.Transform()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transform := promisedAnswerOpsToTransform(recvTransform)
|
||||
msg.AddCap(a.pipelineClient(transform))
|
||||
default:
|
||||
c.errorf("unknown capability type %v", desc.Which())
|
||||
return errUnimplemented
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeCapTable converts the clients in the segment's message into capability descriptors.
|
||||
func (c *Conn) makeCapTable(s *capnp.Segment) (rpccapnp.CapDescriptor_List, error) {
|
||||
msgtab := s.Message().CapTable
|
||||
t, err := rpccapnp.NewCapDescriptor_List(s, int32(len(msgtab)))
|
||||
if err != nil {
|
||||
return rpccapnp.CapDescriptor_List{}, nil
|
||||
}
|
||||
for i, client := range msgtab {
|
||||
desc := t.At(i)
|
||||
if client == nil {
|
||||
desc.SetNone()
|
||||
continue
|
||||
}
|
||||
c.descriptorForClient(desc, client)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// handleBootstrapMessage handles a received bootstrap message.
|
||||
// The caller holds onto c.mu.
|
||||
func (c *Conn) handleBootstrapMessage(id answerID) error {
|
||||
ctx, cancel := c.newContext()
|
||||
defer cancel()
|
||||
a := c.insertAnswer(id, cancel)
|
||||
if a == nil {
|
||||
// Question ID reused, error out.
|
||||
retmsg := newReturnMessage(nil, id)
|
||||
r, _ := retmsg.Return()
|
||||
setReturnException(r, errQuestionReused)
|
||||
return c.sendMessage(retmsg)
|
||||
}
|
||||
if c.mainFunc == nil {
|
||||
return a.reject(errNoMainInterface)
|
||||
}
|
||||
main, err := c.mainFunc(ctx)
|
||||
if err != nil {
|
||||
return a.reject(errNoMainInterface)
|
||||
}
|
||||
m := &capnp.Message{
|
||||
Arena: capnp.SingleSegment(make([]byte, 0)),
|
||||
CapTable: []capnp.Client{main},
|
||||
}
|
||||
s, _ := m.Segment(0)
|
||||
in := capnp.NewInterface(s, 0)
|
||||
return a.fulfill(in.ToPtr())
|
||||
}
|
||||
|
||||
// handleCallMessage handles a received call message. It mutates the
|
||||
// capability table of its parameter. The caller holds onto c.mu.
|
||||
func (c *Conn) handleCallMessage(m rpccapnp.Message) error {
|
||||
mcall, err := m.Call()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mt, err := mcall.Target()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if mt.Which() != rpccapnp.MessageTarget_Which_importedCap && mt.Which() != rpccapnp.MessageTarget_Which_promisedAnswer {
|
||||
um := newUnimplementedMessage(nil, m)
|
||||
return c.sendMessage(um)
|
||||
}
|
||||
mparams, err := mcall.Params()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.populateMessageCapTable(mparams); err == errUnimplemented {
|
||||
um := newUnimplementedMessage(nil, m)
|
||||
return c.sendMessage(um)
|
||||
} else if err != nil {
|
||||
c.abort(err)
|
||||
return err
|
||||
}
|
||||
ctx, cancel := c.newContext()
|
||||
id := answerID(mcall.QuestionId())
|
||||
a := c.insertAnswer(id, cancel)
|
||||
if a == nil {
|
||||
// Question ID reused, error out.
|
||||
c.abort(errQuestionReused)
|
||||
return errQuestionReused
|
||||
}
|
||||
meth := capnp.Method{
|
||||
InterfaceID: mcall.InterfaceId(),
|
||||
MethodID: mcall.MethodId(),
|
||||
}
|
||||
paramContent, err := mparams.ContentPtr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cl := &capnp.Call{
|
||||
Ctx: ctx,
|
||||
Method: meth,
|
||||
Params: paramContent.Struct(),
|
||||
}
|
||||
if err := c.routeCallMessage(a, mt, cl); err != nil {
|
||||
return a.reject(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) routeCallMessage(result *answer, mt rpccapnp.MessageTarget, cl *capnp.Call) error {
|
||||
switch mt.Which() {
|
||||
case rpccapnp.MessageTarget_Which_importedCap:
|
||||
id := exportID(mt.ImportedCap())
|
||||
e := c.findExport(id)
|
||||
if e == nil {
|
||||
return errBadTarget
|
||||
}
|
||||
answer := c.lockedCall(e.client, cl)
|
||||
go joinAnswer(result, answer)
|
||||
case rpccapnp.MessageTarget_Which_promisedAnswer:
|
||||
mpromise, err := mt.PromisedAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
id := answerID(mpromise.QuestionId())
|
||||
if id == result.id {
|
||||
// Grandfather paradox.
|
||||
return errBadTarget
|
||||
}
|
||||
pa := c.answers[id]
|
||||
if pa == nil {
|
||||
return errBadTarget
|
||||
}
|
||||
mtrans, err := mpromise.Transform()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transform := promisedAnswerOpsToTransform(mtrans)
|
||||
pa.mu.Lock()
|
||||
if pa.done {
|
||||
obj, err := pa.obj, pa.err
|
||||
pa.mu.Unlock()
|
||||
client := clientFromResolution(transform, obj, err)
|
||||
answer := c.lockedCall(client, cl)
|
||||
go joinAnswer(result, answer)
|
||||
} else {
|
||||
err = pa.queueCallLocked(cl, pcall{transform: transform, qcall: qcall{a: result}})
|
||||
pa.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleDisembargoMessage(msg rpccapnp.Message) error {
|
||||
d, err := msg.Disembargo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dtarget, err := d.Target()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch d.Context().Which() {
|
||||
case rpccapnp.Disembargo_context_Which_senderLoopback:
|
||||
id := embargoID(d.Context().SenderLoopback())
|
||||
if dtarget.Which() != rpccapnp.MessageTarget_Which_promisedAnswer {
|
||||
return errDisembargoNonImport
|
||||
}
|
||||
dpa, err := dtarget.PromisedAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
aid := answerID(dpa.QuestionId())
|
||||
a := c.answers[aid]
|
||||
if a == nil {
|
||||
return errDisembargoMissingAnswer
|
||||
}
|
||||
dtrans, err := dpa.Transform()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transform := promisedAnswerOpsToTransform(dtrans)
|
||||
queued, err := a.queueDisembargo(transform, id, dtarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !queued {
|
||||
// There's nothing to embargo; everything's been delivered.
|
||||
resp := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, id)
|
||||
rd, _ := resp.Disembargo()
|
||||
if err := rd.SetTarget(dtarget); err != nil {
|
||||
return err
|
||||
}
|
||||
c.sendMessage(resp)
|
||||
}
|
||||
case rpccapnp.Disembargo_context_Which_receiverLoopback:
|
||||
id := embargoID(d.Context().ReceiverLoopback())
|
||||
c.disembargo(id)
|
||||
default:
|
||||
um := newUnimplementedMessage(nil, msg)
|
||||
c.sendMessage(um)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newDisembargoMessage creates a disembargo message. Its target will be left blank.
|
||||
func newDisembargoMessage(buf []byte, which rpccapnp.Disembargo_context_Which, id embargoID) rpccapnp.Message {
|
||||
msg := newMessage(buf)
|
||||
d, _ := msg.NewDisembargo()
|
||||
switch which {
|
||||
case rpccapnp.Disembargo_context_Which_senderLoopback:
|
||||
d.Context().SetSenderLoopback(uint32(id))
|
||||
case rpccapnp.Disembargo_context_Which_receiverLoopback:
|
||||
d.Context().SetReceiverLoopback(uint32(id))
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// newContext creates a new context for a local call.
|
||||
func (c *Conn) newContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithCancel(c.bg)
|
||||
}
|
||||
|
||||
func promisedAnswerOpsToTransform(list rpccapnp.PromisedAnswer_Op_List) []capnp.PipelineOp {
|
||||
n := list.Len()
|
||||
transform := make([]capnp.PipelineOp, 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
op := list.At(i)
|
||||
switch op.Which() {
|
||||
case rpccapnp.PromisedAnswer_Op_Which_getPointerField:
|
||||
transform = append(transform, capnp.PipelineOp{
|
||||
Field: op.GetPointerField(),
|
||||
})
|
||||
case rpccapnp.PromisedAnswer_Op_Which_noop:
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
return transform
|
||||
}
|
||||
|
||||
func newAbortMessage(buf []byte, err error) rpccapnp.Message {
|
||||
n := newMessage(buf)
|
||||
e, _ := n.NewAbort()
|
||||
toException(e, err)
|
||||
return n
|
||||
}
|
||||
|
||||
func newReturnMessage(buf []byte, id answerID) rpccapnp.Message {
|
||||
retmsg := newMessage(buf)
|
||||
ret, _ := retmsg.NewReturn()
|
||||
ret.SetAnswerId(uint32(id))
|
||||
ret.SetReleaseParamCaps(false)
|
||||
return retmsg
|
||||
}
|
||||
|
||||
func setReturnException(ret rpccapnp.Return, err error) rpccapnp.Exception {
|
||||
e, _ := rpccapnp.NewException(ret.Segment())
|
||||
toException(e, err)
|
||||
ret.SetException(e)
|
||||
return e
|
||||
}
|
||||
|
||||
// clientFromResolution retrieves a client from a resolved question or
|
||||
// answer by applying a transform.
|
||||
func clientFromResolution(transform []capnp.PipelineOp, obj capnp.Ptr, err error) capnp.Client {
|
||||
if err != nil {
|
||||
return capnp.ErrorClient(err)
|
||||
}
|
||||
out, err := capnp.TransformPtr(obj, transform)
|
||||
if err != nil {
|
||||
return capnp.ErrorClient(err)
|
||||
}
|
||||
c := out.Interface().Client()
|
||||
if c == nil {
|
||||
return capnp.ErrorClient(capnp.ErrNullClient)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newMessage(buf []byte) rpccapnp.Message {
|
||||
_, s, err := capnp.NewMessage(capnp.SingleSegment(buf))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
m, err := rpccapnp.NewRootMessage(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// chanMutex is a mutex backed by a channel so that it can be used in a select.
|
||||
// A receive is a lock and a send is an unlock.
|
||||
type chanMutex chan struct{}
|
||||
|
||||
type connState int
|
||||
|
||||
const (
|
||||
connAlive connState = iota
|
||||
connDying
|
||||
connDead
|
||||
)
|
||||
|
||||
func newChanMutex() chanMutex {
|
||||
mu := make(chanMutex, 1)
|
||||
mu <- struct{}{}
|
||||
return mu
|
||||
}
|
||||
|
||||
func (mu chanMutex) Lock() {
|
||||
<-mu
|
||||
}
|
||||
|
||||
func (mu chanMutex) TryLock(ctx context.Context) error {
|
||||
select {
|
||||
case <-mu:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (mu chanMutex) Unlock() {
|
||||
mu <- struct{}{}
|
||||
}
|
629
vendor/zombiezen.com/go/capnproto2/rpc/rpc_test.go
generated
vendored
Normal file
629
vendor/zombiezen.com/go/capnproto2/rpc/rpc_test.go
generated
vendored
Normal file
@@ -0,0 +1,629 @@
|
||||
package rpc_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/logtransport"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/pipetransport"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceID uint64 = 0xa7317bd7216570aa
|
||||
methodID uint16 = 9
|
||||
bootstrapExportID uint32 = 84
|
||||
)
|
||||
|
||||
var logMessages = flag.Bool("logmessages", false, "whether to log the transport in tests. Messages are always from client to server.")
|
||||
|
||||
type testLogger struct {
|
||||
t interface {
|
||||
Logf(format string, args ...interface{})
|
||||
}
|
||||
}
|
||||
|
||||
func (l testLogger) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||
l.t.Logf("conn log: "+format, args...)
|
||||
}
|
||||
|
||||
func (l testLogger) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||
l.t.Logf("conn log: "+format, args...)
|
||||
}
|
||||
|
||||
func newUnpairedConn(t *testing.T, options ...rpc.ConnOption) (*rpc.Conn, rpc.Transport) {
|
||||
p, q := pipetransport.New()
|
||||
if *logMessages {
|
||||
p = logtransport.New(nil, p)
|
||||
}
|
||||
newopts := make([]rpc.ConnOption, len(options), len(options)+1)
|
||||
copy(newopts, options)
|
||||
newopts = append(newopts, rpc.ConnLog(testLogger{t}))
|
||||
c := rpc.NewConn(p, newopts...)
|
||||
return c, q
|
||||
}
|
||||
|
||||
func TestBootstrap(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn, p := newUnpairedConn(t)
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
|
||||
clientCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
readBootstrap(t, clientCtx, conn, p)
|
||||
}
|
||||
|
||||
func readBootstrap(t *testing.T, ctx context.Context, conn *rpc.Conn, p rpc.Transport) (client capnp.Client, questionID uint32) {
|
||||
clientCh := make(chan capnp.Client, 1)
|
||||
go func() {
|
||||
clientCh <- conn.Bootstrap(ctx)
|
||||
}()
|
||||
|
||||
msg, err := p.RecvMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal("Read Bootstrap failed:", err)
|
||||
}
|
||||
if msg.Which() != rpccapnp.Message_Which_bootstrap {
|
||||
t.Fatalf("Received %v message from bootstrap, want Message_Which_bootstrap", msg.Which())
|
||||
}
|
||||
boot, err := msg.Bootstrap()
|
||||
if err != nil {
|
||||
t.Fatal("Read Bootstrap failed:", err)
|
||||
}
|
||||
questionID = boot.QuestionId()
|
||||
// If this deadlocks, then Bootstrap isn't using a promised client.
|
||||
client = <-clientCh
|
||||
if client == nil {
|
||||
t.Fatal("Bootstrap client is nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestBootstrapFulfilledSenderHosted(t *testing.T) {
|
||||
testBootstrapFulfilled(t, false)
|
||||
}
|
||||
|
||||
func TestBootstrapFulfilledSenderPromise(t *testing.T) {
|
||||
testBootstrapFulfilled(t, true)
|
||||
}
|
||||
|
||||
func testBootstrapFulfilled(t *testing.T, resultIsPromise bool) {
|
||||
ctx := context.Background()
|
||||
conn, p := newUnpairedConn(t)
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
|
||||
clientCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
bootstrapAndFulfill(t, clientCtx, conn, p, resultIsPromise)
|
||||
}
|
||||
|
||||
// Receive a Finish message for the given question ID.
|
||||
//
|
||||
// Immediately releases any capabilities in the message.
|
||||
//
|
||||
// An error is returned if any of the following occur:
|
||||
//
|
||||
// * An error occurs when reading the message
|
||||
// * The message is not of type `Finish`
|
||||
// * The message's question ID is not equal to `questionID`.
|
||||
//
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: The context to be used when sending the message.
|
||||
// p: The rpc.Transport to send the message on.
|
||||
// questionID: The expected question ID.
|
||||
func recvFinish(ctx context.Context, p rpc.Transport, questionID uint32) error {
|
||||
if finish, err := p.RecvMessage(ctx); err != nil {
|
||||
return err
|
||||
} else if finish.Which() != rpccapnp.Message_Which_finish {
|
||||
return fmt.Errorf("message sent is %v; want Message_Which_finish", finish.Which())
|
||||
} else {
|
||||
f, err := finish.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if id := f.QuestionId(); id != questionID {
|
||||
return fmt.Errorf("finish question ID is %d; want %d", id, questionID)
|
||||
}
|
||||
if f.ReleaseResultCaps() {
|
||||
return fmt.Errorf("finish released bootstrap capability")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send a Return message with a single capability to a bootstrap interface in
|
||||
// its payload. Returns any error that occurs.
|
||||
//
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: The context to be used when sending the message.
|
||||
// p: The rpc.Transport to send the message on.
|
||||
// answerId: The message's answerId.
|
||||
// isPromise: If this is true, the capability in the response will be of type
|
||||
// senderPromise, otherwise it will be of type senderHosted.
|
||||
func sendBootstrapReturn(ctx context.Context, p rpc.Transport, answerId uint32, isPromise bool) error {
|
||||
return sendMessage(ctx, p, func(msg rpccapnp.Message) error {
|
||||
ret, err := msg.NewReturn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ret.SetAnswerId(answerId)
|
||||
payload, err := ret.NewResults()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload.SetContent(capnp.NewInterface(msg.Segment(), 0))
|
||||
capTable, err := payload.NewCapTable(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPromise {
|
||||
capTable.At(0).SetSenderPromise(bootstrapExportID)
|
||||
} else {
|
||||
capTable.At(0).SetSenderHosted(bootstrapExportID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func bootstrapAndFulfill(t *testing.T, ctx context.Context, conn *rpc.Conn, p rpc.Transport, resultIsPromise bool) capnp.Client {
|
||||
client, bootstrapID := readBootstrap(t, ctx, conn, p)
|
||||
if err := sendBootstrapReturn(ctx, p, bootstrapID, resultIsPromise); err != nil {
|
||||
t.Fatalf("sendBootstrapReturn: %v", err)
|
||||
}
|
||||
if err := recvFinish(ctx, p, bootstrapID); err != nil {
|
||||
t.Fatalf("recvFinish: %v", err)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func TestCallOnPromisedAnswer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn, p := newUnpairedConn(t)
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
client, bootstrapID := readBootstrap(t, ctx, conn, p)
|
||||
|
||||
readDone := startRecvMessage(p)
|
||||
client.Call(&capnp.Call{
|
||||
Ctx: ctx,
|
||||
Method: capnp.Method{
|
||||
InterfaceID: interfaceID,
|
||||
MethodID: methodID,
|
||||
},
|
||||
ParamsSize: capnp.ObjectSize{DataSize: 8},
|
||||
ParamsFunc: func(s capnp.Struct) error {
|
||||
s.SetUint64(0, 42)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
read := <-readDone
|
||||
|
||||
if read.err != nil {
|
||||
t.Fatal("Reading failed:", read.err)
|
||||
}
|
||||
if read.msg.Which() != rpccapnp.Message_Which_call {
|
||||
t.Fatalf("Conn sent %v message, want Message_Which_call", read.msg.Which())
|
||||
}
|
||||
call, err := read.msg.Call()
|
||||
if err != nil {
|
||||
t.Fatal("call error:", err)
|
||||
}
|
||||
if target, err := call.Target(); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if target.Which() == rpccapnp.MessageTarget_Which_promisedAnswer {
|
||||
if pa, err := target.PromisedAnswer(); err != nil {
|
||||
t.Error("call.target.promisedAnswer error:", err)
|
||||
} else {
|
||||
if qid := pa.QuestionId(); qid != bootstrapID {
|
||||
t.Errorf("Target question ID = %d; want %d", qid, bootstrapID)
|
||||
}
|
||||
// TODO(light): allow no-ops
|
||||
if xform, err := pa.Transform(); err != nil {
|
||||
t.Error("call.target.promisedAnswer.transform error:", err)
|
||||
} else if xform.Len() != 0 {
|
||||
t.Error("Target transform is non-empty")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Target is %v, want MessageTarget_Which_promisedAnswer", target.Which())
|
||||
}
|
||||
if id := call.InterfaceId(); id != interfaceID {
|
||||
t.Errorf("Interface ID = %x; want %x", id, interfaceID)
|
||||
}
|
||||
if id := call.MethodId(); id != methodID {
|
||||
t.Errorf("Method ID = %d; want %d", id, methodID)
|
||||
}
|
||||
if params, err := call.Params(); err != nil {
|
||||
t.Error("call.params error:", err)
|
||||
} else {
|
||||
if content, err := params.Content(); err != nil {
|
||||
t.Error("call.params.content error:", err)
|
||||
} else if x := capnp.ToStruct(content).Uint64(0); x != 42 {
|
||||
t.Errorf("Params content value = %d; want %d", x, 42)
|
||||
}
|
||||
}
|
||||
sendResultsTo := call.SendResultsTo()
|
||||
if sendResultsTo.Which() != rpccapnp.Call_sendResultsTo_Which_caller {
|
||||
t.Errorf("Send results to %v; want caller", sendResultsTo.Which())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallOnExportId_BootstrapIsPromise(t *testing.T) {
|
||||
testCallOnExportId(t, true)
|
||||
}
|
||||
|
||||
func TestCallOnExportId_BootstrapIsHosted(t *testing.T) {
|
||||
testCallOnExportId(t, false)
|
||||
}
|
||||
|
||||
func testCallOnExportId(t *testing.T, bootstrapIsPromise bool) {
|
||||
ctx := context.Background()
|
||||
conn, p := newUnpairedConn(t)
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
client := bootstrapAndFulfill(t, ctx, conn, p, bootstrapIsPromise)
|
||||
|
||||
readDone := startRecvMessage(p)
|
||||
client.Call(&capnp.Call{
|
||||
Ctx: ctx,
|
||||
Method: capnp.Method{
|
||||
InterfaceID: interfaceID,
|
||||
MethodID: methodID,
|
||||
},
|
||||
ParamsSize: capnp.ObjectSize{DataSize: 8},
|
||||
ParamsFunc: func(s capnp.Struct) error {
|
||||
s.SetUint64(0, 42)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
read := <-readDone
|
||||
|
||||
if read.err != nil {
|
||||
t.Fatal("Reading failed:", read.err)
|
||||
}
|
||||
call, err := read.msg.Call()
|
||||
if err != nil {
|
||||
t.Fatal("call error:", err)
|
||||
}
|
||||
if read.msg.Which() != rpccapnp.Message_Which_call {
|
||||
t.Fatalf("Conn sent %v message, want Message_Which_call", read.msg.Which())
|
||||
}
|
||||
if target, err := call.Target(); err != nil {
|
||||
t.Error("call.target error:", err)
|
||||
} else if target.Which() != rpccapnp.MessageTarget_Which_importedCap {
|
||||
t.Errorf("Target is %v, want MessageTarget_Which_importedCap", target.Which())
|
||||
} else {
|
||||
if id := target.ImportedCap(); id != bootstrapExportID {
|
||||
t.Errorf("Target imported cap = %d; want %d", id, bootstrapExportID)
|
||||
}
|
||||
}
|
||||
if id := call.InterfaceId(); id != interfaceID {
|
||||
t.Errorf("Interface ID = %x; want %x", id, interfaceID)
|
||||
}
|
||||
if id := call.MethodId(); id != methodID {
|
||||
t.Errorf("Method ID = %d; want %d", id, methodID)
|
||||
}
|
||||
if params, err := call.Params(); err != nil {
|
||||
t.Error("call.params error:", err)
|
||||
} else if content, err := params.Content(); err != nil {
|
||||
t.Error("call.params.content error:", err)
|
||||
} else if x := capnp.ToStruct(content).Uint64(0); x != 42 {
|
||||
t.Errorf("Params content value = %d; want %d", x, 42)
|
||||
}
|
||||
if sendResultsTo := call.SendResultsTo(); err != nil {
|
||||
t.Error("call.sendResultsTo error:", err)
|
||||
} else if sendResultsTo.Which() != rpccapnp.Call_sendResultsTo_Which_caller {
|
||||
t.Errorf("Send results to %v; want caller", sendResultsTo.Which())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainInterface(t *testing.T) {
|
||||
main := mockClient()
|
||||
conn, p := newUnpairedConn(t, rpc.MainInterface(main))
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
|
||||
bootstrapRoundtrip(t, p)
|
||||
}
|
||||
|
||||
func bootstrapRoundtrip(t *testing.T, p rpc.Transport) (importID, questionID uint32) {
|
||||
questionID = 54
|
||||
err := sendMessage(context.TODO(), p, func(msg rpccapnp.Message) error {
|
||||
bootstrap, err := msg.NewBootstrap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bootstrap.SetQuestionId(questionID)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("Write Bootstrap failed:", err)
|
||||
}
|
||||
msg, err := p.RecvMessage(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal("Read Bootstrap response failed:", err)
|
||||
}
|
||||
|
||||
if msg.Which() != rpccapnp.Message_Which_return {
|
||||
t.Fatalf("Conn sent %v message, want Message_Which_return", msg.Which())
|
||||
}
|
||||
ret, err := msg.Return()
|
||||
if err != nil {
|
||||
t.Fatal("return error:", err)
|
||||
}
|
||||
if id := ret.AnswerId(); id != questionID {
|
||||
t.Fatalf("msg.Return().AnswerId() = %d; want %d", id, questionID)
|
||||
}
|
||||
if ret.Which() != rpccapnp.Return_Which_results {
|
||||
t.Fatalf("msg.Return().Which() = %v; want Return_Which_results", ret.Which())
|
||||
}
|
||||
payload, err := ret.Results()
|
||||
if err != nil {
|
||||
t.Fatal("return.results error:", err)
|
||||
}
|
||||
content, err := payload.ContentPtr()
|
||||
if err != nil {
|
||||
t.Fatal("return.results.content error:", err)
|
||||
}
|
||||
in := content.Interface()
|
||||
if !in.IsValid() {
|
||||
t.Fatalf("Result payload contains %v; want interface", content)
|
||||
}
|
||||
capIdx := int(in.Capability())
|
||||
capTable, err := payload.CapTable()
|
||||
if err != nil {
|
||||
t.Fatal("return.results.capTable error:", err)
|
||||
}
|
||||
if n := capTable.Len(); capIdx >= n {
|
||||
t.Fatalf("Payload capTable has size %d, but capability index = %d", n, capIdx)
|
||||
}
|
||||
if cw := capTable.At(capIdx).Which(); cw != rpccapnp.CapDescriptor_Which_senderHosted {
|
||||
t.Fatalf("Capability type is %d; want CapDescriptor_Which_senderHosted", cw)
|
||||
}
|
||||
return capTable.At(capIdx).SenderHosted(), questionID
|
||||
}
|
||||
|
||||
func TestReceiveCallOnPromisedAnswer(t *testing.T) {
|
||||
const questionID = 999
|
||||
called := false
|
||||
main := stubClient(func(ctx context.Context, params capnp.Struct) (capnp.Struct, error) {
|
||||
msg, s, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
result, err := capnp.NewStruct(s, capnp.ObjectSize{})
|
||||
if err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
called = true
|
||||
if err := msg.SetRoot(result); err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
conn, p := newUnpairedConn(t, rpc.MainInterface(main))
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
_, bootqID := bootstrapRoundtrip(t, p)
|
||||
|
||||
err := sendMessage(context.TODO(), p, func(msg rpccapnp.Message) error {
|
||||
call, err := msg.NewCall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
call.SetQuestionId(questionID)
|
||||
call.SetInterfaceId(interfaceID)
|
||||
call.SetMethodId(methodID)
|
||||
target, err := call.NewTarget()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa, err := target.NewPromisedAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa.SetQuestionId(bootqID)
|
||||
payload, err := call.NewParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
content, err := capnp.NewStruct(msg.Segment(), capnp.ObjectSize{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload.SetContent(content)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("Call message failed:", err)
|
||||
}
|
||||
retmsg, err := p.RecvMessage(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal("Read Call return failed:", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("interface not called")
|
||||
}
|
||||
if retmsg.Which() != rpccapnp.Message_Which_return {
|
||||
t.Fatalf("Return message is %v; want %v", retmsg.Which(), rpccapnp.Message_Which_return)
|
||||
}
|
||||
ret, err := retmsg.Return()
|
||||
if err != nil {
|
||||
t.Fatal("return error:", err)
|
||||
}
|
||||
if id := ret.AnswerId(); id != questionID {
|
||||
t.Errorf("Return.answerId = %d; want %d", id, questionID)
|
||||
}
|
||||
if ret.Which() == rpccapnp.Return_Which_results {
|
||||
// TODO(light)
|
||||
} else if ret.Which() == rpccapnp.Return_Which_exception {
|
||||
exc, _ := ret.Exception()
|
||||
reason, _ := exc.Reason()
|
||||
t.Error("Return.exception:", reason)
|
||||
} else {
|
||||
t.Errorf("Return.Which() = %v; want %v", ret.Which(), rpccapnp.Return_Which_results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiveCallOnExport(t *testing.T) {
|
||||
const questionID = 999
|
||||
called := false
|
||||
main := stubClient(func(ctx context.Context, params capnp.Struct) (capnp.Struct, error) {
|
||||
msg, s, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
result, err := capnp.NewStruct(s, capnp.ObjectSize{})
|
||||
if err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
called = true
|
||||
if err := msg.SetRoot(result); err != nil {
|
||||
return capnp.Struct{}, err
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
conn, p := newUnpairedConn(t, rpc.MainInterface(main))
|
||||
defer conn.Close()
|
||||
defer p.Close()
|
||||
importID := sendBootstrapAndFinish(t, p)
|
||||
|
||||
err := sendMessage(context.TODO(), p, func(msg rpccapnp.Message) error {
|
||||
call, err := msg.NewCall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
call.SetQuestionId(questionID)
|
||||
call.SetInterfaceId(interfaceID)
|
||||
call.SetMethodId(methodID)
|
||||
target, err := call.NewTarget()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target.SetImportedCap(importID)
|
||||
call.SetTarget(target)
|
||||
payload, err := call.NewParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
content, err := capnp.NewStruct(msg.Segment(), capnp.ObjectSize{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload.SetContent(content)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("Call message failed:", err)
|
||||
}
|
||||
retmsg, err := p.RecvMessage(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal("Read Call return failed:", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("interface not called")
|
||||
}
|
||||
if retmsg.Which() != rpccapnp.Message_Which_return {
|
||||
t.Fatalf("Return message is %v; want %v", retmsg.Which(), rpccapnp.Message_Which_return)
|
||||
}
|
||||
ret, err := retmsg.Return()
|
||||
if err != nil {
|
||||
t.Fatal("return error:", err)
|
||||
}
|
||||
if id := ret.AnswerId(); id != questionID {
|
||||
t.Errorf("Return.answerId = %d; want %d", id, questionID)
|
||||
}
|
||||
if ret.Which() == rpccapnp.Return_Which_results {
|
||||
// TODO(light)
|
||||
} else if ret.Which() == rpccapnp.Return_Which_exception {
|
||||
exc, _ := ret.Exception()
|
||||
reason, _ := exc.Reason()
|
||||
t.Error("Return.exception:", reason)
|
||||
} else {
|
||||
t.Errorf("Return.Which() = %v; want %v", ret.Which(), rpccapnp.Return_Which_results)
|
||||
}
|
||||
}
|
||||
|
||||
func sendBootstrapAndFinish(t *testing.T, p rpc.Transport) (importID uint32) {
|
||||
importID, questionID := bootstrapRoundtrip(t, p)
|
||||
err := sendMessage(context.TODO(), p, func(msg rpccapnp.Message) error {
|
||||
finish, err := msg.NewFinish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finish.SetQuestionId(questionID)
|
||||
finish.SetReleaseResultCaps(false)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("Write Bootstrap Finish failed:", err)
|
||||
}
|
||||
return importID
|
||||
}
|
||||
|
||||
func sendMessage(ctx context.Context, t rpc.Transport, f func(rpccapnp.Message) error) error {
|
||||
_, s, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
m, err := rpccapnp.NewRootMessage(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return t.SendMessage(ctx, m)
|
||||
}
|
||||
|
||||
func startRecvMessage(t rpc.Transport) <-chan asyncRecv {
|
||||
ch := make(chan asyncRecv, 1)
|
||||
go func() {
|
||||
msg, err := t.RecvMessage(context.TODO())
|
||||
ch <- asyncRecv{msg, err}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
type asyncRecv struct {
|
||||
msg rpccapnp.Message
|
||||
err error
|
||||
}
|
||||
|
||||
func mockClient() capnp.Client {
|
||||
return capnp.ErrorClient(errMockClient)
|
||||
}
|
||||
|
||||
type stubClient func(ctx context.Context, params capnp.Struct) (capnp.Struct, error)
|
||||
|
||||
func (stub stubClient) Call(call *capnp.Call) capnp.Answer {
|
||||
if call.Method.InterfaceID != interfaceID || call.Method.MethodID != methodID {
|
||||
return capnp.ErrorAnswer(errNotImplemented)
|
||||
}
|
||||
cc, err := call.PlaceParams(nil)
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
s, err := stub(call.Ctx, cc)
|
||||
if err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
return capnp.ImmediateAnswer(s)
|
||||
}
|
||||
|
||||
func (stub stubClient) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errMockClient = errors.New("rpc_test: mock client")
|
||||
errNotImplemented = errors.New("rpc_test: stub client method not implemented")
|
||||
)
|
255
vendor/zombiezen.com/go/capnproto2/rpc/tables.go
generated
vendored
Normal file
255
vendor/zombiezen.com/go/capnproto2/rpc/tables.go
generated
vendored
Normal file
@@ -0,0 +1,255 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc/internal/refcount"
|
||||
)
|
||||
|
||||
// Table IDs
|
||||
type (
|
||||
questionID uint32
|
||||
answerID uint32
|
||||
exportID uint32
|
||||
importID uint32
|
||||
embargoID uint32
|
||||
)
|
||||
|
||||
// impent is an entry in the import table.
|
||||
type impent struct {
|
||||
rc *refcount.RefCount
|
||||
refs int
|
||||
}
|
||||
|
||||
// addImport increases the counter of the times the import ID was sent to this vat.
|
||||
func (c *Conn) addImport(id importID) capnp.Client {
|
||||
if c.imports == nil {
|
||||
c.imports = make(map[importID]*impent)
|
||||
} else if ent := c.imports[id]; ent != nil {
|
||||
ent.refs++
|
||||
return ent.rc.Ref()
|
||||
}
|
||||
client := &importClient{
|
||||
id: id,
|
||||
conn: c,
|
||||
}
|
||||
rc, ref := refcount.New(client)
|
||||
c.imports[id] = &impent{rc: rc, refs: 1}
|
||||
return ref
|
||||
}
|
||||
|
||||
// popImport removes the import ID and returns the number of times the import ID was sent to this vat.
|
||||
func (c *Conn) popImport(id importID) (refs int) {
|
||||
if c.imports == nil {
|
||||
return 0
|
||||
}
|
||||
ent := c.imports[id]
|
||||
if ent == nil {
|
||||
return 0
|
||||
}
|
||||
refs = ent.refs
|
||||
delete(c.imports, id)
|
||||
return refs
|
||||
}
|
||||
|
||||
// An importClient implements capnp.Client for a remote capability.
|
||||
type importClient struct {
|
||||
id importID
|
||||
conn *Conn
|
||||
closed bool // protected by conn.mu
|
||||
}
|
||||
|
||||
func (ic *importClient) Call(cl *capnp.Call) capnp.Answer {
|
||||
select {
|
||||
case <-ic.conn.mu:
|
||||
if err := ic.conn.startWork(); err != nil {
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
case <-cl.Ctx.Done():
|
||||
return capnp.ErrorAnswer(cl.Ctx.Err())
|
||||
}
|
||||
ans := ic.lockedCall(cl)
|
||||
ic.conn.workers.Done()
|
||||
ic.conn.mu.Unlock()
|
||||
return ans
|
||||
}
|
||||
|
||||
// lockedCall is equivalent to Call but assumes that the caller is
|
||||
// already holding onto ic.conn.mu.
|
||||
func (ic *importClient) lockedCall(cl *capnp.Call) capnp.Answer {
|
||||
if ic.closed {
|
||||
return capnp.ErrorAnswer(errImportClosed)
|
||||
}
|
||||
|
||||
q := ic.conn.newQuestion(cl.Ctx, &cl.Method)
|
||||
msg := newMessage(nil)
|
||||
msgCall, _ := msg.NewCall()
|
||||
msgCall.SetQuestionId(uint32(q.id))
|
||||
msgCall.SetInterfaceId(cl.Method.InterfaceID)
|
||||
msgCall.SetMethodId(cl.Method.MethodID)
|
||||
target, _ := msgCall.NewTarget()
|
||||
target.SetImportedCap(uint32(ic.id))
|
||||
payload, _ := msgCall.NewParams()
|
||||
if err := ic.conn.fillParams(payload, cl); err != nil {
|
||||
ic.conn.popQuestion(q.id)
|
||||
return capnp.ErrorAnswer(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case ic.conn.out <- msg:
|
||||
case <-cl.Ctx.Done():
|
||||
ic.conn.popQuestion(q.id)
|
||||
return capnp.ErrorAnswer(cl.Ctx.Err())
|
||||
case <-ic.conn.bg.Done():
|
||||
ic.conn.popQuestion(q.id)
|
||||
return capnp.ErrorAnswer(ErrConnClosed)
|
||||
}
|
||||
q.start()
|
||||
return q
|
||||
}
|
||||
|
||||
func (ic *importClient) Close() error {
|
||||
ic.conn.mu.Lock()
|
||||
if err := ic.conn.startWork(); err != nil {
|
||||
ic.conn.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
closed := ic.closed
|
||||
var i int
|
||||
if !closed {
|
||||
i = ic.conn.popImport(ic.id)
|
||||
ic.closed = true
|
||||
}
|
||||
ic.conn.workers.Done()
|
||||
ic.conn.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return errImportClosed
|
||||
}
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
msg := newMessage(nil)
|
||||
mr, err := msg.NewRelease()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mr.SetId(uint32(ic.id))
|
||||
mr.SetReferenceCount(uint32(i))
|
||||
select {
|
||||
case ic.conn.out <- msg:
|
||||
return nil
|
||||
case <-ic.conn.bg.Done():
|
||||
return ErrConnClosed
|
||||
}
|
||||
}
|
||||
|
||||
type export struct {
|
||||
id exportID
|
||||
rc *refcount.RefCount
|
||||
client capnp.Client
|
||||
wireRefs int
|
||||
}
|
||||
|
||||
func (c *Conn) findExport(id exportID) *export {
|
||||
if int(id) >= len(c.exports) {
|
||||
return nil
|
||||
}
|
||||
return c.exports[id]
|
||||
}
|
||||
|
||||
// addExport ensures that the client is present in the table, returning its ID.
|
||||
// If the client is already in the table, the previous ID is returned.
|
||||
func (c *Conn) addExport(client capnp.Client) exportID {
|
||||
for i, e := range c.exports {
|
||||
if e != nil && isSameClient(e.rc.Client, client) {
|
||||
e.wireRefs++
|
||||
return exportID(i)
|
||||
}
|
||||
}
|
||||
id := exportID(c.exportID.next())
|
||||
rc, client := refcount.New(client)
|
||||
export := &export{
|
||||
id: id,
|
||||
rc: rc,
|
||||
client: client,
|
||||
wireRefs: 1,
|
||||
}
|
||||
if int(id) == len(c.exports) {
|
||||
c.exports = append(c.exports, export)
|
||||
} else {
|
||||
c.exports[id] = export
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (c *Conn) releaseExport(id exportID, refs int) {
|
||||
e := c.findExport(id)
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
e.wireRefs -= refs
|
||||
if e.wireRefs > 0 {
|
||||
return
|
||||
}
|
||||
if e.wireRefs < 0 {
|
||||
c.errorf("warning: export %v has negative refcount (%d)", id, e.wireRefs)
|
||||
}
|
||||
if err := e.client.Close(); err != nil {
|
||||
c.errorf("export %v close: %v", id, err)
|
||||
}
|
||||
c.exports[id] = nil
|
||||
c.exportID.remove(uint32(id))
|
||||
}
|
||||
|
||||
type embargo <-chan struct{}
|
||||
|
||||
func (c *Conn) newEmbargo() (embargoID, embargo) {
|
||||
id := embargoID(c.embargoID.next())
|
||||
e := make(chan struct{})
|
||||
if int(id) == len(c.embargoes) {
|
||||
c.embargoes = append(c.embargoes, e)
|
||||
} else {
|
||||
c.embargoes[id] = e
|
||||
}
|
||||
return id, e
|
||||
}
|
||||
|
||||
func (c *Conn) disembargo(id embargoID) {
|
||||
if int(id) >= len(c.embargoes) {
|
||||
return
|
||||
}
|
||||
e := c.embargoes[id]
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
close(e)
|
||||
c.embargoes[id] = nil
|
||||
c.embargoID.remove(uint32(id))
|
||||
}
|
||||
|
||||
// idgen returns a sequence of monotonically increasing IDs with
|
||||
// support for replacement. The zero value is a generator that
|
||||
// starts at zero.
|
||||
type idgen struct {
|
||||
i uint32
|
||||
free []uint32
|
||||
}
|
||||
|
||||
func (gen *idgen) next() uint32 {
|
||||
if n := len(gen.free); n > 0 {
|
||||
i := gen.free[n-1]
|
||||
gen.free = gen.free[:n-1]
|
||||
return i
|
||||
}
|
||||
i := gen.i
|
||||
gen.i++
|
||||
return i
|
||||
}
|
||||
|
||||
func (gen *idgen) remove(i uint32) {
|
||||
gen.free = append(gen.free, i)
|
||||
}
|
||||
|
||||
var errImportClosed = errors.New("rpc: call on closed import")
|
175
vendor/zombiezen.com/go/capnproto2/rpc/transport.go
generated
vendored
Normal file
175
vendor/zombiezen.com/go/capnproto2/rpc/transport.go
generated
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
// Transport is the interface that abstracts sending and receiving
|
||||
// individual messages of the Cap'n Proto RPC protocol.
|
||||
type Transport interface {
|
||||
// SendMessage sends msg.
|
||||
SendMessage(ctx context.Context, msg rpccapnp.Message) error
|
||||
|
||||
// RecvMessage waits to receive a message and returns it.
|
||||
// Implementations may re-use buffers between calls, so the message is
|
||||
// only valid until the next call to RecvMessage.
|
||||
RecvMessage(ctx context.Context) (rpccapnp.Message, error)
|
||||
|
||||
// Close releases any resources associated with the transport.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type streamTransport struct {
|
||||
rwc io.ReadWriteCloser
|
||||
deadline writeDeadlineSetter
|
||||
|
||||
enc *capnp.Encoder
|
||||
dec *capnp.Decoder
|
||||
wbuf bytes.Buffer
|
||||
}
|
||||
|
||||
// StreamTransport creates a transport that sends and receives messages
|
||||
// by serializing and deserializing unpacked Cap'n Proto messages.
|
||||
// Closing the transport will close the underlying ReadWriteCloser.
|
||||
func StreamTransport(rwc io.ReadWriteCloser) Transport {
|
||||
d, _ := rwc.(writeDeadlineSetter)
|
||||
s := &streamTransport{
|
||||
rwc: rwc,
|
||||
deadline: d,
|
||||
dec: capnp.NewDecoder(rwc),
|
||||
}
|
||||
s.wbuf.Grow(4096)
|
||||
s.enc = capnp.NewEncoder(&s.wbuf)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *streamTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
||||
s.wbuf.Reset()
|
||||
if err := s.enc.Encode(msg.Segment().Message()); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.deadline != nil {
|
||||
// TODO(light): log errors
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
s.deadline.SetWriteDeadline(d)
|
||||
} else {
|
||||
s.deadline.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
_, err := s.rwc.Write(s.wbuf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *streamTransport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
||||
var (
|
||||
msg *capnp.Message
|
||||
err error
|
||||
)
|
||||
read := make(chan struct{})
|
||||
go func() {
|
||||
msg, err = s.dec.Decode()
|
||||
close(read)
|
||||
}()
|
||||
select {
|
||||
case <-read:
|
||||
case <-ctx.Done():
|
||||
return rpccapnp.Message{}, ctx.Err()
|
||||
}
|
||||
if err != nil {
|
||||
return rpccapnp.Message{}, err
|
||||
}
|
||||
return rpccapnp.ReadRootMessage(msg)
|
||||
}
|
||||
|
||||
func (s *streamTransport) Close() error {
|
||||
return s.rwc.Close()
|
||||
}
|
||||
|
||||
type writeDeadlineSetter interface {
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// dispatchSend runs in its own goroutine and sends messages on a transport.
|
||||
func (c *Conn) dispatchSend() {
|
||||
defer c.workers.Done()
|
||||
for {
|
||||
select {
|
||||
case msg := <-c.out:
|
||||
err := c.transport.SendMessage(c.bg, msg)
|
||||
if err != nil {
|
||||
c.errorf("writing %v: %v", msg.Which(), err)
|
||||
}
|
||||
case <-c.bg.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessage enqueues a message to be sent or returns an error if the
|
||||
// connection is shut down before the message is queued. It is safe to
|
||||
// call from multiple goroutines and does not require holding c.mu.
|
||||
func (c *Conn) sendMessage(msg rpccapnp.Message) error {
|
||||
select {
|
||||
case c.out <- msg:
|
||||
return nil
|
||||
case <-c.bg.Done():
|
||||
return ErrConnClosed
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchRecv runs in its own goroutine and receives messages from a transport.
|
||||
func (c *Conn) dispatchRecv() {
|
||||
defer c.workers.Done()
|
||||
for {
|
||||
msg, err := c.transport.RecvMessage(c.bg)
|
||||
if err == nil {
|
||||
c.handleMessage(msg)
|
||||
} else if isTemporaryError(err) {
|
||||
c.errorf("read temporary error: %v", err)
|
||||
} else {
|
||||
c.shutdown(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyMessage clones a Cap'n Proto buffer.
|
||||
func copyMessage(msg *capnp.Message) *capnp.Message {
|
||||
n := msg.NumSegments()
|
||||
segments := make([][]byte, n)
|
||||
for i := range segments {
|
||||
s, err := msg.Segment(capnp.SegmentID(i))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
segments[i] = make([]byte, len(s.Data()))
|
||||
copy(segments[i], s.Data())
|
||||
}
|
||||
return &capnp.Message{Arena: capnp.MultiSegment(segments)}
|
||||
}
|
||||
|
||||
// copyRPCMessage clones an RPC packet.
|
||||
func copyRPCMessage(m rpccapnp.Message) rpccapnp.Message {
|
||||
mm := copyMessage(m.Segment().Message())
|
||||
rpcMsg, err := rpccapnp.ReadRootMessage(mm)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return rpcMsg
|
||||
}
|
||||
|
||||
// isTemporaryError reports whether e has a Temporary() method that
|
||||
// returns true.
|
||||
func isTemporaryError(e error) bool {
|
||||
type temp interface {
|
||||
Temporary() bool
|
||||
}
|
||||
t, ok := e.(temp)
|
||||
return ok && t.Temporary()
|
||||
}
|
Reference in New Issue
Block a user