Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amplifyframework.auth.cognito

import android.content.Context
import android.util.Log
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.amplifyframework.auth.cognito.options.AWSCognitoAuthSignInOptions
import com.amplifyframework.auth.cognito.options.AuthFlowType
import com.amplifyframework.auth.cognito.testutils.Credentials
import com.amplifyframework.auth.options.AuthSignInOptions
import com.amplifyframework.core.Amplify
import com.amplifyframework.core.InitializationStatus
import com.amplifyframework.hub.HubChannel
import com.amplifyframework.testutils.DeviceFarmTestBase
import com.amplifyframework.testutils.assertAwait
import com.amplifyframework.testutils.sync.SynchronousAuth
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith

/**
* Integration tests for device key persistence across different auth flows.
*
* Validates that once a device is remembered, the same device key is reused
* across sign-out/sign-in cycles regardless of which auth flow is used
* (USER_PASSWORD_AUTH, USER_SRP_AUTH). The device count should always remain 1.
*
* Prerequisites:
* - Cognito User Pool with Device Tracking set to "Always Remember"
* - USER_PASSWORD_AUTH and USER_SRP_AUTH both enabled on the app client
* - Valid test credentials in credentials.json
*/
@RunWith(AndroidJUnit4::class)
class DeviceKeyPersistenceInstrumentationTest : DeviceFarmTestBase() {

companion object {
val auth = AWSCognitoAuthPlugin()
val syncAuth = SynchronousAuth.delegatingTo(auth)

@BeforeClass
@JvmStatic
fun setUp() {
try {
Amplify.addPlugin(auth)
Amplify.configure(ApplicationProvider.getApplicationContext())
val latch = CountDownLatch(1)
Amplify.Hub.subscribe(HubChannel.AUTH) { event ->
when (event.name) {
InitializationStatus.SUCCEEDED.toString(),
InitializationStatus.FAILED.toString() ->
latch.countDown()
}
}
latch.assertAwait(20, TimeUnit.SECONDS)
} catch (ex: Exception) {
Log.i("DeviceKeyPersistenceTest", "Error initializing", ex)
}
}
}

@Before
fun setup() {
signOut()
}

/**
* Sign in with USER_PASSWORD_AUTH, remember device (1 device).
* Sign out, sign in with USER_SRP_AUTH — device count should stay 1, same device key.
* Sign out, sign in with USER_PASSWORD_AUTH again — still 1, same key.
* Sign out, sign in with USER_SRP_AUTH again — still 1, same key.
*/
@Test
Comment on lines +85 to +91
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattcreaser the test steps that are failing on main.

fun deviceKey_stays_consistent_across_alternating_auth_flows() {
val context = ApplicationProvider.getApplicationContext<Context>()
val (username, password) = Credentials.load(context)

// Step 1: Sign in with USER_PASSWORD_AUTH and remember device
signIn(username, password, AuthFlowType.USER_PASSWORD_AUTH)
syncAuth.rememberDevice()

val initialDevices = syncAuth.fetchDevices()
assertEquals("Should have exactly 1 device after initial sign-in", 1, initialDevices.size)
val originalDeviceId = initialDevices[0].id

// Step 2: Sign out, sign in with USER_SRP_AUTH — same device
signOut()
signIn(username, password, AuthFlowType.USER_SRP_AUTH)

var devices = syncAuth.fetchDevices()
assertEquals("Should still have 1 device after SRP sign-in", 1, devices.size)
assertEquals("Device ID should match after SRP sign-in", originalDeviceId, devices[0].id)

// Step 3: Sign out, sign in with USER_PASSWORD_AUTH — same device
signOut()
signIn(username, password, AuthFlowType.USER_PASSWORD_AUTH)

devices = syncAuth.fetchDevices()
assertEquals("Should still have 1 device after second PASSWORD sign-in", 1, devices.size)
assertEquals("Device ID should match after second PASSWORD sign-in", originalDeviceId, devices[0].id)

// Step 4: Sign out, sign in with USER_SRP_AUTH — same device
signOut()
signIn(username, password, AuthFlowType.USER_SRP_AUTH)

devices = syncAuth.fetchDevices()
assertEquals("Should still have 1 device after second SRP sign-in", 1, devices.size)
assertEquals("Device ID should match after second SRP sign-in", originalDeviceId, devices[0].id)

// Clean up
syncAuth.forgetDevice()
}

/**
* Same-flow baseline: sign in with USER_SRP_AUTH, remember device,
* sign out, sign in with USER_SRP_AUTH — device count stays 1.
*/
@Test
fun deviceKey_persists_across_same_flow_signIn_signOut() {
val context = ApplicationProvider.getApplicationContext<Context>()
val (username, password) = Credentials.load(context)

signIn(username, password, AuthFlowType.USER_SRP_AUTH)
syncAuth.rememberDevice()

val initialDevices = syncAuth.fetchDevices()
assertEquals("Should have exactly 1 device", 1, initialDevices.size)
val originalDeviceId = initialDevices[0].id

signOut()
signIn(username, password, AuthFlowType.USER_SRP_AUTH)

val devices = syncAuth.fetchDevices()
assertEquals("Should still have 1 device after re-sign-in", 1, devices.size)
assertEquals("Device ID should be the same", originalDeviceId, devices[0].id)

// Clean up
syncAuth.forgetDevice()
}

private fun signIn(username: String, password: String, flowType: AuthFlowType) {
val options: AuthSignInOptions = AWSCognitoAuthSignInOptions.builder()
.authFlowType(flowType)
.build()
syncAuth.signIn(username, password, options)
}

private fun signOut() {
try {
syncAuth.signOut()
} catch (e: Exception) {
// Ignore errors during sign-out
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ internal object MigrateAuthCognitoActions : MigrateAuthActions {
private const val KEY_PASSWORD = "PASSWORD"
private const val KEY_SECRET_HASH = "SECRET_HASH"
private const val KEY_USERID_FOR_SRP = "USER_ID_FOR_SRP"
private const val KEY_DEVICE_KEY = "DEVICE_KEY"
private const val KEY_ANSWER = "ANSWER"
private const val KEY_PREFERRED_CHALLENGE = "PREFERRED_CHALLENGE"

Expand All @@ -54,8 +55,9 @@ internal object MigrateAuthCognitoActions : MigrateAuthActions {
secretHash?.let { authParams[KEY_SECRET_HASH] = it }

val encodedContextData = getUserContextData(event.username)
val deviceMetadata = getDeviceMetadata(event.username)
deviceMetadata?.let { authParams[KEY_DEVICE_KEY] = it.deviceKey }
val pinpointEndpointId = getPinpointEndpointId()

if (event.respondToAuthChallenge?.session != null) {
authParams[KEY_ANSWER] = ChallengeNameType.Password.value

Expand All @@ -76,7 +78,8 @@ internal object MigrateAuthCognitoActions : MigrateAuthActions {
session = response.session,
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = event.username
)
} else {
if (event.authFlowType == AuthFlowType.USER_AUTH) {
Expand All @@ -103,7 +106,8 @@ internal object MigrateAuthCognitoActions : MigrateAuthActions {
session = response.session,
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
signInMethod = signInMethod
signInMethod = signInMethod,
inputUsername = event.username
)
}
} catch (e: Exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ internal object SRPCognitoActions : SRPActions {
challengeParams = response.challengeParameters,
session = response.session,
updatedDeviceMetadata = updatedDeviceMetadata,
metadata = event.metadata
metadata = event.metadata,
inputUsername = event.username
)
} else {
if (event.authFlowType == AuthFlowType.USER_AUTH) {
Expand All @@ -124,7 +125,8 @@ internal object SRPCognitoActions : SRPActions {
challengeParams = response.challengeParameters,
session = response.session,
updatedDeviceMetadata = updatedDeviceMetadata,
metadata = event.metadata
metadata = event.metadata,
inputUsername = event.username
)
}
} catch (e: Exception) {
Expand All @@ -143,7 +145,8 @@ internal object SRPCognitoActions : SRPActions {
challengeParams: Map<String, String>?,
session: String?,
updatedDeviceMetadata: DeviceMetadata.Metadata?,
metadata: Map<String, String>
metadata: Map<String, String>,
inputUsername: String? = null
): SRPEvent = when (challengeNameType) {
ChallengeNameType.PasswordVerifier -> {
challengeParams?.let { params ->
Expand All @@ -155,7 +158,8 @@ internal object SRPCognitoActions : SRPActions {
SRPEvent.EventType.RespondPasswordVerifier(
updatedChallengeParams,
metadata,
session
session,
inputUsername = inputUsername
)
)
} ?: throw ServiceException(
Expand Down Expand Up @@ -209,7 +213,8 @@ internal object SRPCognitoActions : SRPActions {
SRPEvent.EventType.RespondPasswordVerifier(
challengeParams,
event.metadata,
initiateAuthResponse.session
initiateAuthResponse.session,
inputUsername = event.username
)
)
} ?: throw ServiceException(
Expand All @@ -235,7 +240,8 @@ internal object SRPCognitoActions : SRPActions {
challengeParameters: Map<String, String>,
metadata: Map<String, String>,
session: String?,
signInMethod: SignInMethod
signInMethod: SignInMethod,
inputUsername: String?
) = Action<AuthEnvironment>("VerifyPasswordSRP") { id, dispatcher ->
logger.verbose("$id Starting execution")
val evt = try {
Expand Down Expand Up @@ -282,7 +288,8 @@ internal object SRPCognitoActions : SRPActions {
session = response.session,
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
signInMethod = signInMethod
signInMethod = signInMethod,
inputUsername = inputUsername
)
} else {
throw ServiceException(
Expand All @@ -306,7 +313,8 @@ internal object SRPCognitoActions : SRPActions {
challengeParams,
metadata,
session,
signInMethod
signInMethod,
inputUsername = inputUsername
)
)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
session = challenge.session,
challengeParameters = mapOf("MFAS_CAN_SETUP" to answer),
authenticationResult = null,
signInMethod = signInMethod
signInMethod = signInMethod,
inputUsername = challenge.inputUsername
)
logger.verbose("$id Sending event ${event.type}")
dispatcher.send(event)
Expand Down Expand Up @@ -111,7 +112,8 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
session = response.session,
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
signInMethod = signInMethod
signInMethod = signInMethod,
inputUsername = challenge.inputUsername
)
} ?: CustomSignInEvent(
CustomSignInEvent.EventType.ThrowAuthError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ internal object SignInCognitoActions : SignInActions {
) ?: throw ServiceException("Sign in failed", AmplifyException.TODO_RECOVERY_SUGGESTION)

val updatedDeviceMetadata = deviceMetadata.copy(deviceSecret = deviceVerifierMap["secret"])
val deviceMetadataUsername = event.signedInData.inputUsername ?: event.signedInData.username
credentialStoreClient.storeCredentials(
CredentialType.Device(event.signedInData.username),
CredentialType.Device(deviceMetadataUsername),
AmplifyCredential.DeviceData(updatedDeviceMetadata)
)

Expand Down Expand Up @@ -230,7 +231,8 @@ internal object SignInCognitoActions : SignInActions {
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
availableChallenges = response.availableChallenges?.map { it.value },
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = event.signInData.username
)
} else {
throw ServiceException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ internal object SignInCustomCognitoActions : CustomSignInActions {
session = initiateAuthResponse.session,
challengeParameters = initiateAuthResponse.challengeParameters,
authenticationResult = initiateAuthResponse.authenticationResult,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.CUSTOM_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.CUSTOM_AUTH),
inputUsername = event.username
)
} else {
throw ServiceException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ internal object UserAuthSignInCognitoActions : UserAuthSignInActions {
availableChallenges = listOfChallenges,
authenticationResult = initiateAuthResponse.authenticationResult,
callingActivity = event.callingActivity,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = event.username
)
} else if (isSupportedChallenge(initiateAuthResponse?.challengeName) &&
initiateAuthResponse?.challengeParameters != null &&
Expand All @@ -98,7 +99,8 @@ internal object UserAuthSignInCognitoActions : UserAuthSignInActions {
challengeParameters = initiateAuthResponse.challengeParameters,
authenticationResult = initiateAuthResponse.authenticationResult,
callingActivity = event.callingActivity,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = event.username
)
} else {
throw ServiceException("Sign in failed", AmplifyException.TODO_RECOVERY_SUGGESTION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ internal object WebAuthnSignInCognitoActions : WebAuthnSignInActions {
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
callingActivity = signInContext.callingActivity,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = signInContext.username
)
}

Expand Down Expand Up @@ -104,7 +105,8 @@ internal object WebAuthnSignInCognitoActions : WebAuthnSignInActions {
challengeParameters = response.challengeParameters,
authenticationResult = response.authenticationResult,
callingActivity = signInContext.callingActivity,
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH)
signInMethod = SignInMethod.ApiBased(SignInMethod.ApiBased.AuthType.USER_AUTH),
inputUsername = signInContext.username
)
}

Expand Down
Loading
Loading