diff --git a/src/gtest/test_noteencryption.cpp b/src/gtest/test_noteencryption.cpp index a674daf65..0ed6999f8 100644 --- a/src/gtest/test_noteencryption.cpp +++ b/src/gtest/test_noteencryption.cpp @@ -35,6 +35,11 @@ TEST(noteencryption, NotePlaintext) } SaplingNote note(addr, 39393); + auto cmu_opt = note.cm(); + if (!cmu_opt) { + FAIL(); + } + uint256 cmu = cmu_opt.get(); SaplingNotePlaintext pt(note, memo); auto res = pt.encrypt(addr.pk_d); @@ -48,11 +53,20 @@ TEST(noteencryption, NotePlaintext) auto encryptor = enc.second; auto epk = encryptor.get_epk(); - // Try to decrypt + // Try to decrypt with incorrect commitment + ASSERT_FALSE(SaplingNotePlaintext::decrypt( + ct, + ivk, + epk, + uint256() + )); + + // Try to decrypt with correct commitment auto foo = SaplingNotePlaintext::decrypt( ct, ivk, - epk + epk, + cmu ); if (!foo) { @@ -112,12 +126,24 @@ TEST(noteencryption, NotePlaintext) ASSERT_TRUE(decrypted_out_ct_unwrapped.pk_d == out_pt.pk_d); ASSERT_TRUE(decrypted_out_ct_unwrapped.esk == out_pt.esk); + // Test sender won't accept invalid commitments + ASSERT_FALSE( + SaplingNotePlaintext::decrypt( + ct, + epk, + decrypted_out_ct_unwrapped.esk, + decrypted_out_ct_unwrapped.pk_d, + uint256() + ) + ); + // Test sender can decrypt the note ciphertext. foo = SaplingNotePlaintext::decrypt( ct, epk, decrypted_out_ct_unwrapped.esk, - decrypted_out_ct_unwrapped.pk_d + decrypted_out_ct_unwrapped.pk_d, + cmu ); if (!foo) { diff --git a/src/zcash/Note.cpp b/src/zcash/Note.cpp index c6c72e297..ee8f7b641 100644 --- a/src/zcash/Note.cpp +++ b/src/zcash/Note.cpp @@ -187,7 +187,8 @@ boost::optional SaplingOutgoingPlaintext::decrypt( boost::optional SaplingNotePlaintext::decrypt( const SaplingEncCiphertext &ciphertext, const uint256 &ivk, - const uint256 &epk + const uint256 &epk, + const uint256 &cmu ) { auto pt = AttemptSaplingEncDecryption(ciphertext, ivk, epk); @@ -204,6 +205,27 @@ boost::optional SaplingNotePlaintext::decrypt( assert(ss.size() == 0); + uint256 pk_d; + if (!librustzcash_ivk_to_pkd(ivk.begin(), ret.d.data(), pk_d.begin())) { + return boost::none; + } + + uint256 cmu_expected; + if (!librustzcash_sapling_compute_cm( + ret.d.data(), + pk_d.begin(), + ret.value(), + ret.rcm.begin(), + cmu_expected.begin() + )) + { + return boost::none; + } + + if (cmu_expected != cmu) { + return boost::none; + } + return ret; } @@ -211,7 +233,8 @@ boost::optional SaplingNotePlaintext::decrypt( const SaplingEncCiphertext &ciphertext, const uint256 &epk, const uint256 &esk, - const uint256 &pk_d + const uint256 &pk_d, + const uint256 &cmu ) { auto pt = AttemptSaplingEncDecryption(ciphertext, epk, esk, pk_d); @@ -226,6 +249,22 @@ boost::optional SaplingNotePlaintext::decrypt( SaplingNotePlaintext ret; ss >> ret; + uint256 cmu_expected; + if (!librustzcash_sapling_compute_cm( + ret.d.data(), + pk_d.begin(), + ret.value(), + ret.rcm.begin(), + cmu_expected.begin() + )) + { + return boost::none; + } + + if (cmu_expected != cmu) { + return boost::none; + } + assert(ss.size() == 0); return ret; diff --git a/src/zcash/Note.hpp b/src/zcash/Note.hpp index f1b8e4323..7d3377306 100644 --- a/src/zcash/Note.hpp +++ b/src/zcash/Note.hpp @@ -130,14 +130,16 @@ public: static boost::optional decrypt( const SaplingEncCiphertext &ciphertext, const uint256 &ivk, - const uint256 &epk + const uint256 &epk, + const uint256 &cmu ); static boost::optional decrypt( const SaplingEncCiphertext &ciphertext, const uint256 &epk, const uint256 &esk, - const uint256 &pk_d + const uint256 &pk_d, + const uint256 &cmu ); boost::optional note(const SaplingIncomingViewingKey& ivk) const;