Skip to content

Commit

Permalink
different copies for key and req objects in http send executor
Browse files Browse the repository at this point in the history
Signed-off-by: Rudrakh Panigrahi <[email protected]>
  • Loading branch information
rudrakhp committed May 26, 2024
1 parent 4bcaa77 commit 85d4af0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 41 deletions.
86 changes: 47 additions & 39 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) {
if err != nil {
return nil, err
}
reqExecutor, err := newHTTPRequestExecutor(bctx, key)

reqExecutor, err := newHTTPRequestExecutor(bctx, req, key)
if err != nil {
return nil, err
}

// Check if cache already has a response for this query
// set headers to exclude cache_ignored_headers
resp, err := reqExecutor.CheckCache()
if err != nil {
return nil, err
Expand Down Expand Up @@ -206,33 +207,38 @@ func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) {
// getKeyFromRequest returns a key to be used for caching HTTP responses
// deletes headers from request object mentioned in cache_ignored_headers
func getKeyFromRequest(req ast.Object) (ast.Object, error) {
var cacheIgnoredHeaders []string
var allHeaders map[string]interface{}
// new copy so headers in request object doesn't change
key := req.Copy()
cacheIgnoredHeadersTerm := req.Get(ast.StringTerm("cache_ignored_headers"))
allHeadersTerm := req.Get(ast.StringTerm("headers"))
if cacheIgnoredHeadersTerm != nil && allHeadersTerm != nil {
err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders)
if err != nil {
return nil, err
}
err = ast.As(allHeadersTerm.Value, &allHeaders)
if err != nil {
return nil, err
}
for _, header := range cacheIgnoredHeaders {
delete(allHeaders, header)
}
val, err := ast.InterfaceToValue(allHeaders)
if err != nil {
return nil, err
}
allHeadersTerm.Value = val
req.Insert(ast.StringTerm("headers"), allHeadersTerm)
// skip because no headers to delete
if cacheIgnoredHeadersTerm == nil || allHeadersTerm == nil {
// need to explicitly set cache_ignored_headers to null
// equivalent requests might have different sets of exclusion lists
key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm())
return key, nil
}
if cacheIgnoredHeadersTerm != nil {
req.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm())
var cacheIgnoredHeaders []string
var allHeaders map[string]interface{}
err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders)
if err != nil {
return nil, err
}
return req, nil
err = ast.As(allHeadersTerm.Value, &allHeaders)
if err != nil {
return nil, err
}
for _, header := range cacheIgnoredHeaders {
delete(allHeaders, header)
}
val, err := ast.InterfaceToValue(allHeaders)
if err != nil {
return nil, err
}
key.Insert(ast.StringTerm("headers"), ast.NewTerm(val))
// remove cache_ignored_headers key
key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm())
return key, nil
}

func init() {
Expand Down Expand Up @@ -766,13 +772,13 @@ func newHTTPSendCache() *httpSendCache {
}

func valueHash(v util.T) int {
return v.(ast.Value).Hash()
return ast.StringTerm(v.(ast.Value).String()).Hash()
}

func valueEq(a, b util.T) bool {
av := a.(ast.Value)
bv := b.(ast.Value)
return av.Compare(bv) == 0
return av.String() == bv.String()
}

func (cache *httpSendCache) get(k ast.Value) *httpSendCacheEntry {
Expand Down Expand Up @@ -1419,20 +1425,21 @@ type httpRequestExecutor interface {

// newHTTPRequestExecutor returns a new HTTP request executor that wraps either an inter-query or
// intra-query cache implementation
func newHTTPRequestExecutor(bctx BuiltinContext, key ast.Object) (httpRequestExecutor, error) {
useInterQueryCache, forceCacheParams, err := useInterQueryCache(key)
func newHTTPRequestExecutor(bctx BuiltinContext, req ast.Object, key ast.Object) (httpRequestExecutor, error) {
useInterQueryCache, forceCacheParams, err := useInterQueryCache(req)
if err != nil {
return nil, handleHTTPSendErr(bctx, err)
}

if useInterQueryCache && bctx.InterQueryBuiltinCache != nil {
return newInterQueryCache(bctx, key, forceCacheParams)
return newInterQueryCache(bctx, req, key, forceCacheParams)
}
return newIntraQueryCache(bctx, key)
return newIntraQueryCache(bctx, req, key)
}

type interQueryCache struct {
bctx BuiltinContext
req ast.Object
key ast.Object
httpReq *http.Request
httpClient *http.Client
Expand All @@ -1441,8 +1448,8 @@ type interQueryCache struct {
forceCacheParams *forceCacheParams
}

func newInterQueryCache(bctx BuiltinContext, key ast.Object, forceCacheParams *forceCacheParams) (*interQueryCache, error) {
return &interQueryCache{bctx: bctx, key: key, forceCacheParams: forceCacheParams}, nil
func newInterQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object, forceCacheParams *forceCacheParams) (*interQueryCache, error) {
return &interQueryCache{bctx: bctx, req: req, key: key, forceCacheParams: forceCacheParams}, nil
}

// CheckCache checks the cache for the value of the key set on this object
Expand Down Expand Up @@ -1501,21 +1508,22 @@ func (c *interQueryCache) InsertErrorIntoCache(err error) {
// ExecuteHTTPRequest executes a HTTP request
func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
var err error
c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key)
c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.req)
if err != nil {
return nil, handleHTTPSendErr(c.bctx, err)
}

return executeHTTPRequest(c.httpReq, c.httpClient, c.key)
return executeHTTPRequest(c.httpReq, c.httpClient, c.req)
}

type intraQueryCache struct {
bctx BuiltinContext
req ast.Object
key ast.Object
}

func newIntraQueryCache(bctx BuiltinContext, key ast.Object) (*intraQueryCache, error) {
return &intraQueryCache{bctx: bctx, key: key}, nil
func newIntraQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object) (*intraQueryCache, error) {
return &intraQueryCache{bctx: bctx, req: req, key: key}, nil
}

// CheckCache checks the cache for the value of the key set on this object
Expand Down Expand Up @@ -1552,11 +1560,11 @@ func (c *intraQueryCache) InsertErrorIntoCache(err error) {

// ExecuteHTTPRequest executes a HTTP request
func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
httpReq, httpClient, err := createHTTPRequest(c.bctx, c.key)
httpReq, httpClient, err := createHTTPRequest(c.bctx, c.req)
if err != nil {
return nil, handleHTTPSendErr(c.bctx, err)
}
return executeHTTPRequest(httpReq, httpClient, c.key)
return executeHTTPRequest(httpReq, httpClient, c.req)
}

func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) {
Expand Down
24 changes: 22 additions & 2 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,27 @@ func TestHTTPSendCaching(t *testing.T) {
note: "http.send GET different cache_ignored_headers but still cached (force_cache enabled)",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": ["h2"]})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": ["h2", "h3"]}) # cached
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2", "h3": "v3"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": ["h2", "h3"]}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET different cache_ignored_headers (one of them is nil) but still cached (force_cache enabled)",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1"}, "force_cache": true, "force_cache_duration_seconds": 300})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": ["h2"]}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET different cache_ignored_headers (one of them is empty) but still cached (force_cache enabled)",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": []})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}, "force_cache": true, "force_cache_duration_seconds": 300, "cache_ignored_headers": ["h2"]}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
Expand Down Expand Up @@ -2197,7 +2217,7 @@ func TestInterQueryCheckCacheError(t *testing.T) {
input := ast.MustParseTerm(`{"force_cache": true}`)
inputObj := input.Value.(ast.Object)

_, err := newHTTPRequestExecutor(BuiltinContext{Context: context.Background()}, inputObj)
_, err := newHTTPRequestExecutor(BuiltinContext{Context: context.Background()}, inputObj, inputObj)
if err == nil {
t.Fatal("expected error but got nil")
}
Expand Down

0 comments on commit 85d4af0

Please sign in to comment.