diff --git a/config.go b/config.go index 097419635..40c6d3f78 100644 --- a/config.go +++ b/config.go @@ -16,7 +16,10 @@ package ttrpc -import "errors" +import ( + "context" + "errors" +) type serverConfig struct { handshaker Handshaker @@ -42,11 +45,40 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt { // WithUnaryServerInterceptor sets the provided interceptor on the server func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt { + return func(c *serverConfig) error { + if c.interceptor == nil { + c.interceptor = i + } else { + WithChainUnaryServerInterceptor(i)(c) + } + return nil + } +} + +// WithChainUnaryServerInterceptor sets the provided chain of server interceptors +func WithChainUnaryServerInterceptor(interceptors ...UnaryServerInterceptor) ServerOpt { return func(c *serverConfig) error { if c.interceptor != nil { - return errors.New("only one interceptor allowed per server") + interceptors = append([]UnaryServerInterceptor{c.interceptor}, interceptors...) + } + c.interceptor = func( + ctx context.Context, + unmarshal Unmarshaler, + info *UnaryServerInfo, + method Method) (interface{}, error) { + return interceptors[0](ctx, unmarshal, info, + chainUnaryServerInterceptors(info, method, interceptors[1:])) } - c.interceptor = i return nil } } + +func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, interceptors []UnaryServerInterceptor) Method { + if len(interceptors) == 0 { + return method + } + return func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + return interceptors[0](ctx, unmarshal, info, + chainUnaryServerInterceptors(info, method, interceptors[1:])) + } +} diff --git a/interceptor_test.go b/interceptor_test.go index 9e72528c6..1d1ede258 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -278,3 +278,67 @@ func TestChainUnaryServerInterceptor(t *testing.T) { strings.Join(recorded, " "), strings.Join(expected, " ")) } } + +func TestImplicitChainUnaryServerInterceptor(t *testing.T) { + var ( + orderIdx = 0 + recorded = []string{} + intercept = func(idx int, tag string) UnaryServerInterceptor { + return func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) { + if orderIdx != idx { + t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx) + } + recorded = append(recorded, tag) + orderIdx++ + return method(ctx, unmarshal) + } + } + + ctx = context.Background() + server = mustServer(t)(NewServer( + WithUnaryServerInterceptor( + intercept(0, "seen it"), + ), + WithUnaryServerInterceptor( + intercept(1, "been"), + ), + WithUnaryServerInterceptor( + intercept(2, "there"), + ), + WithChainUnaryServerInterceptor( + intercept(3, "done"), + intercept(4, "that"), + ), + )) + expected = []string{ + "seen it", + "been", + "there", + "done", + "that", + } + testImpl = &testingServer{} + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + tp := &internal.TestPayload{ + Foo: strings.Repeat("a", 16), + } + if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(recorded, expected) { + t.Fatalf("unexpected ttrpc chained server unary interceptor order (%s != %s)", + strings.Join(recorded, " "), strings.Join(expected, " ")) + } +}