Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,11 @@ userAuthLoop:
candidate.user = s.user
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
var pse *PartialSuccessError
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if !isPartialSuccessError {
isPartialSuccessError = errors.As(candidate.result, &pse)
}
if isPartialSuccessError && config.VerifiedPublicKeyCallback != nil {
return nil, errors.New("ssh: invalid library usage: PublicKeyCallback must not return partial success when VerifiedPublicKeyCallback is defined")
}
Expand All @@ -804,8 +808,12 @@ userAuthLoop:
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if candidate.result == nil || isPartialSuccessError {
var pse2 *PartialSuccessError
_, isPartialSuccessError2 := candidate.result.(*PartialSuccessError)
if !isPartialSuccessError2 {
isPartialSuccessError2 = errors.As(candidate.result, &pse2)
}
if candidate.result == nil || isPartialSuccessError2 {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: pubKeyData,
Expand Down Expand Up @@ -946,7 +954,8 @@ userAuthLoop:

var failureMsg userAuthFailureMsg

if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
var partialSuccess *PartialSuccessError
if ok := errors.As(authErr, &partialSuccess); ok {
// Permissions are not preserved between authentication steps. To
// avoid confusion about the final state of the connection, we
// disallow returning non-nil Permissions combined with
Expand Down
65 changes: 65 additions & 0 deletions ssh/server_multi_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,68 @@ func TestDynamicAuthCallbacks(t *testing.T) {
t.Fatal("server not returned partial success")
}
}

// TestPartialSuccessErrorWrappedInBannerError verifies that a PartialSuccessError
// wrapped inside a BannerError is correctly detected via errors.As, rather than
// silently treated as an authentication failure. Prior to the fix, the direct
// type assertion authErr.(*PartialSuccessError) would fail when authErr is a
// *BannerError, causing the partial-success state to be lost and authFailures
// to be incremented instead.
func TestPartialSuccessErrorWrappedInBannerError(t *testing.T) {
username := "testuser"
errPwdAuthFailed := errors.New("password auth failed")

// The PasswordCallback returns a BannerError wrapping a PartialSuccessError.
// This is valid API usage: the callback wants to both send a banner message
// AND signal partial success requiring a second factor.
serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
if string(password) == clientPassword {
// First factor OK; wrap PartialSuccessError in a BannerError so
// a banner is also sent to the client.
return nil, &BannerError{
Err: &PartialSuccessError{
Next: ServerAuthCallbacks{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
if string(password) == clientPassword {
return nil, nil
}
return nil, errPwdAuthFailed
},
},
},
Message: "First factor accepted; please provide second factor.",
}
}
return nil, errPwdAuthFailed
},
}

clientConfig := &ClientConfig{
User: username,
Auth: []AuthMethod{
// Two password attempts: first triggers partial success, second completes login.
Password(clientPassword),
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: BannerDisplayStderr(),
}

serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
if err != nil {
t.Fatalf("client login error: %v (PartialSuccessError wrapped in BannerError was not detected)", err)
}

// Expected sequence:
// [0] ErrNoAuth (none method)
// [1] BannerError wrapping PartialSuccessError (first password)
// [2] nil (second password succeeds)
if len(serverAuthErrors) != 3 {
t.Fatalf("unexpected number of server auth errors: %d, errors: %+v", len(serverAuthErrors), serverAuthErrors)
}
var pse *PartialSuccessError
if !errors.As(serverAuthErrors[1], &pse) {
t.Fatalf("expected a PartialSuccessError (possibly wrapped) at index 1, got: %v", serverAuthErrors[1])
}
}