Skip to content

Commit

Permalink
allow additional login params to be supplied
Browse files Browse the repository at this point in the history
  • Loading branch information
dovholuknf committed Oct 4, 2024
1 parent d09dc6a commit f4becf1
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 92 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ build
zssh*json
.run
ziti-edge-tunnel*
my.env
*.env

18 changes: 10 additions & 8 deletions zsshlib/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ type SshFlags struct {
}

type OIDCFlags struct {
Mode bool
Issuer string
ClientID string
ClientSecret string
CallbackPort string
AsAscii bool
OIDCOnly bool
ControllerUrl string
Mode bool
Issuer string
ClientID string
ClientSecret string
CallbackPort string
AsAscii bool
OIDCOnly bool
ControllerUrl string
AdditionalLoginParams []string
}

type ScpFlags struct {
Expand Down Expand Up @@ -93,6 +94,7 @@ func (f *SshFlags) OIDCFlags(cmd *cobra.Command) {
cmd.Flags().BoolVarP(&f.OIDC.Mode, "oidc", "o", false, fmt.Sprintf("toggle OIDC mode. default: %t", defaults.OIDC.Enabled))
cmd.Flags().BoolVar(&f.OIDC.OIDCOnly, "oidcOnly", false, "toggle OIDC only mode. default: false")
cmd.Flags().StringVar(&f.OIDC.ControllerUrl, "controllerUrl", "", "the url of the controller to use. only used with --oidcOnly")
cmd.Flags().StringArrayVarP(&f.OIDC.AdditionalLoginParams, "additionalLoginParams", "l", []string{}, "Additional parameters to specify to the login. Can specify multiple times. Must be in the format of param=value")
}

func (f *SshFlags) AddCommonFlags(cmd *cobra.Command) {
Expand Down
121 changes: 117 additions & 4 deletions zsshlib/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@ package zsshlib

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rp/cli"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
)

Expand All @@ -16,10 +25,11 @@ func OIDCFlow(initialContext context.Context, flags *SshFlags) (string, error) {
ClientSecret: flags.OIDC.ClientSecret,
RedirectURL: fmt.Sprintf("http://localhost:%v%v", flags.OIDC.CallbackPort, callbackPath),
},
CallbackPath: callbackPath,
CallbackPort: flags.OIDC.CallbackPort,
Issuer: flags.OIDC.Issuer,
Logf: log.Debugf,
CallbackPath: callbackPath,
CallbackPort: flags.OIDC.CallbackPort,
Issuer: flags.OIDC.Issuer,
Logf: log.Debugf,
AdditionalLoginParams: flags.OIDC.AdditionalLoginParams,
}
waitFor := 30 * time.Second
ctx, cancel := context.WithTimeout(initialContext, waitFor)
Expand All @@ -36,3 +46,106 @@ func OIDCFlow(initialContext context.Context, flags *SshFlags) (string, error) {

return token, nil
}

func zsshCodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, config *OIDCConfig) *oidc.Tokens[C] {
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
defer codeflowCancel()

tokenChan := make(chan *oidc.Tokens[C], 1)

callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
tokenChan <- tokens
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
w.Write([]byte(msg))
}

authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc {
var urlParamOpts rp.URLParamOpt
for _, v := range config.AdditionalLoginParams {
parts := strings.Split(v, "=")
urlParamOpts = rp.WithURLParam(parts[0], parts[1])
}
return func(w http.ResponseWriter, r *http.Request) {
rp.AuthURLHandler(func() string {
return uuid.New().String()
}, party, urlParamOpts /*rp.WithURLParam("audience", "openziti2")*/)(w, r)
}
}

http.Handle("/login", authHandlerWithQueryState(relyingParty))
http.Handle(config.CallbackPath, rp.CodeExchangeHandler(callback, relyingParty))

httphelper.StartServer(codeflowCtx, ":"+config.CallbackPort)

cli.OpenBrowser("http://localhost:" + config.CallbackPort + "/login")

return <-tokenChan
}

// OIDCConfig represents a config for the OIDC auth flow.
type OIDCConfig struct {
// CallbackPath is the path of the callback handler.
CallbackPath string

// CallbackPort is the port of the callback handler.
CallbackPort string

// Issuer is the URL of the OpenID Connect provider.
Issuer string

// HashKey is used to authenticate values using HMAC.
HashKey []byte

// BlockKey is used to encrypt values using AES.
BlockKey []byte

// IDToken is the ID token returned by the OIDC provider.
IDToken string

// Logger function for debug.
Logf func(format string, args ...interface{})

// Additional params to add to the login request
AdditionalLoginParams []string

oauth2.Config
}

// GetToken starts a local HTTP server, opens the web browser to initiate the OIDC Discovery and
// Token Exchange flow, blocks until the user completes authentication and is redirected back, and returns
// the OIDC tokens.
func GetToken(ctx context.Context, config *OIDCConfig) (string, error) {
if err := config.validateAndSetDefaults(); err != nil {
return "", fmt.Errorf("invalid config: %w", err)
}

cookieHandler := httphelper.NewCookieHandler(config.HashKey, config.BlockKey, httphelper.WithUnsecure())

options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
}
if config.ClientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}

relyingParty, err := rp.NewRelyingPartyOIDC(config.Issuer, config.ClientID, config.ClientSecret, config.RedirectURL, config.Scopes, options...)
if err != nil {
logrus.Fatalf("error creating relyingParty %s", err.Error())
}

resultChan := make(chan *oidc.Tokens[*oidc.IDTokenClaims])

go func() {
tokens := zsshCodeFlow[*oidc.IDTokenClaims](ctx, relyingParty, config)
resultChan <- tokens
}()

select {
case tokens := <-resultChan:
return tokens.AccessToken, nil
case <-ctx.Done():
return "", errors.New("timeout: OIDC authentication took too long")
}
}
81 changes: 2 additions & 79 deletions zsshlib/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,8 @@ package zsshlib

import (
"bufio"
"context"
"encoding/base64"
"fmt"
"github.com/google/uuid"
"github.com/gorilla/securecookie"
"github.com/zitadel/oidc/v2/pkg/client/rp/cli"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/crypto/ssh/knownhosts"
"io"
"net"
"os"
Expand All @@ -36,15 +30,13 @@ import (
"sync"
"time"

"github.com/gorilla/securecookie"
"github.com/openziti/sdk-golang/ziti"
"github.com/pkg/errors"
"github.com/pkg/sftp"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"golang.org/x/oauth2"

"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
"golang.org/x/crypto/ssh/terminal"
)

Expand Down Expand Up @@ -140,75 +132,6 @@ func Dial(config *ssh.ClientConfig, conn net.Conn) (*ssh.Client, error) {
return ssh.NewClient(c, chans, reqs), nil
}

// OIDCConfig represents a config for the OIDC auth flow.
type OIDCConfig struct {
// CallbackPath is the path of the callback handler.
CallbackPath string

// CallbackPort is the port of the callback handler.
CallbackPort string

// Issuer is the URL of the OpenID Connect provider.
Issuer string

// HashKey is used to authenticate values using HMAC.
HashKey []byte

// BlockKey is used to encrypt values using AES.
BlockKey []byte

// IDToken is the ID token returned by the OIDC provider.
IDToken string

// Logger function for debug.
Logf func(format string, args ...interface{})

oauth2.Config
}

// GetToken starts a local HTTP server, opens the web browser to initiate the OIDC Discovery and
// Token Exchange flow, blocks until the user completes authentication and is redirected back, and returns
// the OIDC tokens.
func GetToken(ctx context.Context, config *OIDCConfig) (string, error) {
if err := config.validateAndSetDefaults(); err != nil {
return "", fmt.Errorf("invalid config: %w", err)
}

cookieHandler := httphelper.NewCookieHandler(config.HashKey, config.BlockKey, httphelper.WithUnsecure())

options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
}
if config.ClientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}

relyingParty, err := rp.NewRelyingPartyOIDC(config.Issuer, config.ClientID, config.ClientSecret, config.RedirectURL, config.Scopes, options...)
if err != nil {
logrus.Fatalf("error creating relyingParty %s", err.Error())
}

//ctx := context.Background()
state := func() string {
return uuid.New().String()
}

resultChan := make(chan *oidc.Tokens[*oidc.IDTokenClaims])

go func() {
tokens := cli.CodeFlow[*oidc.IDTokenClaims](ctx, relyingParty, config.CallbackPath, config.CallbackPort, state)
resultChan <- tokens
}()

select {
case tokens := <-resultChan:
return tokens.AccessToken, nil
case <-ctx.Done():
return "", errors.New("Timeout: OIDC authentication took too long")
}
}

// validateAndSetDefaults validates the config and sets default values.
func (c *OIDCConfig) validateAndSetDefaults() error {
if c.ClientID == "" {
Expand Down

0 comments on commit f4becf1

Please sign in to comment.