diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp index 50da9ea7a..abc582faa 100644 --- a/src/wallet/gtest/test_wallet.cpp +++ b/src/wallet/gtest/test_wallet.cpp @@ -2014,7 +2014,39 @@ TEST(WalletTests, SproutNoteLocking) { EXPECT_TRUE(wallet.IsLockedNote(jsoutpt2)); // Test unlock all - wallet.UnlockAllNotes(); + wallet.UnlockAllSproutNotes(); EXPECT_FALSE(wallet.IsLockedNote(jsoutpt)); EXPECT_FALSE(wallet.IsLockedNote(jsoutpt2)); } + +TEST(WalletTests, SaplingNoteLocking) { + TestWallet wallet; + SaplingOutPoint sop1 {uint256(), 1}; + SaplingOutPoint sop2 {uint256(), 2}; + + // Test selective locking + wallet.LockNote(sop1); + EXPECT_TRUE(wallet.IsLockedNote(sop1)); + EXPECT_FALSE(wallet.IsLockedNote(sop2)); + + // Test selective unlocking + wallet.UnlockNote(sop1); + EXPECT_FALSE(wallet.IsLockedNote(sop1)); + + // Test multiple locking + wallet.LockNote(sop1); + wallet.LockNote(sop2); + EXPECT_TRUE(wallet.IsLockedNote(sop1)); + EXPECT_TRUE(wallet.IsLockedNote(sop2)); + + // Test list + auto v = wallet.ListLockedSaplingNotes(); + EXPECT_EQ(v.size(), 2); + EXPECT_TRUE(std::find(v.begin(), v.end(), sop1) != v.end()); + EXPECT_TRUE(std::find(v.begin(), v.end(), sop2) != v.end()); + + // Test unlock all + wallet.UnlockAllSaplingNotes(); + EXPECT_FALSE(wallet.IsLockedNote(sop1)); + EXPECT_FALSE(wallet.IsLockedNote(sop2)); +} diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index e1ed88a24..d931f7384 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -3942,36 +3942,67 @@ void CWallet::ListLockedCoins(std::vector& vOutpts) void CWallet::LockNote(const JSOutPoint& output) { - AssertLockHeld(cs_wallet); // setLockedNotes - setLockedNotes.insert(output); + AssertLockHeld(cs_wallet); // setLockedSproutNotes + setLockedSproutNotes.insert(output); } void CWallet::UnlockNote(const JSOutPoint& output) { - AssertLockHeld(cs_wallet); // setLockedNotes - setLockedNotes.erase(output); + AssertLockHeld(cs_wallet); // setLockedSproutNotes + setLockedSproutNotes.erase(output); } -void CWallet::UnlockAllNotes() +void CWallet::UnlockAllSproutNotes() { - AssertLockHeld(cs_wallet); // setLockedNotes - setLockedNotes.clear(); + AssertLockHeld(cs_wallet); // setLockedSproutNotes + setLockedSproutNotes.clear(); } bool CWallet::IsLockedNote(const JSOutPoint& outpt) const { - AssertLockHeld(cs_wallet); // setLockedNotes + AssertLockHeld(cs_wallet); // setLockedSproutNotes - return (setLockedNotes.count(outpt) > 0); + return (setLockedSproutNotes.count(outpt) > 0); } -std::vector CWallet::ListLockedNotes() +std::vector CWallet::ListLockedSproutNotes() { - AssertLockHeld(cs_wallet); // setLockedNotes - std::vector vOutpts(setLockedNotes.begin(), setLockedNotes.end()); + AssertLockHeld(cs_wallet); // setLockedSproutNotes + std::vector vOutpts(setLockedSproutNotes.begin(), setLockedSproutNotes.end()); return vOutpts; } +void CWallet::LockNote(const SaplingOutPoint& output) +{ + AssertLockHeld(cs_wallet); + setLockedSaplingNotes.insert(output); +} + +void CWallet::UnlockNote(const SaplingOutPoint& output) +{ + AssertLockHeld(cs_wallet); + setLockedSaplingNotes.erase(output); +} + +void CWallet::UnlockAllSaplingNotes() +{ + AssertLockHeld(cs_wallet); + setLockedSaplingNotes.clear(); +} + +bool CWallet::IsLockedNote(const SaplingOutPoint& output) const +{ + AssertLockHeld(cs_wallet); + return (setLockedSaplingNotes.count(output) > 0); +} + +std::vector CWallet::ListLockedSaplingNotes() +{ + AssertLockHeld(cs_wallet); + std::vector vOutputs(setLockedSaplingNotes.begin(), setLockedSaplingNotes.end()); + return vOutputs; +} + /** @} */ // end of Actions class CAffectedKeysVisitor : public boost::static_visitor { diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 3c7d4dd46..dda1e4dfe 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -948,7 +948,8 @@ public: CPubKey vchDefaultKey; std::set setLockedCoins; - std::set setLockedNotes; + std::set setLockedSproutNotes; + std::set setLockedSaplingNotes; int64_t nTimeFirstKey; @@ -970,13 +971,17 @@ public: void UnlockAllCoins(); void ListLockedCoins(std::vector& vOutpts); - bool IsLockedNote(const JSOutPoint& outpt) const; void LockNote(const JSOutPoint& output); void UnlockNote(const JSOutPoint& output); - void UnlockAllNotes(); - std::vector ListLockedNotes(); + void UnlockAllSproutNotes(); + std::vector ListLockedSproutNotes(); + bool IsLockedNote(const SaplingOutPoint& output) const; + void LockNote(const SaplingOutPoint& output); + void UnlockNote(const SaplingOutPoint& output); + void UnlockAllSaplingNotes(); + std::vector ListLockedSaplingNotes(); /** * keystore implementation