Skip to content

Commit

Permalink
fix: include client_assertion in token hook payload
Browse files Browse the repository at this point in the history
  • Loading branch information
phooijenga committed Feb 20, 2025
1 parent 5d2ca41 commit c1d0eb1
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 1 deletion.
187 changes: 187 additions & 0 deletions oauth2/oauth2_jwt_bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,190 @@ func TestJWTBearer(t *testing.T) {
t.Run("strategy=jwt", run("jwt"))
})
}

func TestJWTClientAssertion(t *testing.T) {
ctx := context.Background()

reg := testhelpers.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
_, admin := testhelpers.NewOAuth2Server(ctx, t, reg)

set, kid := uuid.NewString(), uuid.NewString()
keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig")
require.NoError(t, err)
signer := jwk.NewDefaultJWTSigner(reg.Config(), reg, set)
signer.GetPrivateKey = func(ctx context.Context) (interface{}, error) {
return keys.Keys[0], nil
}

client := &hc.Client{
GrantTypes: []string{"client_credentials"},
Scope: "offline_access",
TokenEndpointAuthMethod: "private_key_jwt",
JSONWebKeys: &x.JoseJSONWebKeySet{
JSONWebKeySet: &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{keys.Keys[0].Public()},
},
},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, client))

var newConf = func(client *hc.Client) *clientcredentials.Config {
return &clientcredentials.Config{
AuthStyle: goauth2.AuthStyleInParams,
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
Scopes: strings.Split(client.Scope, " "),
EndpointParams: url.Values{
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
},
}
}
var getToken = func(t *testing.T, conf *clientcredentials.Config) (*goauth2.Token, error) {
return conf.Token(context.Background())
}

var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, checkExtraClaims bool) {
introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: cl.Secret}, token.AccessToken, admin)

check := func(res gjson.Result) {
assert.EqualValues(t, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw)
assert.EqualValues(t, cl.GetID(), res.Get("sub").String(), "%s", res.Raw)
assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw)

assert.EqualValues(t, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw)
assert.True(t, res.Get("exp").Int() >= res.Get("iat").Int()+int64(reg.Config().GetAccessTokenLifespan(ctx).Seconds()), "%s", res.Raw)

if checkExtraClaims {
require.True(t, res.Get("ext.hooked").Bool())
}
}

check(introspection)
assert.True(t, introspection.Get("active").Bool())
assert.EqualValues(t, "access_token", introspection.Get("token_use").String())
assert.EqualValues(t, "Bearer", introspection.Get("token_type").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

if strategy != "jwt" {
return
}

body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
require.NoError(t, err)
jwtClaims := gjson.ParseBytes(body)
assert.NotEmpty(t, jwtClaims.Get("jti").String())
assert.NotEmpty(t, jwtClaims.Get("iss").String())
assert.NotEmpty(t, jwtClaims.Get("client_id").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

header, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[0])
require.NoError(t, err)
jwtHeader := gjson.ParseBytes(header)
assert.NotEmpty(t, jwtHeader.Get("kid").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

check(jwtClaims)
}

var generateAssertion = func() (string, error) {
token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
"iss": client.GetID(),
"sub": client.GetID(),
"aud": reg.Config().OAuth2TokenURL(ctx).String(),
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-time.Minute).Unix(),
}, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}})
return token, err
}

t.Run("case=unable to exchange invalid jwt", func(t *testing.T) {
conf := newConf(client)
conf.EndpointParams.Set("client_assertion", "not-a-jwt")
_, err := getToken(t, conf)
require.Error(t, err)
assert.Contains(t, err.Error(), "Unable to verify the integrity of the 'client_assertion' value.")
})

t.Run("case=should exchange for an access token", func(t *testing.T) {
run := func(strategy string) func(t *testing.T) {
return func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)

token, err := generateAssertion()
require.NoError(t, err)

conf := newConf(client)
conf.EndpointParams.Set("client_assertion", token)

result, err := getToken(t, conf)
require.NoError(t, err)

inspectToken(t, result, client, strategy, false)
}
}

t.Run("strategy=opaque", run("opaque"))
t.Run("strategy=jwt", run("jwt"))
})

t.Run("should call token hook if configured", func(t *testing.T) {
run := func(strategy string) func(t *testing.T) {
return func(t *testing.T) {
token, err := generateAssertion()
require.NoError(t, err)

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")

expectedGrantedScopes := []string{client.Scope}
expectedPayload := map[string][]string{
"grant_type": {"client_credentials"},
"client_assertion": {token},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"scope": {"offline_access"},
}

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
require.NotEmpty(t, hookReq.Session)
require.Equal(t, hookReq.Session.Extra, map[string]interface{}{})
require.NotEmpty(t, hookReq.Request)
require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes)
require.Equal(t, expectedPayload, hookReq.Request.Payload)

claims := map[string]interface{}{
"hooked": true,
}

hookResp := hydraoauth2.TokenHookResponse{
Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
}

w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
}))
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.EndpointParams.Set("client_assertion", token)

result, err := getToken(t, conf)
require.NoError(t, err)

inspectToken(t, result, client, strategy, true)
}
}

t.Run("strategy=opaque", run("opaque"))
t.Run("strategy=jwt", run("jwt"))
})
}
2 changes: 1 addition & 1 deletion oauth2/token_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func TokenHook(reg interface {
GrantedScopes: requester.GetGrantedScopes(),
GrantedAudience: requester.GetGrantedAudience(),
GrantTypes: requester.GetGrantTypes(),
Payload: requester.Sanitize([]string{"assertion"}).GetRequestForm(),
Payload: requester.Sanitize([]string{"assertion", "client_assertion_type", "client_assertion"}).GetRequestForm(),
}

reqBody := TokenHookRequest{
Expand Down

0 comments on commit c1d0eb1

Please sign in to comment.