From 04696da195f746dee324574e4e282f4a562c464b Mon Sep 17 00:00:00 2001 From: "M. Thiercelin" Date: Thu, 27 Apr 2023 15:21:17 +0200 Subject: [PATCH] fix(key-transparency): Enforce that the SKL creation time increases. We need to save the last creation time checked in the verified epoch. And we need to verify that the SKL creation times increase monotonically. --- .../domain/entity/VerifiedEpochData.kt | 4 +- .../domain/usecase/AuditUserAddress.kt | 19 ++++- .../domain/usecase/BootstrapInitialEpoch.kt | 7 +- .../domain/usecase/AuditUserAddressTest.kt | 85 ++++++++++++++++++- .../usecase/BootstrapInitialEpochTest.kt | 19 ++++- .../domain/usecase/FetchVerifiedEpochTest.kt | 3 +- .../domain/usecase/UploadVerifiedEpochTest.kt | 3 +- 7 files changed, 128 insertions(+), 12 deletions(-) diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt index b3b35015a..5b5bb0d02 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/entity/VerifiedEpochData.kt @@ -28,7 +28,9 @@ internal data class VerifiedEpochData( @SerialName("EpochID") val epochId: EpochId, @SerialName("Revision") - val revision: Int + val revision: Int, + @SerialName("SKLCreationTime") + val sklCreationTime: Long ) { companion object { fun fromJson(json: String): VerifiedEpochData = json.deserialize() diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt index 057db5eb2..2be181e46 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddress.kt @@ -103,7 +103,11 @@ internal class AuditUserAddress @Inject constructor( keyTransparencyCheck(revision == initialEpoch.revision) { "Revision has changed but no SKL were provided" } if (maxEpochId > initialEpoch.epochId) { // Update the verified epoch - uploadVerifiedEpoch(userId, userAddress.addressId, VerifiedEpochData(maxEpochId, initialEpoch.revision)) + uploadVerifiedEpoch( + userId, + userAddress.addressId, + VerifiedEpochData(maxEpochId, initialEpoch.revision, initialEpoch.sklCreationTime) + ) } } @@ -116,6 +120,7 @@ internal class AuditUserAddress @Inject constructor( ) { var previousVerifiedEpoch = initialEpoch var previousCertificateDate: Long? = null + var previousSKLCreationTime = initialEpoch.sklCreationTime newSKLs.forEachIndexed { index, newSKL -> val isLast = index == newSKLs.size - 1 val isFirst = index == 0 @@ -127,6 +132,12 @@ internal class AuditUserAddress @Inject constructor( val timestamp = if (newSKL.data != null) { verifySignedKeyListSignature(userAddress, newSKL) } else null + if (timestamp != null) { + keyTransparencyCheck( + timestamp >= previousSKLCreationTime + ) { "SKL Creation time must increase monotonically."} + previousSKLCreationTime = timestamp + } if (maxEpochId != null) { val (verifiedState, revision) = verifyMaxEpoch(userId, maxEpochId, userAddress, newSKL) val isRevisionConsistent = if (isFirst) { @@ -135,7 +146,11 @@ internal class AuditUserAddress @Inject constructor( revision == previousVerifiedEpoch.revision + 1 } keyTransparencyCheck(isRevisionConsistent) { "Revision chain is inconsistent" } - previousVerifiedEpoch = VerifiedEpochData(revision = revision, epochId = maxEpochId) + previousVerifiedEpoch = VerifiedEpochData( + maxEpochId, + revision, + timestamp ?: previousVerifiedEpoch.sklCreationTime + ) previousCertificateDate = (verifiedState as TimedState).notBefore } else { // the last SKL cannot be an obsolescence because the address is not disabled diff --git a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt index db8c4ff2d..9d4db69ef 100644 --- a/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt +++ b/key-transparency/domain/src/main/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpoch.kt @@ -46,7 +46,10 @@ internal class BootstrapInitialEpoch @Inject constructor( keyTransparencyCheck(newSKLs.isNotEmpty()) { "Can't bootstrap, no SKL available" } val oldestSKL = newSKLs[0] if (oldestSKL.minEpochId == null) { - keyTransparencyCheck(oldestSKL.data == inputSKL.data) + keyTransparencyCheck( + oldestSKL.data == inputSKL.data && + oldestSKL.signature == inputSKL.signature + ) { "Input SKL did not equal the only SKL" } // The address is too recent to bootstrap the verified epoch keyTransparencyCheck(newSKLs.size == 1) { "New address had more SKLs than the current one" } val timestamp = verifySignedKeyListSignature(userAddress, inputSKL) @@ -70,6 +73,6 @@ internal class BootstrapInitialEpoch @Inject constructor( ) { "Oldest epoch is not in range" } } val bootstrappedRevision = keyTransparencyCheckNotNull(revision) { "Bootstrapped epoch had no revision" } - return VerifiedEpochData(minEpochId, bootstrappedRevision) + return VerifiedEpochData(minEpochId, bootstrappedRevision, 0) } } diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt index 7cca49d30..8574564e8 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/AuditUserAddressTest.kt @@ -267,9 +267,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 101 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns emptyList() @@ -283,7 +285,7 @@ class AuditUserAddressTest { val certificateTime = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 val verifiedState = VerifiedState.Existent(certificateTime) coEvery { verifyProofInEpoch(testEmail, skl, epoch, proof) } returns verifiedState - val expectedVE = VerifiedEpochData(101, 1) + val expectedVE = VerifiedEpochData(101, 1, lastVerifiedTimestamp) coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedVE) } // when auditUserAddress(testUserId, userAddress).unwrap() @@ -303,9 +305,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 100 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns emptyList() @@ -344,9 +348,11 @@ class AuditUserAddressTest { every { maxEpochId } returns 100 } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKLs = listOf( @@ -369,9 +375,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -399,9 +407,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -438,9 +448,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -479,9 +491,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -524,9 +538,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -573,9 +589,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -612,15 +630,17 @@ class AuditUserAddressTest { } @Test - fun `New changes - If a new skl is included, the verified epoch is updated`() = runTest { + fun `New changes - If the creation time is decreasing, audit fails`() = runTest { // given val skl = mockk { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -629,7 +649,8 @@ class AuditUserAddressTest { } coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns listOf(newSKL) coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch - coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns 10_000 + val newSKLTimestamp = lastVerifiedTimestamp - 100 // SKL creation time is decreasing + coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns newSKLTimestamp val epoch = mockk() coEvery { keyTransparencyRepository.getEpoch(testUserId, 110) } returns epoch val proof = mockk { @@ -647,6 +668,61 @@ class AuditUserAddressTest { val expectedNewVerifiedEpoch = mockk { every { revision } returns 2 every { epochId } returns 110 + every { sklCreationTime } returns newSKLTimestamp + } + coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedNewVerifiedEpoch) } + // when + assertFailsWith { + auditUserAddress(testUserId, userAddress).unwrap() + } + // then + coVerify { + fetchVerifiedEpoch(testUserId, userAddress) + publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) + verifySignedKeyListSignature(userAddress, newSKL) + } + } + + @Test + fun `New changes - If a new skl is included, the verified epoch is updated`() = runTest { + // given + val skl = mockk { + every { data } returns "data" + } + every { userAddress.signedKeyList } returns skl + val lastVerifiedTimestamp = 1000L + val verifiedEpoch = mockk { + every { revision } returns 1 + every { epochId } returns 100 + every { sklCreationTime } returns lastVerifiedTimestamp + } + coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch + val newSKL = mockk { + every { data } returns "data" + every { maxEpochId } returns 110 + } + coEvery { publicAddressRepository.getSKLsAfterEpoch(testUserId, 100, testEmail) } returns listOf(newSKL) + coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch + val newSKLTimestamp = lastVerifiedTimestamp + 10 + coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns newSKLTimestamp + val epoch = mockk() + coEvery { keyTransparencyRepository.getEpoch(testUserId, 110) } returns epoch + val proof = mockk { + every { proof.revision } returns 2 + } + coEvery { keyTransparencyRepository.getProof(testUserId, 110, testEmail) } returns proof + val currentTime = 10_000L + coEvery { getCurrentTime() } returns currentTime + val certificateTime = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 + val verifiedState = VerifiedState.Existent(certificateTime) + coEvery { + verifyProofInEpoch(testEmail, newSKL, epoch, proof) + } returns verifiedState + coJustRun { checkSignedKeyListMatch(userAddress, skl) } + val expectedNewVerifiedEpoch = mockk { + every { revision } returns 2 + every { epochId } returns 110 + every { sklCreationTime } returns newSKLTimestamp } coJustRun { uploadVerifiedEpoch(testUserId, testAddressId, expectedNewVerifiedEpoch) } // when @@ -671,9 +747,11 @@ class AuditUserAddressTest { every { data } returns "data" } every { userAddress.signedKeyList } returns skl + val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 val verifiedEpoch = mockk { every { revision } returns 1 every { epochId } returns 100 + every { sklCreationTime } returns signatureTimestamp - 1000 } coEvery { fetchVerifiedEpoch(testUserId, userAddress) } returns verifiedEpoch val newSKL = mockk { @@ -684,7 +762,6 @@ class AuditUserAddressTest { coEvery { buildInitialEpoch(verifiedEpoch, any(), testUserId, userAddress, skl) } returns verifiedEpoch val currentTime = 10_000L coEvery { getCurrentTime() } returns currentTime - val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 coEvery { verifySignedKeyListSignature(userAddress, newSKL) } returns signatureTimestamp coJustRun { checkSignedKeyListMatch(userAddress, skl) } // when diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt index c6c833cd7..1d5e57c05 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/BootstrapInitialEpochTest.kt @@ -78,16 +78,17 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data1" + every { signature } returns "signature" every { minEpochId } returns null } // when assertFailsWith { bootstrapInitialEpoch(userId, userAddress, inputSKL, listOf(newSKL)) } - // then } @Test @@ -96,9 +97,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } coEvery { @@ -120,9 +123,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS - 1000 @@ -143,9 +148,11 @@ class BootstrapInitialEpochTest { val userAddress = mockk() val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns null } val signatureTimestamp = currentTime - Constants.KT_MAX_EPOCH_INTERVAL_SECONDS + 1000 @@ -167,9 +174,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -197,9 +206,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -227,9 +238,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -247,6 +260,7 @@ class BootstrapInitialEpochTest { assertNotNull(bootstrappedEpoch) assertEquals(100, bootstrappedEpoch.epochId) assertEquals(0, bootstrappedEpoch.revision) + assertEquals(0, bootstrappedEpoch.sklCreationTime) } @Test @@ -257,9 +271,11 @@ class BootstrapInitialEpochTest { } val inputSKL = mockk { every { data } returns "data" + every { signature } returns "signature" } val newSKL = mockk { every { data } returns "data" + every { signature } returns "signature" every { minEpochId } returns 100 } val epoch = mockk() @@ -277,5 +293,6 @@ class BootstrapInitialEpochTest { assertNotNull(bootstrappedEpoch) assertEquals(100, bootstrappedEpoch.epochId) assertEquals(1, bootstrappedEpoch.revision) + assertEquals(0, bootstrappedEpoch.sklCreationTime) } } diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt index 019f60fcc..5d4666919 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/FetchVerifiedEpochTest.kt @@ -87,7 +87,8 @@ class FetchVerifiedEpochTest { // given val verifiedEpochData = VerifiedEpochData( 10, - 2 + 2, + 100 ) val verifiedEpoch = mockk { every { data } returns "data" diff --git a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt index b633a0af5..d8af7078c 100644 --- a/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt +++ b/key-transparency/domain/src/test/kotlin/me/proton/core/keytransparency/domain/usecase/UploadVerifiedEpochTest.kt @@ -82,7 +82,8 @@ class UploadVerifiedEpochTest { val revision = 10 val inputVerifiedEpoch = VerifiedEpochData( epochID, - revision + revision, + 100 ) val serialized = inputVerifiedEpoch.toJson() val unlocked = "unlocked".toByteArray()