Skip to content

Commit

Permalink
Fixes most TODO (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane authored Oct 12, 2023
1 parent 19f28a5 commit 5517025
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 50 deletions.
12 changes: 3 additions & 9 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,7 @@ func (o *operation) validate(mux *Mux, codecs codecMap) error {
}

if o.server.protocol.protocol() == ProtocolREST {
// REST always uses JSON.
// TODO: Allow non-JSON encodings with REST? Would require registering content-types with codecs.
// Would also require figuring out how to (un)marshal things other than messages when a body
// path indicates a non-message field (do-able with JSON, but maybe non-starter with proto?)
//
// REST always defaults to JSON.
// NB: This is fine to set even if a custom content-type is used via
// the use of google.api.HttpBody. The actual content-type and body
// data will be written via serverBodyPreparer implementation.
Expand All @@ -403,8 +399,8 @@ func (o *operation) validate(mux *Mux, codecs codecMap) error {
if _, supportsCompression := o.methodConf.compressorNames[reqMeta.compression]; supportsCompression {
o.server.reqCompression = o.client.reqCompression
}
// else: we'll just decompress and not recompress
// TODO: should we instead pick a supported compression scheme (if there is one)?
// If the server doesn't support the compression scheme, we'll just
// decompress and not recompress.
}

o.isValid = true // Successfully validated!
Expand Down Expand Up @@ -580,8 +576,6 @@ func (o *operation) resolveMethod(mux *Mux) error {
default:
methodConf := mux.methods[uriPath]
if methodConf == nil {
// TODO: if the service is known, but the method is not, we should send to the client
// a proper RPC error (encoded per protocol handler) with an Unimplemented code.
return errNotFound
}
o.restTarget = methodConf.httpRule
Expand Down
2 changes: 1 addition & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ type serverEnvelopedProtocolHandler interface {
// The given codec represents the sub-format used to send
// the request to the server (which may be used to decode
// the error).
decodeEndFromMessage(*operation, io.Reader) (responseEnd, error)
decodeEndFromMessage(*operation, *bytes.Buffer) (responseEnd, error)
}

// requestLineBuilder is an optional interface implemented by
Expand Down
9 changes: 1 addition & 8 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,14 +569,7 @@ func (c connectStreamServerProtocol) encodeEnvelope(env envelope) envelopeBytes
return envBytes
}

func (c connectStreamServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) {
// TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_, err := buffer.ReadFrom(reader)
if err != nil {
return responseEnd{}, err
}
func (c connectStreamServerProtocol) decodeEndFromMessage(_ *operation, buffer *bytes.Buffer) (responseEnd, error) {
var streamEnd connectStreamEnd
if err := json.Unmarshal(buffer.Bytes(), &streamEnd); err != nil {
return responseEnd{}, err
Expand Down
13 changes: 3 additions & 10 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (g grpcServerProtocol) encodeEnvelope(env envelope) envelopeBytes {
return envBytes
}

func (g grpcServerProtocol) decodeEndFromMessage(_ *operation, _ io.Reader) (responseEnd, error) {
func (g grpcServerProtocol) decodeEndFromMessage(_ *operation, _ *bytes.Buffer) (responseEnd, error) {
return responseEnd{}, errors.New("gRPC protocol does not allow embedding result/trailers in body")
}

Expand Down Expand Up @@ -184,7 +184,7 @@ func (g grpcWebClientProtocol) encodeEnd(op *operation, end *responseEnd, writer
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_ = trailers.Write(buffer)
// TODO: compress?
// TODO: Send envelope compressed if possible.
env := envelope{trailer: true, length: uint32(buffer.Len())}
envBytes := g.encodeEnvelope(env)
_, _ = writer.Write(envBytes[:])
Expand Down Expand Up @@ -254,14 +254,7 @@ func (g grpcWebServerProtocol) encodeEnvelope(env envelope) envelopeBytes {
return grpcServerProtocol{}.encodeEnvelope(env)
}

func (g grpcWebServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) {
// TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_, err := buffer.ReadFrom(reader)
if err != nil {
return responseEnd{}, err
}
func (g grpcWebServerProtocol) decodeEndFromMessage(_ *operation, buffer *bytes.Buffer) (responseEnd, error) {
headerLines := bytes.Split(buffer.Bytes(), []byte{'\r', '\n'})
trailers := make(http.Header, len(headerLines))
for i, headerLine := range headerLines {
Expand Down
48 changes: 31 additions & 17 deletions protocol_rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func (r restClientProtocol) acceptsStreamType(op *operation, streamType connect.
case connect.StreamTypeClient:
return restHTTPBodyRequest(op)
case connect.StreamTypeServer:
// TODO: support server streams even when body is not google.api.HttpBody
return restHTTPBodyResponse(op)
default:
return false
}
}

func (r restClientProtocol) endMustBeInHeaders() bool {
// TODO: when we support server streams over REST, this should return false when streaming
// TODO: when we support server streams over REST, this should return
// false when streaming
return true
}

Expand All @@ -82,24 +82,23 @@ func (r restClientProtocol) extractProtocolRequestHeaders(op *operation, headers
headers.Del("Content-Type")

if timeoutStr := headers.Get("X-Server-Timeout"); timeoutStr != "" {
timeout, err := strconv.ParseFloat(timeoutStr, 64)
timeout, err := restDecodeTimeout(timeoutStr)
if err != nil {
return requestMeta{}, err
}
reqMeta.timeout = time.Duration(timeout * float64(time.Second))
reqMeta.timeout = timeout
reqMeta.hasTimeout = true
}
return reqMeta, nil
}

func (r restClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int {
isErr := meta.end != nil && meta.end.err != nil
// TODO: this formulation might only be valid when meta.codec is JSON; support other codecs.
// Headers are only set if they are not already set, specially to allow
// for google.api.HttpBody payloads.
// Only JSON is supported for now unless using google.api.HttpBody
// payloads which override the content-type.
if headers["Content-Type"] == nil {
headers["Content-Type"] = []string{"application/" + meta.codec}
}
// TODO: Content-Encoding to compress error, too?
if !isErr && meta.compression != "" {
headers["Content-Encoding"] = []string{meta.compression}
}
Expand All @@ -126,12 +125,9 @@ func (r restClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io
stat := grpcStatusFromError(cerr)
bin, err := op.client.codec.MarshalAppend(nil, stat)
if err != nil {
// TODO: This is always uses JSON whereas above we use the given codec.
// If/when we support codecs for REST other than JSON, what should
// we do here?
bin = []byte(`{"code": 13, "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`)
// Hardcode the error to be a JSON-encoded gRPC status.
bin = []byte(`{"code":13,"message":"failed to marshal end error"}`)
}
// TODO: compress?
_, _ = writer.Write(bin)
return nil
}
Expand Down Expand Up @@ -272,9 +268,8 @@ func (r restServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers
if len(meta.acceptCompression) != 0 {
headers["Accept-Encoding"] = []string{strings.Join(meta.acceptCompression, ", ")}
}
if meta.timeout != 0 {
// Encode timeout as a float in seconds.
value := strconv.FormatFloat(meta.timeout.Seconds(), 'E', -1, 64)
if meta.hasTimeout {
value := restEncodeTimeout(meta.timeout)
headers["X-Server-Timeout"] = []string{value}
}
}
Expand Down Expand Up @@ -393,14 +388,33 @@ func (r restServerProtocol) requestLine(op *operation, req proto.Message) (urlPa
urlPath = path
queryParams = query.Encode()
includeBody = op.restTarget.requestBodyFields != nil // can be len(0) if body is '*'
// TODO: Should this return an error if URL (path + query string) is greater than op.methodConf.maxGetURLSz?
return urlPath, queryParams, op.restTarget.method, includeBody, nil
}

func (r restServerProtocol) String() string {
return protocolNameREST
}

// Decode timeout as a float in seconds from X-Server-Timeout header.
func restDecodeTimeout(timeout string) (time.Duration, error) {
if timeout == "" {
return 0, nil
}
val, err := strconv.ParseFloat(timeout, 64)
if err != nil {
return 0, fmt.Errorf("invalid timeout %q: %w", timeout, err)
}
return time.Duration(val * float64(time.Second)), nil
}

// Encode timeout as a float in seconds for X-Server-Timeout header.
func restEncodeTimeout(timeout time.Duration) string {
if timeout == 0 {
return ""
}
return strconv.FormatFloat(timeout.Seconds(), 'f', -1, 64)
}

func restHTTPBodyRequest(op *operation) bool {
return restIsHTTPBody(op.methodConf.descriptor.Input(), op.restTarget.requestBodyFields)
}
Expand Down
1 change: 1 addition & 0 deletions vanguard_restxrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func TestMux_RESTxRPC(t *testing.T) {
for key, values := range input.meta {
req.Header[key] = values
}
req.Header["X-Server-Timeout"] = []string{"30"}
if isCompressed {
req.Header["Content-Encoding"] = []string{comp.Name()}
}
Expand Down
4 changes: 0 additions & 4 deletions vanguard_rpcxrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,6 @@ func TestMux_RPCxRPC(t *testing.T) {
},
},
},
// TODO: Add more tests -- more permutations to catch things like trailers-only responses in gRPC,
// empty client streams, empty server streams
// TODO: Exercise Connect GET for unary operations with Connect client
// TODO: Verify timeouts are propagated correctly
}
for _, opts := range testOpts {
opts := opts
Expand Down
56 changes: 55 additions & 1 deletion vanguard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ import (
"google.golang.org/protobuf/types/known/emptypb"
)

const (
defaultTestTimeout = 30 * time.Second
)

func TestMux_BufferTooLargeFails(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -628,9 +632,12 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) {
connect.WithHTTPGetMaxURLSize(512, false),
connect.WithSendGzip(),
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

req := connect.NewRequest(largeRequest)
req.Header().Set("Test", t.Name()) // must set this for interceptor to work
_, err := client.GetBook(context.Background(), req)
_, err := client.GetBook(ctx, req)
// No error means it made through above interceptor unscathed
// (so server handler got a POST).
require.NoError(t, err)
Expand Down Expand Up @@ -1611,6 +1618,9 @@ func TestRuleSelector(t *testing.T) {
}))

ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/selector/shelves/123/books/456", http.NoBody)
require.NoError(t, err)
req.Header.Set("Message", "hello")
Expand Down Expand Up @@ -1792,6 +1802,9 @@ func (i *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
ctx context.Context,
req connect.AnyRequest,
) (_ connect.AnyResponse, resultError error) {
if err := assertTestTimeoutEncoded(ctx); err != nil {
return nil, err
}
val := req.Header().Get("test")
if val == "" {
return next(ctx, req)
Expand Down Expand Up @@ -1874,6 +1887,9 @@ func (i *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc
ctx context.Context,
conn connect.StreamingHandlerConn,
) (resultError error) {
if err := assertTestTimeoutEncoded(ctx); err != nil {
return err
}
val := conn.RequestHeader().Get("test")
if val == "" {
return next(ctx, conn)
Expand Down Expand Up @@ -2057,6 +2073,19 @@ func (i *testInterceptor) restUnaryHandler(
http.Error(rsp, "invalid test header", http.StatusInternalServerError)
return
}
timeoutStr := req.Header.Get("X-Server-Timeout")
timeout, err := restDecodeTimeout(timeoutStr)
if err != nil {
http.Error(rsp, "invalid timeout header", http.StatusInternalServerError)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
defer cancel()
if err := assertTestTimeoutEncoded(ctx); err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
req = req.WithContext(ctx)
if err := handler(stream, rsp, req); err != nil {
stream.T.Error(err)
http.Error(rsp, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -2178,6 +2207,8 @@ func outputFromUnary[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
if len(reqs) != 1 {
return nil, nil, nil, fmt.Errorf("unary method takes exactly 1 request but got %d", len(reqs))
}
Expand All @@ -2200,6 +2231,8 @@ func outputFromServerStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
if len(reqs) != 1 {
return nil, nil, nil, fmt.Errorf("unary method takes exactly 1 request but got %d", len(reqs))
}
Expand All @@ -2226,6 +2259,8 @@ func outputFromClientStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
str := method(ctx)
for k, v := range headers {
str.RequestHeader()[k] = v
Expand Down Expand Up @@ -2255,6 +2290,8 @@ func outputFromBidiStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
str := method(ctx)
defer func() {
_ = str.CloseResponse()
Expand Down Expand Up @@ -2563,3 +2600,20 @@ func newConnectError(code connect.Code, msg string) *connect.Error {
err.Meta()
return err
}

// assert a 30 second timeout has been set.
func assertTestTimeoutEncoded(ctx context.Context) error {
now := time.Now()
deadline, ok := ctx.Deadline()
if !ok {
return errors.New("context should have deadline")
}
if deadline.After(now.Add(defaultTestTimeout)) {
return errors.New("context deadline should be 30 seconds")
}
// Allow a little bit of slop.
if deadline.Before(now.Add(defaultTestTimeout - 5*time.Second)) {
return errors.New("context deadline should be at least 20 seconds")
}
return nil
}

0 comments on commit 5517025

Please sign in to comment.