From 57744ebbe8f95ee65855f4611b93dd0f2c7b18ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Wed, 25 Dec 2024 14:53:14 +0300 Subject: [PATCH 01/20] :bug: bug: fix EnableSplittingOnParsers is not functional (#3231) * :bug: bug: fix EnableSplittingOnParsers is not functional * remove wrong testcase * add support for external xml decoders * improve test coverage * fix linter * update * add reset methods * improve test coverage * merge Form and MultipartForm methods * fix linter * split reset and putting steps * fix linter --- app.go | 10 +++ bind.go | 109 +++++++++++++++++++---- bind_test.go | 38 +++++--- binder/binder.go | 79 ++++++++++++++--- binder/binder_test.go | 28 ++++++ binder/cbor.go | 17 ++-- binder/cbor_test.go | 92 ++++++++++++++++++++ binder/cookie.go | 19 ++-- binder/cookie_test.go | 90 +++++++++++++++++++ binder/form.go | 32 +++++-- binder/form_test.go | 174 +++++++++++++++++++++++++++++++++++++ binder/header.go | 17 ++-- binder/header_test.go | 88 +++++++++++++++++++ binder/json.go | 17 ++-- binder/json_test.go | 69 +++++++++++++++ binder/mapping.go | 31 ++++--- binder/mapping_test.go | 80 +++++++++++++++++ binder/query.go | 19 ++-- binder/query_test.go | 87 +++++++++++++++++++ binder/resp_header.go | 17 ++-- binder/resp_header_test.go | 79 +++++++++++++++++ binder/uri.go | 11 ++- binder/uri_test.go | 77 ++++++++++++++++ binder/xml.go | 20 +++-- binder/xml_test.go | 135 ++++++++++++++++++++++++++++ ctx_test.go | 4 +- docs/api/bind.md | 46 ++-------- docs/api/fiber.md | 1 + docs/whats_new.md | 1 + redirect.go | 4 +- 30 files changed, 1339 insertions(+), 152 deletions(-) create mode 100644 binder/binder_test.go create mode 100644 binder/cbor_test.go create mode 100644 binder/cookie_test.go create mode 100644 binder/form_test.go create mode 100644 binder/header_test.go create mode 100644 binder/json_test.go create mode 100644 binder/query_test.go create mode 100644 binder/resp_header_test.go create mode 100644 binder/uri_test.go create mode 100644 binder/xml_test.go diff --git a/app.go b/app.go index 7f9193a1a1..5e5475b5f1 100644 --- a/app.go +++ b/app.go @@ -341,6 +341,13 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // Default: xml.Marshal XMLEncoder utils.XMLMarshal `json:"-"` + // XMLDecoder set by an external client of Fiber it will use the provided implementation of a + // XMLUnmarshal + // + // Allowing for flexibility in using another XML library for decoding + // Default: xml.Unmarshal + XMLDecoder utils.XMLUnmarshal `json:"-"` + // If you find yourself behind some sort of proxy, like a load balancer, // then certain header information may be sent to you using special X-Forwarded-* headers or the Forwarded header. // For example, the Host HTTP header is usually used to return the requested host. @@ -560,6 +567,9 @@ func New(config ...Config) *App { if app.config.XMLEncoder == nil { app.config.XMLEncoder = xml.Marshal } + if app.config.XMLDecoder == nil { + app.config.XMLDecoder = xml.Unmarshal + } if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } diff --git a/bind.go b/bind.go index 5af83743a0..13d9d3675e 100644 --- a/bind.go +++ b/bind.go @@ -77,7 +77,16 @@ func (b *Bind) Custom(name string, dest any) error { // Header binds the request header strings into the struct, map[string]string and map[string][]string. func (b *Bind) Header(out any) error { - if err := b.returnErr(binder.HeaderBinder.Bind(b.ctx.Request(), out)); err != nil { + bind := binder.GetFromThePool[*binder.HeaderBinding](&binder.HeaderBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.HeaderBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Request(), out)); err != nil { return err } @@ -86,7 +95,16 @@ func (b *Bind) Header(out any) error { // RespHeader binds the response header strings into the struct, map[string]string and map[string][]string. func (b *Bind) RespHeader(out any) error { - if err := b.returnErr(binder.RespHeaderBinder.Bind(b.ctx.Response(), out)); err != nil { + bind := binder.GetFromThePool[*binder.RespHeaderBinding](&binder.RespHeaderBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.RespHeaderBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Response(), out)); err != nil { return err } @@ -96,7 +114,16 @@ func (b *Bind) RespHeader(out any) error { // Cookie binds the request cookie strings into the struct, map[string]string and map[string][]string. // NOTE: If your cookie is like key=val1,val2; they'll be binded as an slice if your map is map[string][]string. Else, it'll use last element of cookie. func (b *Bind) Cookie(out any) error { - if err := b.returnErr(binder.CookieBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.CookieBinding](&binder.CookieBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.CookieBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -105,7 +132,16 @@ func (b *Bind) Cookie(out any) error { // Query binds the query string into the struct, map[string]string and map[string][]string. func (b *Bind) Query(out any) error { - if err := b.returnErr(binder.QueryBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.QueryBinding](&binder.QueryBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.QueryBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -114,7 +150,16 @@ func (b *Bind) Query(out any) error { // JSON binds the body string into the struct. func (b *Bind) JSON(out any) error { - if err := b.returnErr(binder.JSONBinder.Bind(b.ctx.Body(), b.ctx.App().Config().JSONDecoder, out)); err != nil { + bind := binder.GetFromThePool[*binder.JSONBinding](&binder.JSONBinderPool) + bind.JSONDecoder = b.ctx.App().Config().JSONDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.JSONBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } @@ -123,7 +168,16 @@ func (b *Bind) JSON(out any) error { // CBOR binds the body string into the struct. func (b *Bind) CBOR(out any) error { - if err := b.returnErr(binder.CBORBinder.Bind(b.ctx.Body(), b.ctx.App().Config().CBORDecoder, out)); err != nil { + bind := binder.GetFromThePool[*binder.CBORBinding](&binder.CBORBinderPool) + bind.CBORDecoder = b.ctx.App().Config().CBORDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.CBORBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } return b.validateStruct(out) @@ -131,7 +185,16 @@ func (b *Bind) CBOR(out any) error { // XML binds the body string into the struct. func (b *Bind) XML(out any) error { - if err := b.returnErr(binder.XMLBinder.Bind(b.ctx.Body(), out)); err != nil { + bind := binder.GetFromThePool[*binder.XMLBinding](&binder.XMLBinderPool) + bind.XMLDecoder = b.ctx.App().config.XMLDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.XMLBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } @@ -139,8 +202,20 @@ func (b *Bind) XML(out any) error { } // Form binds the form into the struct, map[string]string and map[string][]string. +// If Content-Type is "application/x-www-form-urlencoded" or "multipart/form-data", it will bind the form values. +// +// Binding multipart files is not supported yet. func (b *Bind) Form(out any) error { - if err := b.returnErr(binder.FormBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.FormBinding](&binder.FormBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.FormBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -149,16 +224,14 @@ func (b *Bind) Form(out any) error { // URI binds the route parameters into the struct, map[string]string and map[string][]string. func (b *Bind) URI(out any) error { - if err := b.returnErr(binder.URIBinder.Bind(b.ctx.Route().Params, b.ctx.Params, out)); err != nil { - return err - } + bind := binder.GetFromThePool[*binder.URIBinding](&binder.URIBinderPool) - return b.validateStruct(out) -} + // Reset & put binder + defer func() { + binder.PutToThePool(&binder.URIBinderPool, bind) + }() -// MultipartForm binds the multipart form into the struct, map[string]string and map[string][]string. -func (b *Bind) MultipartForm(out any) error { - if err := b.returnErr(binder.FormBinder.BindMultipart(b.ctx.RequestCtx(), out)); err != nil { + if err := b.returnErr(bind.Bind(b.ctx.Route().Params, b.ctx.Params, out)); err != nil { return err } @@ -193,10 +266,8 @@ func (b *Bind) Body(out any) error { return b.XML(out) case MIMEApplicationCBOR: return b.CBOR(out) - case MIMEApplicationForm: + case MIMEApplicationForm, MIMEMultipartForm: return b.Form(out) - case MIMEMultipartForm: - return b.MultipartForm(out) } // No suitable content type found diff --git a/bind_test.go b/bind_test.go index 55d2dd75e9..52c9004c61 100644 --- a/bind_test.go +++ b/bind_test.go @@ -32,7 +32,9 @@ func Test_returnErr(t *testing.T) { // go test -run Test_Bind_Query -v func Test_Bind_Query(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { @@ -111,7 +113,9 @@ func Test_Bind_Query(t *testing.T) { func Test_Bind_Query_Map(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) @@ -318,13 +322,13 @@ func Test_Bind_Header(t *testing.T) { c.Request().Header.Add("Hobby", "golang,fiber") q := new(Header) require.NoError(t, c.Bind().Header(q)) - require.Len(t, q.Hobby, 2) + require.Len(t, q.Hobby, 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = new(Header) require.NoError(t, c.Bind().Header(q)) - require.Len(t, q.Hobby, 3) + require.Len(t, q.Hobby, 1) empty := new(Header) c.Request().Header.Del("hobby") @@ -357,7 +361,7 @@ func Test_Bind_Header(t *testing.T) { require.Equal(t, "go,fiber", h2.Hobby) require.True(t, h2.Bool) require.Equal(t, "Jane Doe", h2.Name) // check value get overwritten - require.Equal(t, []string{"milo", "coke", "pepsi"}, h2.FavouriteDrinks) + require.Equal(t, []string{"milo,coke,pepsi"}, h2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, h2.Empty) require.Equal(t, []string{""}, h2.Alloc) @@ -386,13 +390,13 @@ func Test_Bind_Header_Map(t *testing.T) { c.Request().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) - require.Len(t, q["Hobby"], 2) + require.Len(t, q["Hobby"], 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) - require.Len(t, q["Hobby"], 3) + require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Request().Header.Del("hobby") @@ -543,7 +547,9 @@ func Test_Bind_Header_Schema(t *testing.T) { // go test -run Test_Bind_Resp_Header -v func Test_Bind_RespHeader(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Header struct { @@ -627,13 +633,13 @@ func Test_Bind_RespHeader_Map(t *testing.T) { c.Response().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) - require.Len(t, q["Hobby"], 2) + require.Len(t, q["Hobby"], 1) c.Response().Header.Del("hobby") c.Response().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) - require.Len(t, q["Hobby"], 3) + require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Response().Header.Del("hobby") @@ -751,7 +757,9 @@ func Benchmark_Bind_Query_WithParseParam(b *testing.B) { func Benchmark_Bind_Query_Comma(b *testing.B) { var err error - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { @@ -1341,7 +1349,9 @@ func Benchmark_Bind_URI_Map(b *testing.B) { func Test_Bind_Cookie(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Cookie struct { @@ -1414,7 +1424,9 @@ func Test_Bind_Cookie(t *testing.T) { func Test_Bind_Cookie_Map(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) diff --git a/binder/binder.go b/binder/binder.go index bb3fc2b394..06c7c926a5 100644 --- a/binder/binder.go +++ b/binder/binder.go @@ -2,6 +2,7 @@ package binder import ( "errors" + "sync" ) // Binder errors @@ -10,15 +11,69 @@ var ( ErrMapNotConvertable = errors.New("binder: map is not convertable to map[string]string or map[string][]string") ) -// Init default binders for Fiber -var ( - HeaderBinder = &headerBinding{} - RespHeaderBinder = &respHeaderBinding{} - CookieBinder = &cookieBinding{} - QueryBinder = &queryBinding{} - FormBinder = &formBinding{} - URIBinder = &uriBinding{} - XMLBinder = &xmlBinding{} - JSONBinder = &jsonBinding{} - CBORBinder = &cborBinding{} -) +var HeaderBinderPool = sync.Pool{ + New: func() any { + return &HeaderBinding{} + }, +} + +var RespHeaderBinderPool = sync.Pool{ + New: func() any { + return &RespHeaderBinding{} + }, +} + +var CookieBinderPool = sync.Pool{ + New: func() any { + return &CookieBinding{} + }, +} + +var QueryBinderPool = sync.Pool{ + New: func() any { + return &QueryBinding{} + }, +} + +var FormBinderPool = sync.Pool{ + New: func() any { + return &FormBinding{} + }, +} + +var URIBinderPool = sync.Pool{ + New: func() any { + return &URIBinding{} + }, +} + +var XMLBinderPool = sync.Pool{ + New: func() any { + return &XMLBinding{} + }, +} + +var JSONBinderPool = sync.Pool{ + New: func() any { + return &JSONBinding{} + }, +} + +var CBORBinderPool = sync.Pool{ + New: func() any { + return &CBORBinding{} + }, +} + +func GetFromThePool[T any](pool *sync.Pool) T { + binder, ok := pool.Get().(T) + if !ok { + panic(errors.New("failed to type-assert to T")) + } + + return binder +} + +func PutToThePool[T any](pool *sync.Pool, binder T) { + pool.Put(binder) +} diff --git a/binder/binder_test.go b/binder/binder_test.go new file mode 100644 index 0000000000..d078ed02c6 --- /dev/null +++ b/binder/binder_test.go @@ -0,0 +1,28 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GetAndPutToThePool(t *testing.T) { + t.Parallel() + + // Panics in case we get from another pool + require.Panics(t, func() { + _ = GetFromThePool[*HeaderBinding](&CookieBinderPool) + }) + + // We get from the pool + binder := GetFromThePool[*HeaderBinding](&HeaderBinderPool) + PutToThePool(&HeaderBinderPool, binder) + + _ = GetFromThePool[*RespHeaderBinding](&RespHeaderBinderPool) + _ = GetFromThePool[*QueryBinding](&QueryBinderPool) + _ = GetFromThePool[*FormBinding](&FormBinderPool) + _ = GetFromThePool[*URIBinding](&URIBinderPool) + _ = GetFromThePool[*XMLBinding](&XMLBinderPool) + _ = GetFromThePool[*JSONBinding](&JSONBinderPool) + _ = GetFromThePool[*CBORBinding](&CBORBinderPool) +} diff --git a/binder/cbor.go b/binder/cbor.go index 6f47893531..8b1d0d4291 100644 --- a/binder/cbor.go +++ b/binder/cbor.go @@ -4,15 +4,22 @@ import ( "github.com/gofiber/utils/v2" ) -// cborBinding is the CBOR binder for CBOR request body. -type cborBinding struct{} +// CBORBinding is the CBOR binder for CBOR request body. +type CBORBinding struct { + CBORDecoder utils.CBORUnmarshal +} // Name returns the binding name. -func (*cborBinding) Name() string { +func (*CBORBinding) Name() string { return "cbor" } // Bind parses the request body as CBOR and returns the result. -func (*cborBinding) Bind(body []byte, cborDecoder utils.CBORUnmarshal, out any) error { - return cborDecoder(body, out) +func (b *CBORBinding) Bind(body []byte, out any) error { + return b.CBORDecoder(body, out) +} + +// Reset resets the CBORBinding binder. +func (b *CBORBinding) Reset() { + b.CBORDecoder = nil } diff --git a/binder/cbor_test.go b/binder/cbor_test.go new file mode 100644 index 0000000000..16c24cbbca --- /dev/null +++ b/binder/cbor_test.go @@ -0,0 +1,92 @@ +package binder + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/require" +) + +func Test_CBORBinder_Bind(t *testing.T) { + t.Parallel() + + b := &CBORBinding{ + CBORDecoder: cbor.Unmarshal, + } + require.Equal(t, "cbor", b.Name()) + + type Post struct { + Title string `cbor:"title"` + } + + type User struct { + Name string `cbor:"name"` + Posts []Post `cbor:"posts"` + Names []string `cbor:"names"` + Age int `cbor:"age"` + } + var user User + + wantedUser := User{ + Name: "john", + Names: []string{ + "john", + "doe", + }, + Age: 42, + Posts: []Post{ + {Title: "post1"}, + {Title: "post2"}, + {Title: "post3"}, + }, + } + + body, err := cbor.Marshal(wantedUser) + require.NoError(t, err) + + err = b.Bind(body, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.Nil(t, b.CBORDecoder) +} + +func Benchmark_CBORBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &CBORBinding{ + CBORDecoder: cbor.Unmarshal, + } + + type User struct { + Name string `cbor:"name"` + Age int `cbor:"age"` + } + + var user User + wantedUser := User{ + Name: "john", + Age: 42, + } + + body, err := cbor.Marshal(wantedUser) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + err = binder.Bind(body, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) +} diff --git a/binder/cookie.go b/binder/cookie.go index 62271c8e38..230794f45a 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -8,20 +8,22 @@ import ( "github.com/valyala/fasthttp" ) -// cookieBinding is the cookie binder for cookie request body. -type cookieBinding struct{} +// CookieBinding is the cookie binder for cookie request body. +type CookieBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*cookieBinding) Name() string { +func (*CookieBinding) Name() string { return "cookie" } // Bind parses the request cookie and returns the result. -func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.Request.Header.VisitAllCookie(func(key, val []byte) { + req.Header.VisitAllCookie(func(key, val []byte) { if err != nil { return } @@ -29,7 +31,7 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -45,3 +47,8 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the CookieBinding binder. +func (b *CookieBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/cookie_test.go b/binder/cookie_test.go new file mode 100644 index 0000000000..bca316c9fe --- /dev/null +++ b/binder/cookie_test.go @@ -0,0 +1,90 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_CookieBinder_Bind(t *testing.T) { + t.Parallel() + + b := &CookieBinding{ + EnableSplitting: true, + } + require.Equal(t, "cookie", b.Name()) + + type Post struct { + Title string `form:"title"` + } + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Posts []Post `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + + req.Header.SetCookie("name", "john") + req.Header.SetCookie("names", "john,doe") + req.Header.SetCookie("age", "42") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_CookieBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &CookieBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.Header.SetCookie("name", "john") + req.Header.SetCookie("age", "42") + req.Header.SetCookie("posts", "post1,post2,post3") + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/form.go b/binder/form.go index e0f1acd302..7ab0b1b258 100644 --- a/binder/form.go +++ b/binder/form.go @@ -8,20 +8,29 @@ import ( "github.com/valyala/fasthttp" ) -// formBinding is the form binder for form request body. -type formBinding struct{} +const MIMEMultipartForm string = "multipart/form-data" + +// FormBinding is the form binder for form request body. +type FormBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*formBinding) Name() string { +func (*FormBinding) Name() string { return "form" } // Bind parses the request body and returns the result. -func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.PostArgs().VisitAll(func(key, val []byte) { + // Handle multipart form + if FilterFlags(utils.UnsafeString(req.Header.ContentType())) == MIMEMultipartForm { + return b.bindMultipart(req, out) + } + + req.PostArgs().VisitAll(func(key, val []byte) { if err != nil { return } @@ -33,7 +42,7 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -50,12 +59,17 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } -// BindMultipart parses the request body and returns the result. -func (b *formBinding) BindMultipart(reqCtx *fasthttp.RequestCtx, out any) error { - data, err := reqCtx.MultipartForm() +// bindMultipart parses the request body and returns the result. +func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { + data, err := req.MultipartForm() if err != nil { return err } return parse(b.Name(), out, data.Value) } + +// Reset resets the FormBinding binder. +func (b *FormBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/form_test.go b/binder/form_test.go new file mode 100644 index 0000000000..c3c52c73fd --- /dev/null +++ b/binder/form_test.go @@ -0,0 +1,174 @@ +package binder + +import ( + "bytes" + "mime/multipart" + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_FormBinder_Bind(t *testing.T) { + t.Parallel() + + b := &FormBinding{ + EnableSplitting: true, + } + require.Equal(t, "form", b.Name()) + + type Post struct { + Title string `form:"title"` + } + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Posts []Post `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.SetBodyString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") + req.Header.SetContentType("application/x-www-form-urlencoded") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_FormBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &QueryBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3") + req.Header.SetContentType("application/x-www-form-urlencoded") + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) +} + +func Test_FormBinder_BindMultipart(t *testing.T) { + t.Parallel() + + b := &FormBinding{ + EnableSplitting: true, + } + require.Equal(t, "form", b.Name()) + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + + buf := &bytes.Buffer{} + mw := multipart.NewWriter(buf) + + require.NoError(t, mw.WriteField("name", "john")) + require.NoError(t, mw.WriteField("names", "john")) + require.NoError(t, mw.WriteField("names", "doe")) + require.NoError(t, mw.WriteField("age", "42")) + require.NoError(t, mw.Close()) + + req.Header.SetContentType(mw.FormDataContentType()) + req.SetBody(buf.Bytes()) + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") +} + +func Benchmark_FormBinder_BindMultipart(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &FormBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `form:"name"` + Posts []string `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + buf := &bytes.Buffer{} + mw := multipart.NewWriter(buf) + + require.NoError(b, mw.WriteField("name", "john")) + require.NoError(b, mw.WriteField("age", "42")) + require.NoError(b, mw.WriteField("posts", "post1")) + require.NoError(b, mw.WriteField("posts", "post2")) + require.NoError(b, mw.WriteField("posts", "post3")) + require.NoError(b, mw.Close()) + + req.Header.SetContentType(mw.FormDataContentType()) + req.SetBody(buf.Bytes()) + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) +} diff --git a/binder/header.go b/binder/header.go index 258a0b2229..b04ce9add3 100644 --- a/binder/header.go +++ b/binder/header.go @@ -8,22 +8,24 @@ import ( "github.com/valyala/fasthttp" ) -// headerBinding is the header binder for header request body. -type headerBinding struct{} +// v is the header binder for header request body. +type HeaderBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*headerBinding) Name() string { +func (*HeaderBinding) Name() string { return "header" } // Bind parses the request header and returns the result. -func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { +func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) req.Header.VisitAll(func(key, val []byte) { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -35,3 +37,8 @@ func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the HeaderBinding binder. +func (b *HeaderBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/header_test.go b/binder/header_test.go new file mode 100644 index 0000000000..bdef8680ac --- /dev/null +++ b/binder/header_test.go @@ -0,0 +1,88 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_HeaderBinder_Bind(t *testing.T) { + t.Parallel() + + b := &HeaderBinding{ + EnableSplitting: true, + } + require.Equal(t, "header", b.Name()) + + type User struct { + Name string `header:"Name"` + Names []string `header:"Names"` + Posts []string `header:"Posts"` + Age int `header:"Age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.Header.Set("name", "john") + req.Header.Set("names", "john,doe") + req.Header.Set("age", "42") + req.Header.Set("posts", "post1,post2,post3") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0]) + require.Equal(t, "post2", user.Posts[1]) + require.Equal(t, "post3", user.Posts[2]) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_HeaderBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &HeaderBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `header:"Name"` + Posts []string `header:"Posts"` + Age int `header:"Age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.Header.Set("name", "john") + req.Header.Set("age", "42") + req.Header.Set("posts", "post1,post2,post3") + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/json.go b/binder/json.go index 7889aee8a2..a6a904b550 100644 --- a/binder/json.go +++ b/binder/json.go @@ -4,15 +4,22 @@ import ( "github.com/gofiber/utils/v2" ) -// jsonBinding is the JSON binder for JSON request body. -type jsonBinding struct{} +// JSONBinding is the JSON binder for JSON request body. +type JSONBinding struct { + JSONDecoder utils.JSONUnmarshal +} // Name returns the binding name. -func (*jsonBinding) Name() string { +func (*JSONBinding) Name() string { return "json" } // Bind parses the request body as JSON and returns the result. -func (*jsonBinding) Bind(body []byte, jsonDecoder utils.JSONUnmarshal, out any) error { - return jsonDecoder(body, out) +func (b *JSONBinding) Bind(body []byte, out any) error { + return b.JSONDecoder(body, out) +} + +// Reset resets the JSONBinding binder. +func (b *JSONBinding) Reset() { + b.JSONDecoder = nil } diff --git a/binder/json_test.go b/binder/json_test.go new file mode 100644 index 0000000000..00718fdf26 --- /dev/null +++ b/binder/json_test.go @@ -0,0 +1,69 @@ +package binder + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_JSON_Binding_Bind(t *testing.T) { + t.Parallel() + + b := &JSONBinding{ + JSONDecoder: json.Unmarshal, + } + require.Equal(t, "json", b.Name()) + + type Post struct { + Title string `json:"title"` + } + + type User struct { + Name string `json:"name"` + Posts []Post `json:"posts"` + Age int `json:"age"` + } + var user User + + err := b.Bind([]byte(`{"name":"john","age":42,"posts":[{"title":"post1"},{"title":"post2"},{"title":"post3"}]}`), &user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + + b.Reset() + require.Nil(t, b.JSONDecoder) +} + +func Benchmark_JSON_Binding_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &JSONBinding{ + JSONDecoder: json.Unmarshal, + } + + type User struct { + Name string `json:"name"` + Posts []string `json:"posts"` + Age int `json:"age"` + } + + var user User + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind([]byte(`{"name":"john","age":42,"posts":["post1","post2","post3"]}`), &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Equal(b, "post1", user.Posts[0]) + require.Equal(b, "post2", user.Posts[1]) + require.Equal(b, "post3", user.Posts[2]) +} diff --git a/binder/mapping.go b/binder/mapping.go index 055345fe26..d8b692f7e4 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -32,7 +32,7 @@ var ( // decoderPoolMap helps to improve binders decoderPoolMap = map[string]*sync.Pool{} // tags is used to classify parser's pool - tags = []string{HeaderBinder.Name(), RespHeaderBinder.Name(), CookieBinder.Name(), QueryBinder.Name(), FormBinder.Name(), URIBinder.Name()} + tags = []string{"header", "respHeader", "cookie", "query", "form", "uri"} ) // SetParserDecoder allow globally change the option of form decoder, update decoderPool @@ -107,8 +107,9 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error { elem := reflect.TypeOf(ptr).Elem() - // map[string][]string - if elem.Kind() == reflect.Slice { + //nolint:exhaustive // it's not necessary to check all types + switch elem.Kind() { + case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { return ErrMapNotConvertable @@ -117,18 +118,20 @@ func parseToMap(ptr any, data map[string][]string) error { for k, v := range data { newMap[k] = v } + case reflect.String, reflect.Interface: + newMap, ok := ptr.(map[string]string) + if !ok { + return ErrMapNotConvertable + } - return nil - } - - // map[string]string - newMap, ok := ptr.(map[string]string) - if !ok { - return ErrMapNotConvertable - } + for k, v := range data { + if len(v) == 0 { + newMap[k] = "" + continue + } - for k, v := range data { - newMap[k] = v[len(v)-1] + newMap[k] = v[len(v)-1] + } } return nil @@ -223,7 +226,7 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool { continue } // Get tag from field if exist - inputFieldName := typeField.Tag.Get(QueryBinder.Name()) + inputFieldName := typeField.Tag.Get("query") // Name of query binder if inputFieldName == "" { inputFieldName = typeField.Name } else { diff --git a/binder/mapping_test.go b/binder/mapping_test.go index e6fc8146f7..75cdc78305 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -29,6 +29,21 @@ func Test_EqualFieldType(t *testing.T) { require.True(t, equalFieldType(&user, reflect.String, "Address")) require.True(t, equalFieldType(&user, reflect.Int, "AGE")) require.True(t, equalFieldType(&user, reflect.Int, "age")) + + var user2 struct { + User struct { + Name string + Address string `query:"address"` + Age int `query:"AGE"` + } `query:"user"` + } + + require.True(t, equalFieldType(&user2, reflect.String, "user.name")) + require.True(t, equalFieldType(&user2, reflect.String, "user.Name")) + require.True(t, equalFieldType(&user2, reflect.String, "user.address")) + require.True(t, equalFieldType(&user2, reflect.String, "user.Address")) + require.True(t, equalFieldType(&user2, reflect.Int, "user.AGE")) + require.True(t, equalFieldType(&user2, reflect.Int, "user.age")) } func Test_ParseParamSquareBrackets(t *testing.T) { @@ -97,3 +112,68 @@ func Test_ParseParamSquareBrackets(t *testing.T) { }) } } + +func Test_parseToMap(t *testing.T) { + inputMap := map[string][]string{ + "key1": {"value1", "value2"}, + "key2": {"value3"}, + "key3": {"value4"}, + } + + // Test map[string]string + m := make(map[string]string) + err := parseToMap(m, inputMap) + require.NoError(t, err) + + require.Equal(t, "value2", m["key1"]) + require.Equal(t, "value3", m["key2"]) + require.Equal(t, "value4", m["key3"]) + + // Test map[string][]string + m2 := make(map[string][]string) + err = parseToMap(m2, inputMap) + require.NoError(t, err) + + require.Len(t, m2["key1"], 2) + require.Contains(t, m2["key1"], "value1") + require.Contains(t, m2["key1"], "value2") + require.Len(t, m2["key2"], 1) + require.Len(t, m2["key3"], 1) + + // Test map[string]any + m3 := make(map[string]any) + err = parseToMap(m3, inputMap) + require.ErrorIs(t, err, ErrMapNotConvertable) +} + +func Test_FilterFlags(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + input: "text/javascript; charset=utf-8", + expected: "text/javascript", + }, + { + input: "text/javascript", + expected: "text/javascript", + }, + + { + input: "text/javascript; charset=utf-8; foo=bar", + expected: "text/javascript", + }, + { + input: "text/javascript charset=utf-8", + expected: "text/javascript", + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := FilterFlags(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/binder/query.go b/binder/query.go index 8f029d30c4..9ee500ba63 100644 --- a/binder/query.go +++ b/binder/query.go @@ -8,20 +8,22 @@ import ( "github.com/valyala/fasthttp" ) -// queryBinding is the query binder for query request body. -type queryBinding struct{} +// QueryBinding is the query binder for query request body. +type QueryBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*queryBinding) Name() string { +func (*QueryBinding) Name() string { return "query" } // Bind parses the request query and returns the result. -func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.QueryArgs().VisitAll(func(key, val []byte) { + reqCtx.URI().QueryArgs().VisitAll(func(key, val []byte) { if err != nil { return } @@ -33,7 +35,7 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -49,3 +51,8 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the QueryBinding binder. +func (b *QueryBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/query_test.go b/binder/query_test.go new file mode 100644 index 0000000000..0d457e5795 --- /dev/null +++ b/binder/query_test.go @@ -0,0 +1,87 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_QueryBinder_Bind(t *testing.T) { + t.Parallel() + + b := &QueryBinding{ + EnableSplitting: true, + } + require.Equal(t, "query", b.Name()) + + type Post struct { + Title string `query:"title"` + } + + type User struct { + Name string `query:"name"` + Names []string `query:"names"` + Posts []Post `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.URI().SetQueryString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_QueryBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &QueryBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3") + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/resp_header.go b/binder/resp_header.go index ef14255315..fc84d01402 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -8,22 +8,24 @@ import ( "github.com/valyala/fasthttp" ) -// respHeaderBinding is the respHeader binder for response header. -type respHeaderBinding struct{} +// RespHeaderBinding is the respHeader binder for response header. +type RespHeaderBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*respHeaderBinding) Name() string { +func (*RespHeaderBinding) Name() string { return "respHeader" } // Bind parses the response header and returns the result. -func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { +func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { data := make(map[string][]string) resp.Header.VisitAll(func(key, val []byte) { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -35,3 +37,8 @@ func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the RespHeaderBinding binder. +func (b *RespHeaderBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/resp_header_test.go b/binder/resp_header_test.go new file mode 100644 index 0000000000..ff3b51f604 --- /dev/null +++ b/binder/resp_header_test.go @@ -0,0 +1,79 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_RespHeaderBinder_Bind(t *testing.T) { + t.Parallel() + + b := &RespHeaderBinding{ + EnableSplitting: true, + } + require.Equal(t, "respHeader", b.Name()) + + type User struct { + Name string `respHeader:"name"` + Posts []string `respHeader:"posts"` + Age int `respHeader:"age"` + } + var user User + + resp := fasthttp.AcquireResponse() + resp.Header.Set("name", "john") + resp.Header.Set("age", "42") + resp.Header.Set("posts", "post1,post2,post3") + + t.Cleanup(func() { + fasthttp.ReleaseResponse(resp) + }) + + err := b.Bind(resp, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Equal(t, []string{"post1", "post2", "post3"}, user.Posts) + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_RespHeaderBinder_Bind(b *testing.B) { + b.ReportAllocs() + + binder := &RespHeaderBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `respHeader:"name"` + Posts []string `respHeader:"posts"` + Age int `respHeader:"age"` + } + var user User + + resp := fasthttp.AcquireResponse() + resp.Header.Set("name", "john") + resp.Header.Set("age", "42") + resp.Header.Set("posts", "post1,post2,post3") + + b.Cleanup(func() { + fasthttp.ReleaseResponse(resp) + }) + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(resp, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Equal(b, []string{"post1", "post2", "post3"}, user.Posts) +} diff --git a/binder/uri.go b/binder/uri.go index b58d9d49c4..9b358c64b8 100644 --- a/binder/uri.go +++ b/binder/uri.go @@ -1,15 +1,15 @@ package binder // uriBinding is the URI binder for URI parameters. -type uriBinding struct{} +type URIBinding struct{} // Name returns the binding name. -func (*uriBinding) Name() string { +func (*URIBinding) Name() string { return "uri" } // Bind parses the URI parameters and returns the result. -func (b *uriBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { +func (b *URIBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { data := make(map[string][]string, len(params)) for _, param := range params { data[param] = append(data[param], paramsFunc(param)) @@ -17,3 +17,8 @@ func (b *uriBinding) Bind(params []string, paramsFunc func(key string, defaultVa return parse(b.Name(), out, data) } + +// Reset resets URIBinding binder. +func (*URIBinding) Reset() { + // Nothing to reset +} diff --git a/binder/uri_test.go b/binder/uri_test.go new file mode 100644 index 0000000000..8babdef962 --- /dev/null +++ b/binder/uri_test.go @@ -0,0 +1,77 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_URIBinding_Bind(t *testing.T) { + t.Parallel() + + b := &URIBinding{} + require.Equal(t, "uri", b.Name()) + + type User struct { + Name string `uri:"name"` + Posts []string `uri:"posts"` + Age int `uri:"age"` + } + var user User + + paramsKey := []string{"name", "age", "posts"} + paramsVals := []string{"john", "42", "post1,post2,post3"} + paramsFunc := func(key string, _ ...string) string { + for i, k := range paramsKey { + if k == key { + return paramsVals[i] + } + } + + return "" + } + + err := b.Bind(paramsKey, paramsFunc, &user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Equal(t, []string{"post1,post2,post3"}, user.Posts) + + b.Reset() +} + +func Benchmark_URIBinding_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &URIBinding{} + + type User struct { + Name string `uri:"name"` + Posts []string `uri:"posts"` + Age int `uri:"age"` + } + var user User + + paramsKey := []string{"name", "age", "posts"} + paramsVals := []string{"john", "42", "post1,post2,post3"} + paramsFunc := func(key string, _ ...string) string { + for i, k := range paramsKey { + if k == key { + return paramsVals[i] + } + } + + return "" + } + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(paramsKey, paramsFunc, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Equal(b, []string{"post1,post2,post3"}, user.Posts) +} diff --git a/binder/xml.go b/binder/xml.go index 58da2b9b07..0c345a4236 100644 --- a/binder/xml.go +++ b/binder/xml.go @@ -1,23 +1,31 @@ package binder import ( - "encoding/xml" "fmt" + + "github.com/gofiber/utils/v2" ) -// xmlBinding is the XML binder for XML request body. -type xmlBinding struct{} +// XMLBinding is the XML binder for XML request body. +type XMLBinding struct { + XMLDecoder utils.XMLUnmarshal +} // Name returns the binding name. -func (*xmlBinding) Name() string { +func (*XMLBinding) Name() string { return "xml" } // Bind parses the request body as XML and returns the result. -func (*xmlBinding) Bind(body []byte, out any) error { - if err := xml.Unmarshal(body, out); err != nil { +func (b *XMLBinding) Bind(body []byte, out any) error { + if err := b.XMLDecoder(body, out); err != nil { return fmt.Errorf("failed to unmarshal xml: %w", err) } return nil } + +// Reset resets the XMLBinding binder. +func (b *XMLBinding) Reset() { + b.XMLDecoder = nil +} diff --git a/binder/xml_test.go b/binder/xml_test.go new file mode 100644 index 0000000000..879ccf0b78 --- /dev/null +++ b/binder/xml_test.go @@ -0,0 +1,135 @@ +package binder + +import ( + "encoding/xml" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_XMLBinding_Bind(t *testing.T) { + t.Parallel() + + b := &XMLBinding{ + XMLDecoder: xml.Unmarshal, + } + require.Equal(t, "xml", b.Name()) + + type Posts struct { + XMLName xml.Name `xml:"post"` + Title string `xml:"title"` + } + + type User struct { + Name string `xml:"name"` + Ignore string `xml:"-"` + Posts []Posts `xml:"posts>post"` + Age int `xml:"age"` + } + + user := new(User) + err := b.Bind([]byte(` + + john + 42 + ignore + + + post1 + + + post2 + + + + `), user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Empty(t, user.Ignore) + + require.Len(t, user.Posts, 2) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + + b.Reset() + require.Nil(t, b.XMLDecoder) +} + +func Test_XMLBinding_Bind_error(t *testing.T) { + t.Parallel() + b := &XMLBinding{ + XMLDecoder: xml.Unmarshal, + } + + type User struct { + Name string `xml:"name"` + Age int `xml:"age"` + } + + user := new(User) + err := b.Bind([]byte(` + + john + 42 + unknown + post"` + Age int `xml:"age"` + } + + user := new(User) + data := []byte(` + + john + 42 + ignore + + + post1 + + + post2 + + + + `) + + b.StartTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(data, user) + } + require.NoError(b, err) + + user = new(User) + err = binder.Bind(data, user) + require.NoError(b, err) + + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + + require.Len(b, user.Posts, 2) + require.Equal(b, "post1", user.Posts[0].Title) + require.Equal(b, "post2", user.Posts[1].Title) +} diff --git a/ctx_test.go b/ctx_test.go index d025c24413..88b617eb5b 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1433,7 +1433,9 @@ func Benchmark_Ctx_Fresh_LastModified(b *testing.B) { func Test_Ctx_Binders(t *testing.T) { t.Parallel() // setup - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) type TestEmbeddedStruct struct { Names []string `query:"names"` diff --git a/docs/api/bind.md b/docs/api/bind.md index e91fed4273..d2b336310d 100644 --- a/docs/api/bind.md +++ b/docs/api/bind.md @@ -18,7 +18,6 @@ Make copies or use the [**`Immutable`**](./ctx.md) setting instead. [Read more.. - [Body](#body) - [Form](#form) - [JSON](#json) - - [MultipartForm](#multipartform) - [XML](#xml) - [CBOR](#cbor) - [Cookie](#cookie) @@ -83,7 +82,7 @@ curl -X POST -F name=john -F pass=doe http://localhost:3000 ### Form -Binds the request form body to a struct. +Binds the request or multipart form body data to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a form body with a field called `Pass`, you would use a struct field with `form:"pass"`. @@ -111,12 +110,16 @@ app.Post("/", func(c fiber.Ctx) error { }) ``` -Run tests with the following `curl` command: +Run tests with the following `curl` commands for both `application/x-www-form-urlencoded` and `multipart/form-data`: ```bash curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000 ``` +```bash +curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 +``` + ### JSON Binds the request JSON body to a struct. @@ -153,43 +156,6 @@ Run tests with the following `curl` command: curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000 ``` -### MultipartForm - -Binds the request multipart form body to a struct. - -It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a multipart form body with a field called `Pass`, you would use a struct field with `form:"pass"`. - -```go title="Signature" -func (b *Bind) MultipartForm(out any) error -``` - -```go title="Example" -// Field names should start with an uppercase letter -type Person struct { - Name string `form:"name"` - Pass string `form:"pass"` -} - -app.Post("/", func(c fiber.Ctx) error { - p := new(Person) - - if err := c.Bind().MultipartForm(p); err != nil { - return err - } - - log.Println(p.Name) // john - log.Println(p.Pass) // doe - - // ... -}) -``` - -Run tests with the following `curl` command: - -```bash -curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 -``` - ### XML Binds the request XML body to a struct. diff --git a/docs/api/fiber.md b/docs/api/fiber.md index 17cf3896b9..70320984da 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -83,6 +83,7 @@ app := fiber.New(fiber.Config{ | WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | `4096` | | WriteTimeout | `time.Duration` | The maximum duration before timing out writes of the response. The default timeout is unlimited. | `nil` | | XMLEncoder | `utils.XMLMarshal` | Allowing for flexibility in using another XML library for encoding. | `xml.Marshal` | +| XMLDecoder | `utils.XMLUnmarshal` | Allowing for flexibility in using another XML library for decoding. | `xml.Unmarshal` | ## Server listening diff --git a/docs/whats_new.md b/docs/whats_new.md index eadc1afa4a..321df424d6 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -49,6 +49,7 @@ We have made several changes to the Fiber app, including: - `EnablePrintRoutes` - `ListenerNetwork` (previously `Network`) - **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. +- **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder. ### New Methods diff --git a/redirect.go b/redirect.go index bc79314922..483272c7b5 100644 --- a/redirect.go +++ b/redirect.go @@ -146,10 +146,8 @@ func (r *Redirect) WithInput() *Redirect { oldInput := make(map[string]string) switch ctype { - case MIMEApplicationForm: + case MIMEApplicationForm, MIMEMultipartForm: _ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed - case MIMEMultipartForm: - _ = r.c.Bind().MultipartForm(oldInput) //nolint:errcheck // not needed default: _ = r.c.Bind().Query(oldInput) //nolint:errcheck // not needed } From 775e0a73f3fc0eba0940488e2accd3b87ab245ee Mon Sep 17 00:00:00 2001 From: Bulat Bagaviev <110637846+sunnyyssh@users.noreply.github.com> Date: Sat, 28 Dec 2024 16:29:31 +0300 Subject: [PATCH 02/20] =?UTF-8?q?=F0=9F=A9=B9=20Fix:=20Memory=20leak=20rem?= =?UTF-8?q?oval=20in=20the=20idempotency=20middleware=20(#3263)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🩹 Fix: Add key removal in MemoryLock * Fixed concurrent deletion. * Fix: idempotency middleware's MemoryLock * Add MemoryLock benchmarks. * Updated benchmarks: Add returning error handling * Renamed benchmark: RepeatedKeys --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- middleware/idempotency/locker.go | 34 ++++++++++---- middleware/idempotency/locker_test.go | 66 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/middleware/idempotency/locker.go b/middleware/idempotency/locker.go index 2c3348b8f3..f24db382a5 100644 --- a/middleware/idempotency/locker.go +++ b/middleware/idempotency/locker.go @@ -10,42 +10,58 @@ type Locker interface { Unlock(key string) error } +type countedLock struct { + mu sync.Mutex + locked int +} + type MemoryLock struct { - keys map[string]*sync.Mutex + keys map[string]*countedLock mu sync.Mutex } func (l *MemoryLock) Lock(key string) error { l.mu.Lock() - mu, ok := l.keys[key] + lock, ok := l.keys[key] if !ok { - mu = new(sync.Mutex) - l.keys[key] = mu + lock = new(countedLock) + l.keys[key] = lock } + lock.locked++ l.mu.Unlock() - mu.Lock() + lock.mu.Lock() return nil } func (l *MemoryLock) Unlock(key string) error { l.mu.Lock() - mu, ok := l.keys[key] - l.mu.Unlock() + lock, ok := l.keys[key] if !ok { // This happens if we try to unlock an unknown key + l.mu.Unlock() return nil } + l.mu.Unlock() - mu.Unlock() + lock.mu.Unlock() + + l.mu.Lock() + lock.locked-- + if lock.locked <= 0 { + // This happens if countedLock is used to Lock and Unlock the same number of times + // So, we can delete the key to prevent memory leak + delete(l.keys, key) + } + l.mu.Unlock() return nil } func NewMemoryLock() *MemoryLock { return &MemoryLock{ - keys: make(map[string]*sync.Mutex), + keys: make(map[string]*countedLock), } } diff --git a/middleware/idempotency/locker_test.go b/middleware/idempotency/locker_test.go index 3b4a3ca78a..81da15d3bf 100644 --- a/middleware/idempotency/locker_test.go +++ b/middleware/idempotency/locker_test.go @@ -1,6 +1,8 @@ package idempotency_test import ( + "strconv" + "sync/atomic" "testing" "time" @@ -59,3 +61,67 @@ func Test_MemoryLock(t *testing.T) { require.NoError(t, err) } } + +func Benchmark_MemoryLock(b *testing.B) { + keys := make([]string, b.N) + for i := range keys { + keys[i] = strconv.Itoa(i) + } + + lock := idempotency.NewMemoryLock() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } +} + +func Benchmark_MemoryLock_Parallel(b *testing.B) { + // In order to prevent using repeated keys I pre-allocate keys + keys := make([]string, 1_000_000) + for i := range keys { + keys[i] = strconv.Itoa(i) + } + + b.Run("UniqueKeys", func(b *testing.B) { + lock := idempotency.NewMemoryLock() + var keyI atomic.Int32 + b.RunParallel(func(p *testing.PB) { + for p.Next() { + i := int(keyI.Add(1)) % len(keys) + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } + }) + }) + + b.Run("RepeatedKeys", func(b *testing.B) { + lock := idempotency.NewMemoryLock() + var keyI atomic.Int32 + b.RunParallel(func(p *testing.PB) { + for p.Next() { + // Division by 3 ensures that index will be repreated exactly 3 times + i := int(keyI.Add(1)) / 3 % len(keys) + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } + }) + }) +} From 845a7f8b8e7e0a0794eb5fc5a05a0030fd730a46 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sun, 29 Dec 2024 13:34:34 -0500 Subject: [PATCH 03/20] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20Improve=20Perform?= =?UTF-8?q?ance=20of=20Fiber=20Router=20(#3261)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial improvements * Update test * Improve RemoveEscapeChar performance * Fix lint issues * Re-add comments * Add dedicated request handlers * Fix lint issues * Add test case for app.All with custom method * Add test for custom Ctx and Request Methods * Simplify test logic * Simplify test --- .github/workflows/linter.yml | 2 +- Makefile | 2 +- app.go | 16 +++- app_test.go | 39 ++++++-- binder/mapping.go | 4 +- ctx_interface_gen.go | 3 + ctx_test.go | 29 ++++++ path.go | 12 ++- router.go | 181 +++++++++++++++++++---------------- router_test.go | 23 +++++ 10 files changed, 209 insertions(+), 102 deletions(-) diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index bd2c0bce4c..beed212610 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -37,4 +37,4 @@ jobs: uses: golangci/golangci-lint-action@v6 with: # NOTE: Keep this in sync with the version from .golangci.yml - version: v1.62.0 + version: v1.62.2 diff --git a/Makefile b/Makefile index 4b348cd574..669b3fbee4 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ markdown: ## lint: 🚨 Run lint checks .PHONY: lint lint: - go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.0 run ./... + go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.2 run ./... ## test: 🚦 Execute all tests .PHONY: test diff --git a/app.go b/app.go index 5e5475b5f1..3810a8ec0c 100644 --- a/app.go +++ b/app.go @@ -616,6 +616,10 @@ func (app *App) handleTrustedProxy(ipAddress string) { // Note: It doesn't allow adding new methods, only customizing exist methods. func (app *App) NewCtxFunc(function func(app *App) CustomCtx) { app.newCtxFunc = function + + if app.server != nil { + app.server.Handler = app.customRequestHandler + } } // RegisterCustomConstraint allows to register custom constraint. @@ -868,7 +872,11 @@ func (app *App) Config() Config { func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476 // prepare the server for the start app.startupProcess() - return app.requestHandler + + if app.newCtxFunc != nil { + return app.customRequestHandler + } + return app.defaultRequestHandler } // Stack returns the raw router stack. @@ -1057,7 +1065,11 @@ func (app *App) init() *App { } // fasthttp server settings - app.server.Handler = app.requestHandler + if app.newCtxFunc != nil { + app.server.Handler = app.customRequestHandler + } else { + app.server.Handler = app.defaultRequestHandler + } app.server.Name = app.config.ServerHeader app.server.Concurrency = app.config.Concurrency app.server.NoDefaultDate = app.config.DisableDefaultDate diff --git a/app_test.go b/app_test.go index a99796a2c1..8455ded86e 100644 --- a/app_test.go +++ b/app_test.go @@ -581,32 +581,51 @@ func Test_App_Use_StrictRouting(t *testing.T) { func Test_App_Add_Method_Test(t *testing.T) { t.Parallel() - defer func() { - if err := recover(); err != nil { - require.Equal(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err)) - } - }() methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here app := New(Config{ RequestMethods: methods, }) - app.Add([]string{"JOHN"}, "/doe", testEmptyHandler) + app.Add([]string{"JOHN"}, "/john", testEmptyHandler) - resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + resp, err := app.Test(httptest.NewRequest("JOHN", "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusMethodNotAllowed, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil)) + resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotImplemented, resp.StatusCode, "Status code") - app.Add([]string{"JANE"}, "/doe", testEmptyHandler) + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) +} + +func Test_App_All_Method_Test(t *testing.T) { + t.Parallel() + + methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here + app := New(Config{ + RequestMethods: methods, + }) + + // Add a new method with All + app.All("/doe", testEmptyHandler) + + resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) } // go test -run Test_App_GETOnly diff --git a/binder/mapping.go b/binder/mapping.go index d8b692f7e4..29b5550b10 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -107,8 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error { elem := reflect.TypeOf(ptr).Elem() - //nolint:exhaustive // it's not necessary to check all types - switch elem.Kind() { + switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { @@ -129,7 +128,6 @@ func parseToMap(ptr any, data map[string][]string) error { newMap[k] = "" continue } - newMap[k] = v[len(v)-1] } } diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index d7f8bbc615..cc48576efb 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -350,5 +350,8 @@ type Ctx interface { setIndexRoute(route int) setMatched(matched bool) setRoute(route *Route) + // Drop closes the underlying connection without sending any response headers or body. + // This can be useful for silently terminating client connections, such as in DDoS mitigation + // or when blocking access to sensitive endpoints. Drop() error } diff --git a/ctx_test.go b/ctx_test.go index 88b617eb5b..eb81876e37 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -127,6 +127,35 @@ func Test_Ctx_CustomCtx(t *testing.T) { require.Equal(t, "prefix_v3", string(body)) } +// go test -run Test_Ctx_CustomCtx +func Test_Ctx_CustomCtx_and_Method(t *testing.T) { + t.Parallel() + + // Create app with custom request methods + methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here + app := New(Config{ + RequestMethods: methods, + }) + + // Create custom context + app.NewCtxFunc(func(app *App) CustomCtx { + return &customCtx{ + DefaultCtx: *NewDefaultCtx(app), + } + }) + + // Add route with custom method + app.Add([]string{"JOHN"}, "/doe", testEmptyHandler) + resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) +} + // go test -run Test_Ctx_Accepts_EmptyAccept func Test_Ctx_Accepts_EmptyAccept(t *testing.T) { t.Parallel() diff --git a/path.go b/path.go index 00105d5cc0..282073ec04 100644 --- a/path.go +++ b/path.go @@ -620,10 +620,16 @@ func GetTrimmedParam(param string) string { // RemoveEscapeChar remove escape characters func RemoveEscapeChar(word string) string { - if strings.IndexByte(word, escapeChar) != -1 { - return strings.ReplaceAll(word, string(escapeChar), "") + b := []byte(word) + dst := 0 + for src := 0; src < len(b); src++ { + if b[src] == '\\' { + continue + } + b[dst] = b[src] + dst++ } - return word + return string(b[:dst]) } func getParamConstraintType(constraintPart string) TypeConstraint { diff --git a/router.go b/router.go index 2091cfc6cb..9612da170b 100644 --- a/router.go +++ b/router.go @@ -5,11 +5,11 @@ package fiber import ( + "bytes" "errors" "fmt" "html" "sort" - "strings" "sync/atomic" "github.com/gofiber/utils/v2" @@ -65,10 +65,12 @@ type Route struct { func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool { // root detectionPath check - if r.root && detectionPath == "/" { + if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' { return true - // '*' wildcard matches any detectionPath - } else if r.star { + } + + // '*' wildcard matches any detectionPath + if r.star { if len(path) > 1 { params[0] = path[1:] } else { @@ -76,24 +78,32 @@ func (r *Route) match(detectionPath, path string, params *[maxParams]string) boo } return true } - // Does this route have parameters + + // Does this route have parameters? if len(r.Params) > 0 { - // Match params - if match := r.routeParser.getMatch(detectionPath, path, params, r.use); match { - // Get params from the path detectionPath - return match + // Match params using precomputed routeParser + if r.routeParser.getMatch(detectionPath, path, params, r.use) { + return true } } - // Is this route a Middleware? + + // Middleware route? if r.use { - // Single slash will match or detectionPath prefix - if r.root || strings.HasPrefix(detectionPath, r.path) { + // Single slash or prefix match + plen := len(r.path) + if r.root { + // If r.root is '/', it matches everything starting at '/' + if len(detectionPath) > 0 && detectionPath[0] == '/' { + return true + } + } else if len(detectionPath) >= plen && detectionPath[:plen] == r.path { return true } - // Check for a simple detectionPath match - } else if len(r.path) == len(detectionPath) && r.path == detectionPath { + } else if len(r.path) == len(detectionPath) && detectionPath == r.path { + // Check exact match return true } + // No match return false } @@ -201,44 +211,63 @@ func (app *App) next(c *DefaultCtx) (bool, error) { return false, err } -func (app *App) requestHandler(rctx *fasthttp.RequestCtx) { - // Handler for default ctxs - var c CustomCtx - var ok bool - if app.newCtxFunc != nil { - c, ok = app.AcquireCtx(rctx).(CustomCtx) - if !ok { - panic(errors.New("requestHandler: failed to type-assert to CustomCtx")) - } - } else { - c, ok = app.AcquireCtx(rctx).(*DefaultCtx) - if !ok { - panic(errors.New("requestHandler: failed to type-assert to *DefaultCtx")) - } +func (app *App) defaultRequestHandler(rctx *fasthttp.RequestCtx) { + // Acquire DefaultCtx from the pool + ctx, ok := app.AcquireCtx(rctx).(*DefaultCtx) + if !ok { + panic(errors.New("requestHandler: failed to type-assert to *DefaultCtx")) } - defer app.ReleaseCtx(c) - // handle invalid http method directly - if app.methodInt(c.Method()) == -1 { - _ = c.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil + defer app.ReleaseCtx(ctx) + + // Check if the HTTP method is valid + if ctx.methodINT == -1 { + _ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil return } - // check flash messages - if strings.Contains(utils.UnsafeString(c.Request().Header.RawHeaders()), FlashCookieName) { - c.Redirect().parseAndClearFlashMessages() + // Optional: Check flash messages + rawHeaders := ctx.Request().Header.RawHeaders() + if len(rawHeaders) > 0 && bytes.Contains(rawHeaders, []byte(FlashCookieName)) { + ctx.Redirect().parseAndClearFlashMessages() } - // Find match in stack - var err error - if app.newCtxFunc != nil { - _, err = app.nextCustom(c) - } else { - _, err = app.next(c.(*DefaultCtx)) //nolint:errcheck // It is fine to ignore the error here + // Attempt to match a route and execute the chain + _, err := app.next(ctx) + if err != nil { + if catch := ctx.App().ErrorHandler(ctx, err); catch != nil { + _ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil + } + // TODO: Do we need to return here? } +} + +func (app *App) customRequestHandler(rctx *fasthttp.RequestCtx) { + // Acquire CustomCtx from the pool + ctx, ok := app.AcquireCtx(rctx).(CustomCtx) + if !ok { + panic(errors.New("requestHandler: failed to type-assert to CustomCtx")) + } + + defer app.ReleaseCtx(ctx) + + // Check if the HTTP method is valid + if app.methodInt(ctx.Method()) == -1 { + _ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil + return + } + + // Optional: Check flash messages + rawHeaders := ctx.Request().Header.RawHeaders() + if len(rawHeaders) > 0 && bytes.Contains(rawHeaders, []byte(FlashCookieName)) { + ctx.Redirect().parseAndClearFlashMessages() + } + + // Attempt to match a route and execute the chain + _, err := app.nextCustom(ctx) if err != nil { - if catch := c.App().ErrorHandler(c, err); catch != nil { - _ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here + if catch := ctx.App().ErrorHandler(ctx, err); catch != nil { + _ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil } // TODO: Do we need to return here? } @@ -295,68 +324,56 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler handlers = append(handlers, handler) } + // Precompute path normalization ONCE + if pathRaw == "" { + pathRaw = "/" + } + if pathRaw[0] != '/' { + pathRaw = "/" + pathRaw + } + pathPretty := pathRaw + if !app.config.CaseSensitive { + pathPretty = utils.ToLower(pathPretty) + } + if !app.config.StrictRouting && len(pathPretty) > 1 { + pathPretty = utils.TrimRight(pathPretty, '/') + } + pathClean := RemoveEscapeChar(pathPretty) + + parsedRaw := parseRoute(pathRaw, app.customConstraints...) + parsedPretty := parseRoute(pathPretty, app.customConstraints...) + for _, method := range methods { - // Uppercase HTTP methods method = utils.ToUpper(method) - // Check if the HTTP method is valid unless it's USE if method != methodUse && app.methodInt(method) == -1 { panic(fmt.Sprintf("add: invalid http method %s\n", method)) } - // is mounted app + isMount := group != nil && group.app != app - // A route requires atleast one ctx handler if len(handlers) == 0 && !isMount { panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw)) } - // Cannot have an empty path - if pathRaw == "" { - pathRaw = "/" - } - // Path always start with a '/' - if pathRaw[0] != '/' { - pathRaw = "/" + pathRaw - } - // Create a stripped path in case-sensitive / trailing slashes - pathPretty := pathRaw - // Case-sensitive routing, all to lowercase - if !app.config.CaseSensitive { - pathPretty = utils.ToLower(pathPretty) - } - // Strict routing, remove trailing slashes - if !app.config.StrictRouting && len(pathPretty) > 1 { - pathPretty = utils.TrimRight(pathPretty, '/') - } - // Is layer a middleware? + isUse := method == methodUse - // Is path a direct wildcard? - isStar := pathPretty == "/*" - // Is path a root slash? - isRoot := pathPretty == "/" - // Parse path parameters - parsedRaw := parseRoute(pathRaw, app.customConstraints...) - parsedPretty := parseRoute(pathPretty, app.customConstraints...) - - // Create route metadata without pointer + isStar := pathClean == "/*" + isRoot := pathClean == "/" + route := Route{ - // Router booleans use: isUse, mount: isMount, star: isStar, root: isRoot, - // Path data - path: RemoveEscapeChar(pathPretty), + path: pathClean, routeParser: parsedPretty, Params: parsedRaw.params, + group: group, - // Group data - group: group, - - // Public data Path: pathRaw, Method: method, Handlers: handlers, } + // Increment global handler count atomic.AddUint32(&app.handlersCount, uint32(len(handlers))) //nolint:gosec // Not a concern diff --git a/router_test.go b/router_test.go index 5509039c66..fe5b3429e0 100644 --- a/router_test.go +++ b/router_test.go @@ -591,6 +591,29 @@ func Benchmark_Router_Next_Default(b *testing.B) { } } +// go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel$ github.com/gofiber/fiber/v3 -count=1 +func Benchmark_Router_Next_Default_Parallel(b *testing.B) { + app := New() + app.Get("/", func(_ Ctx) error { + return nil + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod(MethodGet) + fctx.Request.SetRequestURI("/") + + for pb.Next() { + h(fctx) + } + }) +} + // go test -v ./... -run=^$ -bench=Benchmark_Route_Match -benchmem -count=4 func Benchmark_Route_Match(b *testing.B) { var match bool From 26e30c06724b21ec0d9348b45b09bfe51dcb38a5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 14:18:54 +0100 Subject: [PATCH 04/20] build(deps): bump DavidAnson/markdownlint-cli2-action from 18 to 19 (#3266) Bumps [DavidAnson/markdownlint-cli2-action](https://github.com/davidanson/markdownlint-cli2-action) from 18 to 19. - [Release notes](https://github.com/davidanson/markdownlint-cli2-action/releases) - [Commits](https://github.com/davidanson/markdownlint-cli2-action/compare/v18...v19) --- updated-dependencies: - dependency-name: DavidAnson/markdownlint-cli2-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/markdown.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/markdown.yml b/.github/workflows/markdown.yml index a015149c22..cf3575a947 100644 --- a/.github/workflows/markdown.yml +++ b/.github/workflows/markdown.yml @@ -15,7 +15,7 @@ jobs: uses: actions/checkout@v4 - name: Run markdownlint-cli2 - uses: DavidAnson/markdownlint-cli2-action@v18 + uses: DavidAnson/markdownlint-cli2-action@v19 with: globs: | **/*.md From d0e767fc4798affb4776591775e6d5f9c2ccd23b Mon Sep 17 00:00:00 2001 From: AuroraTea <1352685369@qq.com> Date: Tue, 31 Dec 2024 22:58:07 +0800 Subject: [PATCH 05/20] =?UTF-8?q?=F0=9F=93=9A=20Doc:=20Optimize=20the=20me?= =?UTF-8?q?nu=20item=20text=20(#3267)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/middleware/session.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index ff73ff6094..b175ac9e3f 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -2,7 +2,7 @@ id: session --- -# Session Middleware for [Fiber](https://github.com/gofiber/fiber) +# Session The `session` middleware provides session management for Fiber applications, utilizing the [Storage](https://github.com/gofiber/storage) package for multi-database support via a unified interface. By default, session data is stored in memory, but custom storage options are easily configurable (see examples below). From ef04a8a99e0b7404b574e371dc2a341acd5a017d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Tue, 31 Dec 2024 18:34:28 +0300 Subject: [PATCH 06/20] :bug: bug: Fix square bracket notation in Multipart FormData (#3235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :bug: bug: add square bracket notation support to BindMultipart * Fix golangci-lint issues * Fixing undef variable * Fix more lint issues * test * update1 * improve coverage * fix linter * reduce code duplication * reduce code duplications in bindMultipart --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: René --- bind_test.go | 103 ++++++++++++++++++++++++++++++++++++++++-- binder/cookie.go | 13 +----- binder/form.go | 29 +++++------- binder/form_test.go | 16 ++++++- binder/header.go | 22 ++++----- binder/mapping.go | 38 +++++++++++++++- binder/query.go | 17 +------ binder/resp_header.go | 23 +++++----- 8 files changed, 186 insertions(+), 75 deletions(-) diff --git a/bind_test.go b/bind_test.go index 52c9004c61..b01086e623 100644 --- a/bind_test.go +++ b/bind_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "mime/multipart" "net/http/httptest" "reflect" "testing" @@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) { reqBody := []byte(`{"name":"john"}`) type Demo struct { - Name string `json:"name" xml:"name" form:"name" query:"name"` + Name string `json:"name" xml:"name" form:"name" query:"name"` + Names []string `json:"names" xml:"names" form:"names" query:"names"` } // Helper function to test compressed bodies @@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) { Data []Demo `query:"data"` } + t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(t, writer.WriteField("data.0.name", "john")) + require.NoError(t, writer.WriteField("data.1.name", "doe")) + require.NoError(t, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + require.NoError(t, c.Bind().Body(cq)) + require.Len(t, cq.Data, 2) + require.Equal(t, "john", cq.Data[0].Name) + require.Equal(t, "doe", cq.Data[1].Name) + }) + + t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(t, writer.WriteField("data[0][name]", "john")) + require.NoError(t, writer.WriteField("data[1][name]", "doe")) + require.NoError(t, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + require.NoError(t, c.Bind().Body(cq)) + require.Len(t, cq.Data, 2) + require.Equal(t, "john", cq.Data[0].Name) + require.Equal(t, "doe", cq.Data[1].Name) + }) + t.Run("CollectionQuerySquareBrackets", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() @@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) { Name string `form:"name"` } - body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(b, writer.WriteField("name", "john")) + require.NoError(b, writer.Close()) + body := buf.Bytes() + + c.Request().SetBody(body) + c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) + c.Request().Header.SetContentLength(len(body)) + d := new(Demo) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + err = c.Bind().Body(d) + } + + require.NoError(b, err) + require.Equal(b, "john", d.Name) +} + +// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4 +func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) { + var err error + + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + + type Person struct { + Name string `form:"name"` + Age int `form:"age"` + } + + type Demo struct { + Name string `form:"name"` + Persons []Person `form:"persons"` + } + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(b, writer.WriteField("name", "john")) + require.NoError(b, writer.WriteField("persons.0.name", "john")) + require.NoError(b, writer.WriteField("persons[0][age]", "10")) + require.NoError(b, writer.WriteField("persons[1][name]", "doe")) + require.NoError(b, writer.WriteField("persons.1.age", "20")) + require.NoError(b, writer.Close()) + body := buf.Bytes() + c.Request().SetBody(body) - c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) + c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) c.Request().Header.SetContentLength(len(body)) d := new(Demo) @@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) { for n := 0; n < b.N; n++ { err = c.Bind().Body(d) } + require.NoError(b, err) require.Equal(b, "john", d.Name) + require.Equal(b, "john", d.Persons[0].Name) + require.Equal(b, 10, d.Persons[0].Age) + require.Equal(b, "doe", d.Persons[1].Name) + require.Equal(b, 20, d.Persons[1].Age) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4 diff --git a/binder/cookie.go b/binder/cookie.go index 230794f45a..5b9ccf1ed3 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) if err != nil { diff --git a/binder/form.go b/binder/form.go index 7ab0b1b258..a8f5b85270 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, true) }) if err != nil { @@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { // bindMultipart parses the request body and returns the result. func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { - data, err := req.MultipartForm() + multipartForm, err := req.MultipartForm() if err != nil { return err } - return parse(b.Name(), out, data.Value) + data := make(map[string][]string) + for key, values := range multipartForm.Value { + err = formatBindData(out, data, key, values, b.EnableSplitting, true) + if err != nil { + return err + } + } + + return parse(b.Name(), out, data) } // Reset resets the FormBinding binder. diff --git a/binder/form_test.go b/binder/form_test.go index c3c52c73fd..55023cb30f 100644 --- a/binder/form_test.go +++ b/binder/form_test.go @@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) { } require.Equal(t, "form", b.Name()) + type Post struct { + Title string `form:"title"` + } + type User struct { Name string `form:"name"` Names []string `form:"names"` + Posts []Post `form:"posts"` Age int `form:"age"` } var user User @@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) { mw := multipart.NewWriter(buf) require.NoError(t, mw.WriteField("name", "john")) - require.NoError(t, mw.WriteField("names", "john")) + require.NoError(t, mw.WriteField("names", "john,eric")) require.NoError(t, mw.WriteField("names", "doe")) require.NoError(t, mw.WriteField("age", "42")) + require.NoError(t, mw.WriteField("posts[0][title]", "post1")) + require.NoError(t, mw.WriteField("posts[1][title]", "post2")) + require.NoError(t, mw.WriteField("posts[2][title]", "post3")) + require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) @@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) { require.Equal(t, 42, user.Age) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") + require.Contains(t, user.Names, "eric") + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) } func Benchmark_FormBinder_BindMultipart(b *testing.B) { diff --git a/binder/header.go b/binder/header.go index b04ce9add3..763be56795 100644 --- a/binder/header.go +++ b/binder/header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string { // Bind parses the request header and returns the result. func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) + var err error req.Header.VisitAll(func(key, val []byte) { + if err != nil { + return + } + k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) + if err != nil { + return err + } + return parse(b.Name(), out, data) } diff --git a/binder/mapping.go b/binder/mapping.go index 29b5550b10..70cb9cbc2d 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error { elem := reflect.TypeOf(ptr).Elem() - switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types + switch elem.Kind() { case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { @@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error { } newMap[k] = v[len(v)-1] } + default: + return nil // it's not necessary to check all types } return nil @@ -247,3 +249,37 @@ func FilterFlags(content string) string { } return content } + +func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay + var err error + if supportBracketNotation && strings.Contains(key, "[") { + key, err = parseParamSquareBrackets(key) + if err != nil { + return err + } + } + + switch v := any(value).(type) { + case string: + assignBindData(out, data, key, v, enableSplitting) + case []string: + for _, val := range v { + assignBindData(out, data, key, val, enableSplitting) + } + default: + return fmt.Errorf("unsupported value type: %T", value) + } + + return err +} + +func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay + if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) { + values := strings.Split(value, ",") + for i := 0; i < len(values); i++ { + data[key] = append(data[key], values[i]) + } + } else { + data[key] = append(data[key], value) + } +} diff --git a/binder/query.go b/binder/query.go index 9ee500ba63..d2ac309215 100644 --- a/binder/query.go +++ b/binder/query.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, true) }) if err != nil { diff --git a/binder/resp_header.go b/binder/resp_header.go index fc84d01402..cb29e99d6f 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -21,20 +18,22 @@ func (*RespHeaderBinding) Name() string { // Bind parses the response header and returns the result. func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { data := make(map[string][]string) + var err error + resp.Header.VisitAll(func(key, val []byte) { + if err != nil { + return + } + k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) + if err != nil { + return err + } + return parse(b.Name(), out, data) } From d5771a34dfbb0285ff7920823ba686fd4cbb19f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Tue, 31 Dec 2024 17:00:40 +0100 Subject: [PATCH 07/20] prepare release 3.0.0-beta.4 --- app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.go b/app.go index 3810a8ec0c..e96642aed2 100644 --- a/app.go +++ b/app.go @@ -32,7 +32,7 @@ import ( ) // Version of current fiber package -const Version = "3.0.0-beta.3" +const Version = "3.0.0-beta.4" // Handler defines a function to serve HTTP requests. type Handler = func(Ctx) error From 5355869d4dad5b1779e4ff2fabfe490725f03f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Thu, 2 Jan 2025 10:42:25 +0300 Subject: [PATCH 08/20] :bug: bug: make Render bind parameter type any again (#3270) * :bug: bug: make Render bind parameter type any again * update docs --- ctx.go | 2 +- ctx_interface_gen.go | 5 ++--- docs/api/ctx.md | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ctx.go b/ctx.go index 789a5315ca..9d73c08bd4 100644 --- a/ctx.go +++ b/ctx.go @@ -1369,7 +1369,7 @@ func (c *DefaultCtx) GetRouteURL(routeName string, params Map) (string, error) { // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template -func (c *DefaultCtx) Render(name string, bind Map, layouts ...string) error { +func (c *DefaultCtx) Render(name string, bind any, layouts ...string) error { // Get new buffer from pool buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index cc48576efb..cd93c48905 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -12,8 +12,7 @@ import ( "github.com/valyala/fasthttp" ) -// Ctx represents the Context which hold the HTTP request and response. -// It has methods for the request query string, parameters, body, HTTP headers and so on. +// Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on. type Ctx interface { // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string @@ -263,7 +262,7 @@ type Ctx interface { GetRouteURL(routeName string, params Map) (string, error) // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template - Render(name string, bind Map, layouts ...string) error + Render(name string, bind any, layouts ...string) error renderExtensions(bind any) // Route returns the matched Route struct. Route() *Route diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 85532a540d..b65532532a 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -1506,7 +1506,7 @@ app.Get("/teapot", func(c fiber.Ctx) error { Renders a view with data and sends a `text/html` response. By default, `Render` uses the default [**Go Template engine**](https://pkg.go.dev/html/template/). If you want to use another view engine, please take a look at our [**Template middleware**](https://docs.gofiber.io/template). ```go title="Signature" -func (c fiber.Ctx) Render(name string, bind Map, layouts ...string) error +func (c fiber.Ctx) Render(name string, bind any, layouts ...string) error ``` ## Request From ac82b0c413e8116908b2ab6d19a13123133f2871 Mon Sep 17 00:00:00 2001 From: Giovanni Rivera Date: Thu, 2 Jan 2025 23:36:58 -0800 Subject: [PATCH 09/20] =?UTF-8?q?=F0=9F=93=9A=20Doc:=20Fix=20static=20midd?= =?UTF-8?q?leware=20CacheDuration=20data=20type=20typo=20(#3273)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/middleware/static.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/middleware/static.md b/docs/middleware/static.md index a8c7ec6093..61a955f96e 100644 --- a/docs/middleware/static.md +++ b/docs/middleware/static.md @@ -154,7 +154,7 @@ To define static routes using `Get`, append the wildcard (`*`) operator at the e | Browse | `bool` | When set to true, enables directory browsing. | `false` | | Download | `bool` | When set to true, enables direct download. | `false` | | IndexNames | `[]string` | The names of the index files for serving a directory. | `[]string{"index.html"}` | -| CacheDuration | `string` | Expiration duration for inactive file handlers.

Use a negative time.Duration to disable it. | `10 * time.Second` | +| CacheDuration | `time.Duration` | Expiration duration for inactive file handlers.

Use a negative time.Duration to disable it. | `10 * time.Second` | | MaxAge | `int` | The value for the Cache-Control HTTP-header that is set on the file response. MaxAge is defined in seconds. | `0` | | ModifyResponse | `fiber.Handler` | ModifyResponse defines a function that allows you to alter the response. | `nil` | | NotFoundHandler | `fiber.Handler` | NotFoundHandler defines a function to handle when the path is not found. | `nil` | From a95ffd8eff3fd699d67ff87740876eaf6ef1c7c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Mon, 6 Jan 2025 08:36:06 +0100 Subject: [PATCH 10/20] fix doc examples for generic function --- docs/guide/utils.md | 16 ++++++++-------- docs/whats_new.md | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/guide/utils.md b/docs/guide/utils.md index 24ed2362ba..1f3fec1c69 100644 --- a/docs/guide/utils.md +++ b/docs/guide/utils.md @@ -45,9 +45,9 @@ func GetReqHeader[V any](c Ctx, key string, defaultValue ...V) V ```go title="Example" app.Get("/search", func(c fiber.Ctx) error { // curl -X GET http://example.com/search -H "X-Request-ID: 12345" -H "X-Request-Name: John" - GetReqHeader[int](c, "X-Request-ID") // => returns 12345 as integer. - GetReqHeader[string](c, "X-Request-Name") // => returns "John" as string. - GetReqHeader[string](c, "unknownParam", "default") // => returns "default" as string. + fiber.GetReqHeader[int](c, "X-Request-ID") // => returns 12345 as integer. + fiber.GetReqHeader[string](c, "X-Request-Name") // => returns "John" as string. + fiber.GetReqHeader[string](c, "unknownParam", "default") // => returns "default" as string. // ... }) ``` @@ -97,8 +97,8 @@ func Params[V any](c Ctx, key string, defaultValue ...V) V ```go title="Example" app.Get("/user/:user/:id", func(c fiber.Ctx) error { // http://example.com/user/john/25 - Params[int](c, "id") // => returns 25 as integer. - Params[int](c, "unknownParam", 99) // => returns the default 99 as integer. + fiber.Params[int](c, "id") // => returns 25 as integer. + fiber.Params[int](c, "unknownParam", 99) // => returns the default 99 as integer. // ... return c.SendString("Hello, " + fiber.Params[string](c, "user")) }) @@ -116,9 +116,9 @@ func Query[V any](c Ctx, key string, defaultValue ...V) V ```go title="Example" app.Get("/search", func(c fiber.Ctx) error { // http://example.com/search?name=john&age=25 - Query[string](c, "name") // => returns "john" - Query[int](c, "age") // => returns 25 as integer. - Query[string](c, "unknownParam", "default") // => returns "default" as string. + fiber.Query[string](c, "name") // => returns "john" + fiber.Query[int](c, "age") // => returns 25 as integer. + fiber.Query[string](c, "unknownParam", "default") // => returns "default" as string. // ... }) ``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 321df424d6..7501c2b57e 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -497,7 +497,7 @@ func main() { app := fiber.New() app.Get("/convert", func(c fiber.Ctx) error { - value, err := Convert[string](c.Query("value"), strconv.Atoi, 0) + value, err := fiber.Convert[string](c.Query("value"), strconv.Atoi, 0) if err != nil { return c.Status(fiber.StatusBadRequest).SendString(err.Error()) } @@ -575,7 +575,7 @@ func main() { app := fiber.New() app.Get("/params/:id", func(c fiber.Ctx) error { - id := Params[int](c, "id", 0) + id := fiber.Params[int](c, "id", 0) return c.JSON(id) }) @@ -607,7 +607,7 @@ func main() { app := fiber.New() app.Get("/query", func(c fiber.Ctx) error { - age := Query[int](c, "age", 0) + age := fiber.Query[int](c, "age", 0) return c.JSON(age) }) @@ -640,7 +640,7 @@ func main() { app := fiber.New() app.Get("/header", func(c fiber.Ctx) error { - userAgent := GetReqHeader[string](c, "User-Agent", "Unknown") + userAgent := fiber.GetReqHeader[string](c, "User-Agent", "Unknown") return c.JSON(userAgent) }) From 86d72bbba8b5998bb4583f3bff110ede1391d795 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:20:53 +0100 Subject: [PATCH 11/20] build(deps): bump golang.org/x/crypto from 0.31.0 to 0.32.0 (#3274) Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.31.0 to 0.32.0. - [Commits](https://github.com/golang/crypto/compare/v0.31.0...v0.32.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index ebdc9080e8..00e9dcc945 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tinylib/msgp v1.2.5 github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/fasthttp v1.58.0 - golang.org/x/crypto v0.31.0 + golang.org/x/crypto v0.32.0 ) require ( @@ -25,7 +25,7 @@ require ( github.com/valyala/tcplisten v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect golang.org/x/net v0.31.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5b8204d1ee..3c58c2c59a 100644 --- a/go.sum +++ b/go.sum @@ -35,14 +35,14 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= From bc37f209bfef66d8ff84dc7d13dc066107eb3d9c Mon Sep 17 00:00:00 2001 From: RW Date: Wed, 8 Jan 2025 08:19:20 +0100 Subject: [PATCH 12/20] refactor(timeout): unify and enhance timeout middleware (#3275) * feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * refactor(timeout): remove goroutine-based logic and improve documentation - Switch to a synchronous approach to avoid data races with fasthttp context - Enhance error handling for deadline and custom errors - Update comments for clarity and maintainability * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- middleware/timeout/timeout.go | 57 +++++++--- middleware/timeout/timeout_test.go | 166 ++++++++++++++++++----------- 2 files changed, 147 insertions(+), 76 deletions(-) diff --git a/middleware/timeout/timeout.go b/middleware/timeout/timeout.go index a88f2e90b1..127fff8723 100644 --- a/middleware/timeout/timeout.go +++ b/middleware/timeout/timeout.go @@ -8,23 +8,52 @@ import ( "github.com/gofiber/fiber/v3" ) -// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response. -func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler { +// New enforces a timeout for each incoming request. If the timeout expires or +// any of the specified errors occur, fiber.ErrRequestTimeout is returned. +func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler { return func(ctx fiber.Ctx) error { - timeoutContext, cancel := context.WithTimeout(ctx.Context(), t) + // If timeout <= 0, skip context.WithTimeout and run the handler as-is. + if timeout <= 0 { + return runHandler(ctx, h, tErrs) + } + + // Create a context with the specified timeout; any operation exceeding + // this deadline will be canceled automatically. + timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout) defer cancel() + + // Replace the default Fiber context with our timeout-bound context. ctx.SetContext(timeoutContext) - if err := h(ctx); err != nil { - if errors.Is(err, context.DeadlineExceeded) { - return fiber.ErrRequestTimeout - } - for i := range tErrs { - if errors.Is(err, tErrs[i]) { - return fiber.ErrRequestTimeout - } - } - return err + + // Run the handler and check for relevant errors. + err := runHandler(ctx, h, tErrs) + + // If the context actually timed out, return a timeout error. + if errors.Is(timeoutContext.Err(), context.DeadlineExceeded) { + return fiber.ErrRequestTimeout + } + return err + } +} + +// runHandler executes the handler and returns fiber.ErrRequestTimeout if it +// sees a deadline exceeded error or one of the custom "timeout-like" errors. +func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error { + // Execute the wrapped handler synchronously. + err := h(c) + // If the context has timed out, return a request timeout error. + if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) { + return fiber.ErrRequestTimeout + } + return err +} + +// isCustomError checks whether err matches any error in errList using errors.Is. +func isCustomError(err error, errList []error) bool { + for _, e := range errList { + if errors.Is(err, e) { + return true } - return nil } + return false } diff --git a/middleware/timeout/timeout_test.go b/middleware/timeout/timeout_test.go index 2e1756184c..161296a71a 100644 --- a/middleware/timeout/timeout_test.go +++ b/middleware/timeout/timeout_test.go @@ -12,77 +12,119 @@ import ( "github.com/stretchr/testify/require" ) -// go test -run Test_WithContextTimeout -func Test_WithContextTimeout(t *testing.T) { - t.Parallel() - // fiber instance - app := fiber.New() - h := New(func(c fiber.Ctx) error { - sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms") - require.NoError(t, err) - if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil { - return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err)) - } +var ( + // Custom error that we treat like a timeout when returned by the handler. + errCustomTimeout = errors.New("custom timeout error") + + // Some unrelated error that should NOT trigger a request timeout. + errUnrelated = errors.New("unmatched error") +) + +// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled. +func sleepWithContext(ctx context.Context, d time.Duration, te error) error { + timer := time.NewTimer(d) + defer timer.Stop() // Clean up the timer + + select { + case <-ctx.Done(): + return te + case <-timer.C: return nil - }, 100*time.Millisecond) - app.Get("/test/:sleepTime", h) - testTimeout := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") - } - testSucces := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") } - testTimeout("300") - testTimeout("500") - testSucces("50") - testSucces("30") } -var ErrFooTimeOut = errors.New("foo context canceled") +// TestTimeout_Success tests a handler that completes within the allotted timeout. +func TestTimeout_Success(t *testing.T) { + t.Parallel() + app := fiber.New() + + // Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit. + app.Get("/fast", New(func(c fiber.Ctx) error { + // Simulate some work + if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil { + return err + } + return c.SendString("OK") + }, 50*time.Millisecond)) + + req := httptest.NewRequest(fiber.MethodGet, "/fast", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests") +} -// go test -run Test_WithContextTimeoutWithCustomError -func Test_WithContextTimeoutWithCustomError(t *testing.T) { +// TestTimeout_Exceeded tests a handler that exceeds the provided timeout. +func TestTimeout_Exceeded(t *testing.T) { t.Parallel() - // fiber instance app := fiber.New() - h := New(func(c fiber.Ctx) error { - sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms") - require.NoError(t, err) - if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil { - return fmt.Errorf("%w: execution error", err) + + // This handler sleeps 200ms, exceeding the 100ms limit. + app.Get("/slow", New(func(c fiber.Ctx) error { + if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil { + return err } - return nil - }, 100*time.Millisecond, ErrFooTimeOut) - app.Get("/test/:sleepTime", h) - testTimeout := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") - } - testSucces := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") - } - testTimeout("300") - testTimeout("500") - testSucces("50") - testSucces("30") + return c.SendString("Should never get here") + }, 100*time.Millisecond)) + + req := httptest.NewRequest(fiber.MethodGet, "/slow", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout") } -func sleepWithContext(ctx context.Context, d time.Duration, te error) error { - timer := time.NewTimer(d) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C +// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout. +func TestTimeout_CustomError(t *testing.T) { + t.Parallel() + app := fiber.New() + + // This handler sleeps 50ms and returns errCustomTimeout if canceled. + app.Get("/custom", New(func(c fiber.Ctx) error { + // Sleep might time out, or might return early. If the context is canceled, + // we treat errCustomTimeout as a 'timeout-like' condition. + if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil { + return fmt.Errorf("wrapped: %w", err) } - return te - case <-timer.C: - } - return nil + return c.SendString("Should never get here") + }, 100*time.Millisecond, errCustomTimeout)) + + req := httptest.NewRequest(fiber.MethodGet, "/custom", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error") +} + +// TestTimeout_UnmatchedError checks that if the handler returns an error +// that is neither a deadline exceeded nor a custom 'timeout' error, it is +// propagated as a regular 500 (internal server error). +func TestTimeout_UnmatchedError(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Get("/unmatched", New(func(_ fiber.Ctx) error { + return errUnrelated // Not in the custom error list + }, 100*time.Millisecond, errCustomTimeout)) + + req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode, + "Expected 500 because the error is not recognized as a timeout error") +} + +// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero. +// Usually this means the request can never exceed a 'deadline' – effectively no timeout. +func TestTimeout_ZeroDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Get("/zero", New(func(c fiber.Ctx) error { + // Sleep 50ms, but there's no real 'deadline' since zero-timeout. + time.Sleep(50 * time.Millisecond) + return c.SendString("No timeout used") + }, 0)) + + req := httptest.NewRequest(fiber.MethodGet, "/zero", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout") } From 6c7473b842148a0a59bc2c0d511f330373fa3fde Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 12:46:21 +0000 Subject: [PATCH 13/20] build(deps): bump github.com/mattn/go-colorable from 0.1.13 to 0.1.14 Bumps [github.com/mattn/go-colorable](https://github.com/mattn/go-colorable) from 0.1.13 to 0.1.14. - [Commits](https://github.com/mattn/go-colorable/compare/v0.1.13...v0.1.14) --- updated-dependencies: - dependency-name: github.com/mattn/go-colorable dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 00e9dcc945..8e898af058 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/gofiber/schema v1.2.0 github.com/gofiber/utils/v2 v2.0.0-beta.7 github.com/google/uuid v1.6.0 - github.com/mattn/go-colorable v0.1.13 + github.com/mattn/go-colorable v0.1.14 github.com/mattn/go-isatty v0.0.20 github.com/stretchr/testify v1.10.0 github.com/tinylib/msgp v1.2.5 diff --git a/go.sum b/go.sum index 3c58c2c59a..db980e3314 100644 --- a/go.sum +++ b/go.sum @@ -12,9 +12,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= @@ -39,7 +38,6 @@ golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= From 4e5fea1d7a830a52cbe736d2da88660a2e959549 Mon Sep 17 00:00:00 2001 From: Giovanni Rivera Date: Mon, 13 Jan 2025 05:18:03 -0800 Subject: [PATCH 14/20] =?UTF-8?q?=F0=9F=A9=B9=20Fix:=20Fix=20app.Test()=20?= =?UTF-8?q?auto-failing=20when=20a=20connection=20is=20closed=20early=20(#?= =?UTF-8?q?3279)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ Refactor: Extract testConn err to variable * ♻️ Refactor: Extract ErrTestGotEmptyResponse from app.Test() * 🩹 Fix: Fix `app.Test()` auto-failing when testConn is closed * 🩹 Fix(app_test.go): Use tab for indent instead of spaces * 🩹 Fix(app_test.go): Fix to respect gofmt linter * ♻️ Refactor: Update Drop tests to verify error type --- app.go | 6 ++++-- app_test.go | 19 +++++++++++++++++-- ctx_test.go | 4 ++-- helpers.go | 4 +++- helpers_test.go | 2 +- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/app.go b/app.go index e96642aed2..dca7efed00 100644 --- a/app.go +++ b/app.go @@ -941,6 +941,8 @@ func (app *App) Hooks() *Hooks { return app.hooks } +var ErrTestGotEmptyResponse = errors.New("test: got empty response") + // TestConfig is a struct holding Test settings type TestConfig struct { // Timeout defines the maximum duration a @@ -1022,7 +1024,7 @@ func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, e } // Check for errors - if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) { + if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) && !errors.Is(err, errTestConnClosed) { return nil, err } @@ -1033,7 +1035,7 @@ func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, e res, err := http.ReadResponse(buffer, req) if err != nil { if errors.Is(err, io.ErrUnexpectedEOF) { - return nil, errors.New("test: got empty response") + return nil, ErrTestGotEmptyResponse } return nil, fmt.Errorf("failed to read response: %w", err) } diff --git a/app_test.go b/app_test.go index 8455ded86e..1b2b7a40d9 100644 --- a/app_test.go +++ b/app_test.go @@ -1491,7 +1491,7 @@ func Test_App_Test_timeout(t *testing.T) { Timeout: 100 * time.Millisecond, FailOnTimeout: true, }) - require.Equal(t, os.ErrDeadlineExceeded, err) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) } func Test_App_Test_timeout_empty_response(t *testing.T) { @@ -1507,7 +1507,22 @@ func Test_App_Test_timeout_empty_response(t *testing.T) { Timeout: 100 * time.Millisecond, FailOnTimeout: false, }) - require.Equal(t, errors.New("test: got empty response"), err) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) +} + +func Test_App_Test_drop_empty_response(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/", func(c Ctx) error { + return c.Drop() + }) + + _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), TestConfig{ + Timeout: 0, + FailOnTimeout: false, + }) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) } func Test_App_SetTLSHandler(t *testing.T) { diff --git a/ctx_test.go b/ctx_test.go index eb81876e37..f094a2c494 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -5896,7 +5896,7 @@ func Test_Ctx_Drop(t *testing.T) { // Test the Drop method resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) - require.Error(t, err) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) require.Nil(t, resp) // Test the no-response handler @@ -5927,7 +5927,7 @@ func Test_Ctx_DropWithMiddleware(t *testing.T) { // Test the Drop method resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) - require.Error(t, err) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) require.Nil(t, resp) } diff --git a/helpers.go b/helpers.go index 526074032a..04a1da6907 100644 --- a/helpers.go +++ b/helpers.go @@ -612,6 +612,8 @@ func isNoCache(cacheControl string) bool { return true } +var errTestConnClosed = errors.New("testConn is closed") + type testConn struct { r bytes.Buffer w bytes.Buffer @@ -631,7 +633,7 @@ func (c *testConn) Write(b []byte) (int, error) { defer c.Unlock() if c.isClosed { - return 0, errors.New("testConn is closed") + return 0, errTestConnClosed } return c.w.Write(b) //nolint:wrapcheck // This must not be wrapped } diff --git a/helpers_test.go b/helpers_test.go index 28a5df2ae7..a3f631bbfd 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -548,7 +548,7 @@ func Test_Utils_TestConn_Closed_Write(t *testing.T) { // Close early, write should fail conn.Close() //nolint:errcheck, revive // It is fine to ignore the error here _, err = conn.Write([]byte("Response 2\n")) - require.Error(t, err) + require.ErrorIs(t, err, errTestConnClosed) res := make([]byte, 11) _, err = conn.w.Read(res) From a42ddc100e5739ce72e20533dfee634dcc70ef9e Mon Sep 17 00:00:00 2001 From: Giovanni Rivera Date: Thu, 16 Jan 2025 02:54:46 -0800 Subject: [PATCH 15/20] =?UTF-8?q?=F0=9F=94=A5=20feat:=20Add=20End()=20meth?= =?UTF-8?q?od=20to=20Ctx=20(#3280)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔥 Feature(v3): Add End() method to Ctx * :art: Style(Ctx): Respect linter in tests * 🚨 Test(End): Add timeout test for c.End() * 📚 Doc: Update End() documentation examples to use 4 spaces * 🚨 Test: Update `c.End()` tests to use StatusOK --------- Co-authored-by: Giovanni Rivera Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- ctx.go | 17 ++++++++++ ctx_interface_gen.go | 5 ++- ctx_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++ docs/api/ctx.md | 48 +++++++++++++++++++++++++++ docs/whats_new.md | 36 +++++++++++++++++++++ 5 files changed, 182 insertions(+), 1 deletion(-) diff --git a/ctx.go b/ctx.go index 9d73c08bd4..3af6b600e7 100644 --- a/ctx.go +++ b/ctx.go @@ -1986,3 +1986,20 @@ func (c *DefaultCtx) Drop() error { //nolint:wrapcheck // error wrapping is avoided to keep the operation lightweight and focused on connection closure. return c.RequestCtx().Conn().Close() } + +// End immediately flushes the current response and closes the underlying connection. +func (c *DefaultCtx) End() error { + ctx := c.RequestCtx() + conn := ctx.Conn() + + bw := bufio.NewWriter(conn) + if err := ctx.Response.Write(bw); err != nil { + return err + } + + if err := bw.Flush(); err != nil { + return err //nolint:wrapcheck // unnecessary to wrap it + } + + return conn.Close() //nolint:wrapcheck // unnecessary to wrap it +} diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index cd93c48905..101068a269 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -12,7 +12,8 @@ import ( "github.com/valyala/fasthttp" ) -// Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on. +// Ctx represents the Context which hold the HTTP request and response. +// It has methods for the request query string, parameters, body, HTTP headers and so on. type Ctx interface { // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string @@ -353,4 +354,6 @@ type Ctx interface { // This can be useful for silently terminating client connections, such as in DDoS mitigation // or when blocking access to sensitive endpoints. Drop() error + // End immediately flushes the current response and closes the underlying connection. + End() error } diff --git a/ctx_test.go b/ctx_test.go index f094a2c494..af3088662c 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -5931,6 +5931,83 @@ func Test_Ctx_DropWithMiddleware(t *testing.T) { require.Nil(t, resp) } +// go test -run Test_Ctx_End +func Test_Ctx_End(t *testing.T) { + app := New() + + app.Get("/", func(c Ctx) error { + c.SendString("Hello, World!") //nolint:errcheck // unnecessary to check error + return c.End() + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "io.ReadAll(resp.Body)") + require.Equal(t, "Hello, World!", string(body)) +} + +// go test -run Test_Ctx_End_after_timeout +func Test_Ctx_End_after_timeout(t *testing.T) { + app := New() + + // Early flushing handler + app.Get("/", func(c Ctx) error { + time.Sleep(2 * time.Second) + return c.End() + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Nil(t, resp) +} + +// go test -run Test_Ctx_End_with_drop_middleware +func Test_Ctx_End_with_drop_middleware(t *testing.T) { + app := New() + + // Middleware that will drop connections + // that persist after c.Next() + app.Use(func(c Ctx) error { + c.Next() //nolint:errcheck // unnecessary to check error + return c.Drop() + }) + + // Early flushing handler + app.Get("/", func(c Ctx) error { + c.SendStatus(StatusOK) //nolint:errcheck // unnecessary to check error + return c.End() + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusOK, resp.StatusCode) +} + +// go test -run Test_Ctx_End_after_drop +func Test_Ctx_End_after_drop(t *testing.T) { + app := New() + + // Middleware that ends the request + // after c.Next() + app.Use(func(c Ctx) error { + c.Next() //nolint:errcheck // unnecessary to check error + return c.End() + }) + + // Early flushing handler + app.Get("/", func(c Ctx) error { + return c.Drop() + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) + require.Nil(t, resp) +} + // go test -run Test_GenericParseTypeString func Test_GenericParseTypeString(t *testing.T) { t.Parallel() diff --git a/docs/api/ctx.md b/docs/api/ctx.md index b65532532a..fda9f37328 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -484,6 +484,54 @@ app.Get("/", func(c fiber.Ctx) error { }) ``` +## End + +End immediately flushes the current response and closes the underlying connection. + +```go title="Signature" +func (c fiber.Ctx) End() error +``` + +```go title="Example" +app.Get("/", func(c fiber.Ctx) error { + c.SendString("Hello World!") + return c.End() +}) +``` + +:::caution +Calling `c.End()` will disallow further writes to the underlying connection. +::: + +End can be used to stop a middleware from modifying a response of a handler/other middleware down the method chain +when they regain control after calling `c.Next()`. + +```go title="Example" +// Error Logging/Responding middleware +app.Use(func(c fiber.Ctx) error { + err := c.Next() + + // Log errors & write the error to the response + if err != nil { + log.Printf("Got error in middleware: %v", err) + return c.Writef("(got error %v)", err) + } + + // No errors occured + return nil +}) + +// Handler with simulated error +app.Get("/", func(c fiber.Ctx) error { + // Closes the connection instantly after writing from this handler + // and disallow further modification of its response + defer c.End() + + c.SendString("Hello, ... I forgot what comes next!") + return errors.New("some error") +}) +``` + ## Format Performs content-negotiation on the [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. It uses [Accepts](ctx.md#accepts) to select a proper format from the supplied offers. A default handler can be provided by setting the `MediaType` to `"default"`. If no offers match and no default is provided, a 406 (Not Acceptable) response is sent. The Content-Type is automatically set when a handler is selected. diff --git a/docs/whats_new.md b/docs/whats_new.md index 7501c2b57e..0c4d749cd7 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -341,6 +341,7 @@ testConfig := fiber.TestConfig{ - **String**: Similar to Express.js, converts a value to a string. - **ViewBind**: Binds data to a view, replacing the old `Bind` method. - **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications. +- **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection. ### Removed Methods @@ -403,6 +404,41 @@ app.Get("/sse", func(c fiber.Ctx) { You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md). +### End + +In v3, we introduced a new method to match the Express.js API's `res.end()` method. + +```go +func (c Ctx) End() +``` + +With this method, you can: + +- Stop middleware from controlling the connection after a handler further up the method chain + by immediately flushing the current response and closing the connection. +- Use `return c.End()` as an alternative to `return nil` + +```go +app.Use(func (c fiber.Ctx) error { + err := c.Next() + if err != nil { + log.Println("Got error: %v", err) + return c.SendString(err.Error()) // Will be unsuccessful since the response ended below + } + return nil +}) + +app.Get("/hello", func (c fiber.Ctx) error { + query := c.Query("name", "") + if query == "" { + c.SendString("You don't have a name?") + c.End() // Closes the underlying connection + return errors.New("No name provided") + } + return c.SendString("Hello, " + query + "!") +}) +``` + --- ## 🌎 Client package From b31184e0b75b0af1fce38dce2e4e63ccc321c025 Mon Sep 17 00:00:00 2001 From: aliziyacevik Date: Sat, 18 Jan 2025 14:31:55 +0300 Subject: [PATCH 16/20] Doc: Added missing ctx.Drop() to whats_new.md --- docs/whats_new.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/whats_new.md b/docs/whats_new.md index 0c4d749cd7..81339d6a48 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -342,6 +342,7 @@ testConfig := fiber.TestConfig{ - **ViewBind**: Binds data to a view, replacing the old `Bind` method. - **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications. - **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection. +- **Drop**: Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access. ### Removed Methods From 1dedc8034cb85f9a2dd36493e4af5368f555ba1e Mon Sep 17 00:00:00 2001 From: Giovanni Rivera Date: Sat, 18 Jan 2025 16:53:12 -0800 Subject: [PATCH 17/20] =?UTF-8?q?=F0=9F=93=9A=20Docs:=20Add=20`c.Drop()`?= =?UTF-8?q?=20example=20to=20`whats=5Fnew.md`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Drop section with an example in `whats_new.md` - Reorder `c.Drop()` and `c.End()` to match source code order in `whats_new.md` --- docs/whats_new.md | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/docs/whats_new.md b/docs/whats_new.md index 81339d6a48..3b3a43e94c 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -341,8 +341,8 @@ testConfig := fiber.TestConfig{ - **String**: Similar to Express.js, converts a value to a string. - **ViewBind**: Binds data to a view, replacing the old `Bind` method. - **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications. -- **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection. - **Drop**: Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access. +- **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection. ### Removed Methods @@ -405,6 +405,37 @@ app.Get("/sse", func(c fiber.Ctx) { You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md). +### Drop + +In v3, we introduced support to silently terminate requests through `Drop`. + +```go +func (c Ctx) Drop() +``` + +With this method, you can: + +- Block certain requests without notifying the client to mitigate DDoS attacks +- Protect sensitive endpoints from unauthorized access without leaking errors. + +:::caution +While this feature adds the ability to drop connections, it is still **highly recommended** to use additional +measures (such as **firewalls**, **proxies**, etc.) to further protect your server endpoints by blocking +malicious connections before the server establishes a connection. +::: + +```go +app.Get("/", func(c fiber.Ctx) error { + if c.IP() == "192.168.1.1" { + return c.Drop() + } + + return c.SendString("Hello World!") +}) +``` + +You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md). + ### End In v3, we introduced a new method to match the Express.js API's `res.end()` method. From 927e3b3266c1f560480619dcb793a9633ad44f5e Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sun, 19 Jan 2025 02:38:34 -0500 Subject: [PATCH 18/20] Update whats_new.md --- docs/whats_new.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/whats_new.md b/docs/whats_new.md index 3b3a43e94c..d9e857181a 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -426,11 +426,11 @@ malicious connections before the server establishes a connection. ```go app.Get("/", func(c fiber.Ctx) error { - if c.IP() == "192.168.1.1" { - return c.Drop() - } + if c.IP() == "192.168.1.1" { + return c.Drop() + } - return c.SendString("Hello World!") + return c.SendString("Hello World!") }) ``` From 8970f515dd41958d23ac39edb56c8dc72baa7f82 Mon Sep 17 00:00:00 2001 From: miyamo2 <79917704+miyamo2@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:22:51 +0900 Subject: [PATCH 19/20] =?UTF-8?q?=F0=9F=90=9B=20fix:=20Align=20cache=20mid?= =?UTF-8?q?dleware=20with=20RFC7231=20(#3283)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🩹 Fix(v3;middleware/cache): don't cache if status code is not cacheable * allow 418 TeaPot * fix test * fix lint error * check cacheability with map * documentation * fix: markdown lint --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- docs/middleware/cache.md | 25 +++++++++++ docs/whats_new.md | 3 +- middleware/cache/cache.go | 21 +++++++++ middleware/cache/cache_test.go | 81 ++++++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+), 1 deletion(-) diff --git a/docs/middleware/cache.md b/docs/middleware/cache.md index 0723c615dc..08c7ad5989 100644 --- a/docs/middleware/cache.md +++ b/docs/middleware/cache.md @@ -10,6 +10,31 @@ Request Directives
`Cache-Control: no-cache` will return the up-to-date response but still caches it. You will always get a `miss` cache status.
`Cache-Control: no-store` will refrain from caching. You will always get the up-to-date response. +Cacheable Status Codes
+ +This middleware caches responses with the following status codes according to RFC7231: + +- `200: OK` +- `203: Non-Authoritative Information` +- `204: No Content` +- `206: Partial Content` +- `300: Multiple Choices` +- `301: Moved Permanently` +- `404: Not Found` +- `405: Method Not Allowed` +- `410: Gone` +- `414: URI Too Long` +- `501: Not Implemented` + +Additionally, `418: I'm a teapot` is not originally cacheable but is cached by this middleware. +If the status code is other than these, you will always get an `unreachable` cache status. + +For more information about cacheable status codes or RFC7231, please refer to the following resources: + +- [Cacheable - MDN Web Docs](https://developer.mozilla.org/en-US/docs/Glossary/Cacheable) + +- [RFC7231 - Hypertext Transfer Protocol (HTTP/1.1): Semantics and Content](https://datatracker.ietf.org/doc/html/rfc7231) + ## Signatures ```go diff --git a/docs/whats_new.md b/docs/whats_new.md index d9e857181a..57e4bf0ba2 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -764,7 +764,8 @@ The adaptor middleware has been significantly optimized for performance and effi ### Cache -We are excited to introduce a new option in our caching middleware: Cache Invalidator. This feature provides greater control over cache management, allowing you to define a custom conditions for invalidating cache entries. +We are excited to introduce a new option in our caching middleware: Cache Invalidator. This feature provides greater control over cache management, allowing you to define a custom conditions for invalidating cache entries. +Additionally, the caching middleware has been optimized to avoid caching non-cacheable status codes, as defined by the [HTTP standards](https://datatracker.ietf.org/doc/html/rfc7231#section-6.1). This improvement enhances cache accuracy and reduces unnecessary cache storage usage. ### CORS diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 5c832f0b96..723b5321e2 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -48,6 +48,21 @@ var ignoreHeaders = map[string]any{ "Content-Encoding": nil, // already stored explicitly by the cache manager } +var cacheableStatusCodes = map[int]bool{ + fiber.StatusOK: true, + fiber.StatusNonAuthoritativeInformation: true, + fiber.StatusNoContent: true, + fiber.StatusPartialContent: true, + fiber.StatusMultipleChoices: true, + fiber.StatusMovedPermanently: true, + fiber.StatusNotFound: true, + fiber.StatusMethodNotAllowed: true, + fiber.StatusGone: true, + fiber.StatusRequestURITooLong: true, + fiber.StatusTeapot: true, + fiber.StatusNotImplemented: true, +} + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config @@ -170,6 +185,12 @@ func New(config ...Config) fiber.Handler { return err } + // Don't cache response if status code is not cacheable + if !cacheableStatusCodes[c.Response().StatusCode()] { + c.Set(cfg.CacheHeader, cacheUnreachable) + return nil + } + // lock entry back and unlock on finish mux.Lock() defer mux.Unlock() diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 22ab0e2895..2193decb25 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -918,6 +918,87 @@ func Test_Cache_MaxBytesSizes(t *testing.T) { } } +func Test_Cache_UncacheableStatusCodes(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New()) + + app.Get("/:statusCode", func(c fiber.Ctx) error { + statusCode, err := strconv.Atoi(c.Params("statusCode")) + require.NoError(t, err) + return c.Status(statusCode).SendString("foo") + }) + + uncacheableStatusCodes := []int{ + // Informational responses + fiber.StatusContinue, + fiber.StatusSwitchingProtocols, + fiber.StatusProcessing, + fiber.StatusEarlyHints, + + // Successful responses + fiber.StatusCreated, + fiber.StatusAccepted, + fiber.StatusResetContent, + fiber.StatusMultiStatus, + fiber.StatusAlreadyReported, + fiber.StatusIMUsed, + + // Redirection responses + fiber.StatusFound, + fiber.StatusSeeOther, + fiber.StatusNotModified, + fiber.StatusUseProxy, + fiber.StatusSwitchProxy, + fiber.StatusTemporaryRedirect, + fiber.StatusPermanentRedirect, + + // Client error responses + fiber.StatusBadRequest, + fiber.StatusUnauthorized, + fiber.StatusPaymentRequired, + fiber.StatusForbidden, + fiber.StatusNotAcceptable, + fiber.StatusProxyAuthRequired, + fiber.StatusRequestTimeout, + fiber.StatusConflict, + fiber.StatusLengthRequired, + fiber.StatusPreconditionFailed, + fiber.StatusRequestEntityTooLarge, + fiber.StatusUnsupportedMediaType, + fiber.StatusRequestedRangeNotSatisfiable, + fiber.StatusExpectationFailed, + fiber.StatusMisdirectedRequest, + fiber.StatusUnprocessableEntity, + fiber.StatusLocked, + fiber.StatusFailedDependency, + fiber.StatusTooEarly, + fiber.StatusUpgradeRequired, + fiber.StatusPreconditionRequired, + fiber.StatusTooManyRequests, + fiber.StatusRequestHeaderFieldsTooLarge, + fiber.StatusUnavailableForLegalReasons, + + // Server error responses + fiber.StatusInternalServerError, + fiber.StatusBadGateway, + fiber.StatusServiceUnavailable, + fiber.StatusGatewayTimeout, + fiber.StatusHTTPVersionNotSupported, + fiber.StatusVariantAlsoNegotiates, + fiber.StatusInsufficientStorage, + fiber.StatusLoopDetected, + fiber.StatusNotExtended, + fiber.StatusNetworkAuthenticationRequired, + } + for _, v := range uncacheableStatusCodes { + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", v), nil)) + require.NoError(t, err) + require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) + require.Equal(t, v, resp.StatusCode) + } +} + // go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4 func Benchmark_Cache(b *testing.B) { app := fiber.New() From 2eb6808e297651813b371ed7fabd325f703805a2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:38:24 +0100 Subject: [PATCH 20/20] build(deps): bump codecov/codecov-action from 5.1.2 to 5.3.0 (#3292) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 5.1.2 to 5.3.0. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v5.1.2...v5.3.0) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 54143ff425..8c2cbb123e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: - name: Upload coverage reports to Codecov if: ${{ matrix.platform == 'ubuntu-latest' && matrix.go-version == '1.23.x' }} - uses: codecov/codecov-action@v5.1.2 + uses: codecov/codecov-action@v5.3.0 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.txt