diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt index b3b35015a..5b5bb0d02 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt @@ -28,7 +28,9 @@ internal data class VerifiedEpochData( @SerialName("EpochID") val epochId: EpochId, @SerialName("Revision") - val revision: Int + val revision: Int, + @SerialName("SKLCreationTime") + val sklCreationTime: Long ) { companion object { fun fromJson(json: String): VerifiedEpochData = json.deserialize() diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt index 057db5eb2..2be181e46 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt @@ -103,7 +103,11 @@ internal class AuditUserAddress @Inject constructor( keyTransparencyCheck(revision == initialEpoch.revision) { "Revision has changed but no SKL were provided" } if (maxEpochId > initialEpoch.epochId) { // Update the verified epoch - uploadVerifiedEpoch(userId, userAddress.addressId, VerifiedEpochData(maxEpochId, initialEpoch.revision)) + uploadVerifiedEpoch( + userId, + userAddress.addressId, + VerifiedEpochData(maxEpochId, initialEpoch.revision, initialEpoch.sklCreationTime) + ) } } @@ -116,6 +120,7 @@ internal class AuditUserAddress @Inject constructor( ) { var previousVerifiedEpoch = initialEpoch var previousCertificateDate: Long? = null + var previousSKLCreationTime = initialEpoch.sklCreationTime newSKLs.forEachIndexed { index, newSKL -> val isLast = index == newSKLs.size - 1 val isFirst = index == 0 @@ -127,6 +132,12 @@ internal class AuditUserAddress @Inject constructor( val timestamp = if (newSKL.data != null) { verifySignedKeyListSignature(userAddress, newSKL) } else null + if (timestamp != null) { + keyTransparencyCheck( + timestamp >= previousSKLCreationTime + ) { "SKL Creation time must increase monotonically."} + previousSKLCreationTime = timestamp + } if (maxEpochId != null) { val (verifiedState, revision) = verifyMaxEpoch(userId, maxEpochId, userAddress, newSKL) val isRevisionConsistent = if (isFirst) { @@ -135,7 +146,11 @@ internal class AuditUserAddress @Inject constructor( revision == previousVerifiedEpoch.revision + 1 } keyTransparencyCheck(isRevisionConsistent) { "Revision chain is inconsistent" } - previousVerifiedEpoch = VerifiedEpochData(revision = revision, epochId = maxEpochId) + previousVerifiedEpoch = VerifiedEpochData( + maxEpochId, + revision, + timestamp ?: previousVerifiedEpoch.sklCreationTime + ) previousCertificateDate = (verifiedState as TimedState).notBefore } else { // the last SKL cannot be an obsolescence because the address is not disabled diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt index db8c4ff2d..9d4db69ef 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt @@ -46,7 +46,10 @@ internal class BootstrapInitialEpoch @Inject constructor( keyTransparencyCheck(newSKLs.isNotEmpty()) { "Can't bootstrap, no SKL available" } val oldestSKL = newSKLs[0] if (oldestSKL.minEpochId == null) { - keyTransparencyCheck(oldestSKL.data == inputSKL.data) + keyTransparencyCheck( + oldestSKL.data == inputSKL.data && + oldestSKL.signature == inputSKL.signature + ) { "Input SKL did not equal the only SKL" } // The address is too recent to bootstrap the verified epoch keyTransparencyCheck(newSKLs.size == 1) { "New address had more SKLs than the current one" } val timestamp = verifySignedKeyListSignature(userAddress, inputSKL) @@ -70,6 +73,6 @@ internal class BootstrapInitialEpoch @Inject constructor( ) { "Oldest epoch is not in range" } } val bootstrappedRevision = keyTransparencyCheckNotNull(revision) { "Bootstrapped epoch had no revision" } - return VerifiedEpochData(minEpochId, bootstrappedRevision) + return VerifiedEpochData(minEpochId, bootstrappedRevision, 0) } } diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt index 7cca49d30..8574564e8 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt @@ -267,9 +267,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 101 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns emptyList() @@ -283,7 +285,7 @@ class AuditUserAddressTest { val certificateTime = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 val verifiedState = VerifiedState.Existent(certificateTime) coEvery { verifyProofInEpoch(testEmail, skl, epoch, proof) } returns verifiedState - val expectedVE = VerifiedEpochData(101, 1) + val expectedVE = VerifiedEpochData(101, 1, lastVerifiedTimestamp) coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedVE) } // when auditUserAddress(testUserId, userAddress).unwrap() @@ -303,9 +305,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 100 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns emptyList() @@ -344,9 +348,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 100 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKLs = listOf( @@ -369,9 +375,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -399,9 +407,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -438,9 +448,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -479,9 +491,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -524,9 +538,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -573,9 +589,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -612,15 +630,17 @@ class AuditUserAddressTest { } @Test - fun `New changes - If a new skl is included, the verified epoch is updated`() = runTest { + fun `New changes - If the creation time is decreasing, audit fails`() = runTest { // given val skl = mockk { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -629,7 +649,8 @@ class AuditUserAddressTest { } coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns listOf(newSKL) coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch - coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns 10_000 + val newSKLTimestamp = lastVerifiedTimestamp - 100 // SKL creation time is decreasing + coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns newSKLTimestamp val epoch = mockk() coEvery { keyTransparencyRepository.getEpoch(testUserId, 110) } returns epoch val proof = mockk { @@ -647,6 +668,61 @@ class AuditUserAddressTest { val expectedNewVerifiedEpoch = mockk { every { revision } returns 2 every { epochId } returns 110 + every { sklCreationTime } returns newSKLTimestamp + } + coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedNewVerifiedEpoch) } + // when + assertFailsWith { + auditUserAddress(testUserId, userAddress).unwrap() + } + // then + coVerify { + fetchVerifiedEpoch(testUserId, userAddress) + publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) + verifySignedKeyListSignature(userAddress, newSKL) + } + } + + @Test + fun `New changes - If a new skl is included, the verified epoch is updated`() = runTest { + // given + val skl = mockk { + every { data } returns "data" + } + every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L + val verifiedEpoch = mockk { + every { revision } returns 1 + every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp + } + coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch + val newSKL = mockk { + every { data } returns "data" + every { maxEpochId } returns 110 + } + coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns listOf(newSKL) + coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch + val newSKLTimestamp = lastVerifiedTimestamp + 10 + coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns newSKLTimestamp + val epoch = mockk() + coEvery { keyTransparencyRepository.getEpoch(testUserId, 110) } returns epoch + val proof = mockk { + every { proof.revision } returns 2 + } + coEvery { keyTransparencyRepository.getProof(testUserId, 110, testEmail) } returns proof + val currentTime = 10_000L + coEvery { getCurrentTime() } returns currentTime + val certificateTime = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 + val verifiedState = VerifiedState.Existent(certificateTime) + coEvery { + verifyProofInEpoch(testEmail, newSKL, epoch, proof) + } returns verifiedState + coJustRun { checkSignedKeyListMatch(userAddress, skl) } + val expectedNewVerifiedEpoch = mockk { + every { revision } returns 2 + every { epochId } returns 110 + every { sklCreationTime } returns newSKLTimestamp } coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedNewVerifiedEpoch) } // when @@ -671,9 +747,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns signatureTimestamp - 1000 } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -684,7 +762,6 @@ class AuditUserAddressTest { coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch val currentTime = 10_000L coEvery { getCurrentTime() } returns currentTime - val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns signatureTimestamp coJustRun { checkSignedKeyListMatch(userAddress, skl) } // when diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt index c6c833cd7..1d5e57c05 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt @@ -78,16 +78,17 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data1" + every { signature } returns "signature" every { minEpochId } returns null } // when assertFailsWith { bootstrapInitialEpoch(userId, userAddress, inputSKL, listOf(newSKL)) } - // then } @Test @@ -96,9 +97,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } coEvery { @@ -120,9 +123,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS - 1000 @@ -143,9 +148,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 @@ -167,9 +174,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -197,9 +206,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -227,9 +238,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -247,6 +260,7 @@ class BootstrapInitialEpochTest { assertNotNull(bootstrappedEpoch) assertEquals(100, bootstrappedEpoch.epochId) assertEquals(0, bootstrappedEpoch.revision) + assertEquals(0, bootstrappedEpoch.sklCreationTime) } @Test @@ -257,9 +271,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -277,5 +293,6 @@ class BootstrapInitialEpochTest { assertNotNull(bootstrappedEpoch) assertEquals(100, bootstrappedEpoch.epochId) assertEquals(1, bootstrappedEpoch.revision) + assertEquals(0, bootstrappedEpoch.sklCreationTime) } } diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt index 019f60fcc..5d4666919 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt @@ -87,7 +87,8 @@ class FetchVerifiedEpochTest { // given val verifiedEpochData = VerifiedEpochData( 10, - 2 + 2, + 100 ) val verifiedEpoch = mockk { every { data } returns "data" diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt index b633a0af5..d8af7078c 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt @@ -82,7 +82,8 @@ class UploadVerifiedEpochTest { val revision = 10 val inputVerifiedEpoch = VerifiedEpochData( epochID, - revision + revision, + 100 ) val serialized = inputVerifiedEpoch.toJson() val unlocked = "unlocked".toByteArray()