diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index de92639afbee..c1d5160618a0 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -4,9 +4,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" - "errors" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -60,19 +58,7 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request func GenerateForwardedRequest(req *http.Request) (*Request, error) { var reader io.Reader = req.Body - ctx := req.Context() - maxRequestSize := ctx.Value("max_request_size") - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) - if !ok { - return nil, errors.New("could not parse max_request_size from request context") - } - if max > 0 { - reader = io.LimitReader(req.Body, max) - } - } - - body, err := ioutil.ReadAll(reader) + body, err := io.ReadAll(reader) if err != nil { return nil, err } diff --git a/http/handler.go b/http/handler.go index 3182dc352f42..b07b1254ddb5 100644 --- a/http/handler.go +++ b/http/handler.go @@ -226,12 +226,13 @@ func handler(props *vault.HandlerProperties) http.Handler { corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core) genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props) + wrappedHandler := wrapMaxRequestSizeHandler(genericWrappedHandler, props) // Wrap the handler with PrintablePathCheckHandler to check for non-printable // characters in the request path. - printablePathCheckHandler := genericWrappedHandler + printablePathCheckHandler := wrappedHandler if !props.DisablePrintableCheck { - printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(genericWrappedHandler, nil) + printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(wrappedHandler, nil) } return printablePathCheckHandler @@ -310,18 +311,12 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { // are performed. func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler { var maxRequestDuration time.Duration - var maxRequestSize int64 if props.ListenerConfig != nil { maxRequestDuration = props.ListenerConfig.MaxRequestDuration - maxRequestSize = props.ListenerConfig.MaxRequestSize } if maxRequestDuration == 0 { maxRequestDuration = vault.DefaultMaxRequestDuration } - if maxRequestSize == 0 { - maxRequestSize = DefaultMaxRequestSize - } - // Swallow this error since we don't want to pollute the logs and we also don't want to // return an HTTP error here. This information is best effort. hostname, _ := os.Hostname() @@ -355,11 +350,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr } else { ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration) } - // if maxRequestSize < 0, no need to set context value - // Add a size limiter if desired - if maxRequestSize > 0 { - ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) - } ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -703,25 +693,7 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. reader := r.Body - ctx := r.Context() - maxRequestSize := ctx.Value("max_request_size") - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) - if !ok { - return nil, errors.New("could not parse max_request_size from request context") - } - if max > 0 { - // MaxBytesReader won't do all the internal stuff it must unless it's - // given a ResponseWriter that implements the internal http interface - // requestTooLarger. So we let it have access to the underlying - // ResponseWriter. - inw := w - if myw, ok := inw.(logical.WrappingResponseWriter); ok { - inw = myw.Wrapped() - } - reader = http.MaxBytesReader(inw, r.Body, max) - } - } + var origBody io.ReadWriter if perfStandby { // Since we're checking PerfStandby here we key on origBody being nil @@ -743,16 +715,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // // A nil map will be returned if the format is empty or invalid. func parseFormRequest(r *http.Request) (map[string]interface{}, error) { - maxRequestSize := r.Context().Value("max_request_size") - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) - if !ok { - return nil, errors.New("could not parse max_request_size from request context") - } - if max > 0 { - r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max)) - } - } if err := r.ParseForm(); err != nil { return nil, err } diff --git a/http/handler_test.go b/http/handler_test.go index 49565b41e235..a3287ce7750b 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -1,6 +1,7 @@ package http import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -11,6 +12,7 @@ import ( "net/textproto" "net/url" "reflect" + "runtime" "strings" "testing" @@ -18,9 +20,11 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/versions" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/require" ) func TestHandler_parseMFAHandler(t *testing.T) { @@ -884,3 +888,59 @@ func TestHandler_Parse_Form(t *testing.T) { t.Fatal(diff) } } + +// TestHandler_MaxRequestSize verifies that a request larger than the +// MaxRequestSize fails +func TestHandler_MaxRequestSize(t *testing.T) { + t.Parallel() + cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{ + DefaultHandlerProperties: vault.HandlerProperties{ + ListenerConfig: &configutil.Listener{ + MaxRequestSize: 1024, + }, + }, + HandlerFunc: Handler, + NumCores: 1, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + _, err := client.KVv2("secret").Put(context.Background(), "foo", map[string]interface{}{ + "bar": strings.Repeat("a", 1025), + }) + + require.ErrorContains(t, err, "error parsing JSON") +} + +// TestHandler_MaxRequestSize_Memory sets the max request size to 1024 bytes, +// and creates a 1MB request. The test verifies that less than 1MB of memory is +// allocated when the request is sent. This test shouldn't be run in parallel, +// because it modifies GOMAXPROCS +func TestHandler_MaxRequestSize_Memory(t *testing.T) { + ln, addr := TestListener(t) + core, _, token := vault.TestCoreUnsealed(t) + TestServerWithListenerAndProperties(t, ln, addr, core, &vault.HandlerProperties{ + Core: core, + ListenerConfig: &configutil.Listener{ + Address: addr, + MaxRequestSize: 1024, + }, + }) + defer ln.Close() + + data := bytes.Repeat([]byte{0x1}, 1024*1024) + + req, err := http.NewRequest("POST", addr+"/v1/sys/unseal", bytes.NewReader(data)) + require.NoError(t, err) + req.Header.Set(consts.AuthHeaderName, token) + + client := cleanhttp.DefaultClient() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var start, end runtime.MemStats + runtime.GC() + runtime.ReadMemStats(&start) + client.Do(req) + runtime.ReadMemStats(&end) + require.Less(t, end.TotalAlloc-start.TotalAlloc, uint64(1024*1024)) +} diff --git a/http/util.go b/http/util.go index b8430479c840..488ab175d25d 100644 --- a/http/util.go +++ b/http/util.go @@ -3,13 +3,13 @@ package http import ( "bytes" "context" - "errors" "fmt" - "io/ioutil" + "io" "net" "net/http" "strings" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/helper/namespace" @@ -35,6 +35,27 @@ var ( adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {} ) +func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var maxRequestSize int64 + if props.ListenerConfig != nil { + maxRequestSize = props.ListenerConfig.MaxRequestSize + } + if maxRequestSize == 0 { + maxRequestSize = DefaultMaxRequestSize + } + ctx := r.Context() + originalBody := r.Body + if maxRequestSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + } + ctx = logical.CreateContextOriginalBody(ctx, originalBody) + r = r.WithContext(ctx) + + handler.ServeHTTP(w, r) + }) +} + func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) @@ -53,14 +74,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler } mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path) - // Clone body, so we do not close the request body reader - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - respondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) - return - } - r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) - quotaReq := "as.Request{ Type: quotas.TypeRateLimit, Path: path, @@ -80,7 +93,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // If any role-based quotas are enabled for this namespace/mount, just // do the role resolution once here. if requiresResolveRole { - role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes) + buf := bytes.Buffer{} + teeReader := io.TeeReader(r.Body, &buf) + role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader) + + // Reset the body if it was read + if buf.Len() > 0 { + r.Body = io.NopCloser(&buf) + originalBody, ok := logical.ContextOriginalBodyValue(r.Context()) + if ok { + r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody))) + } + } // add an entry to the context to prevent recalculating request role unnecessarily r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role)) quotaReq.Role = role @@ -139,3 +163,25 @@ func parseRemoteIPAddress(r *http.Request) string { return ip } + +type multiReaderCloser struct { + readers []io.Reader + io.Reader +} + +func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser { + return &multiReaderCloser{ + readers: readers, + Reader: io.MultiReader(readers...), + } +} + +func (m *multiReaderCloser) Close() error { + var err error + for _, r := range m.readers { + if c, ok := r.(io.Closer); ok { + err = multierror.Append(err, c.Close()) + } + } + return err +} diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 0d20a341ecc1..1f2ca53182a8 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -3,6 +3,7 @@ package logical import ( "context" "fmt" + "io" "net/http" "strings" "time" @@ -398,3 +399,14 @@ type CtxKeyRequestRole struct{} func (c CtxKeyRequestRole) String() string { return "request-role" } + +type ctxKeyOriginalBody struct{} + +func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) { + value, ok := ctx.Value(ctxKeyOriginalBody{}).(io.ReadCloser) + return value, ok +} + +func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context { + return context.WithValue(parent, ctxKeyOriginalBody{}, body) +} diff --git a/vault/core.go b/vault/core.go index eadd1ccb4d3a..357a0232e42a 100644 --- a/vault/core.go +++ b/vault/core.go @@ -3888,22 +3888,24 @@ func (c *Core) LoadNodeID() (string, error) { return hostname, nil } -// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given -// login request, accepting a byte payload -func (c *Core) DetermineRoleFromLoginRequestFromBytes(ctx context.Context, mountPoint string, payload []byte) string { - data := make(map[string]interface{}) - err := jsonutil.DecodeJSON(payload, &data) - if err != nil { - // Cannot discern a role from a request we cannot parse +// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given +// login request +func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string { + c.authLock.RLock() + defer c.authLock.RUnlock() + matchingBackend := c.router.MatchingBackend(ctx, mountPoint) + if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential { + // Role based quotas do not apply to this request return "" } - - return c.DetermineRoleFromLoginRequest(ctx, mountPoint, data) + return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data) } -// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given -// login request -func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string { +// DetermineRoleFromLoginRequestFromReader will determine the role that should +// be applied to a quota for a given login request. The reader will only be +// consumed if the matching backend for the mount point exists and is a secret +// backend +func (c *Core) DetermineRoleFromLoginRequestFromReader(ctx context.Context, mountPoint string, reader io.Reader) string { c.authLock.RLock() defer c.authLock.RUnlock() matchingBackend := c.router.MatchingBackend(ctx, mountPoint) @@ -3912,6 +3914,17 @@ func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint str return "" } + data := make(map[string]interface{}) + err := jsonutil.DecodeJSONFromReader(reader, &data) + if err != nil { + return "" + } + return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data) +} + +// doResolveRoleLocked does a login and resolve role request on the matching +// backend. Callers should have a read lock on c.authLock +func (c *Core) doResolveRoleLocked(ctx context.Context, mountPoint string, matchingBackend logical.Backend, data map[string]interface{}) string { resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{ MountPoint: mountPoint, Path: "login", diff --git a/vault/logical_system_raft.go b/vault/logical_system_raft.go index 40ac0bbcc300..6ecc0cd7e29b 100644 --- a/vault/logical_system_raft.go +++ b/vault/logical_system_raft.go @@ -562,7 +562,8 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool) framework.Ope if !ok { return logical.ErrorResponse("raft storage is not in use"), logical.ErrInvalidRequest } - if req.HTTPRequest == nil || req.HTTPRequest.Body == nil { + body, ok := logical.ContextOriginalBodyValue(ctx) + if !ok { return nil, errors.New("no reader for request") } @@ -575,7 +576,7 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool) framework.Ope // don't have to hold the full snapshot in memory. We also want to do // the restore in two parts so we can restore the snapshot while the // stateLock is write locked. - snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(req.HTTPRequest.Body, access) + snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(body, access) switch { case err == nil: case strings.Contains(err.Error(), "failed to open the sealed hashes"): diff --git a/vault/request_handling.go b/vault/request_handling.go index 9e1df4b13b55..c057c3bbc19c 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -561,6 +561,11 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R if ok { ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole) } + + body, ok := logical.ContextOriginalBodyValue(httpCtx) + if ok { + ctx = logical.CreateContextOriginalBody(ctx, body) + } resp, err = c.handleCancelableRequest(ctx, req) req.SetTokenEntry(nil) cancel()