diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp index 7358472e6..d767967ba 100644 --- a/src/gtest/test_keystore.cpp +++ b/src/gtest/test_keystore.cpp @@ -55,8 +55,12 @@ TEST(keystore_tests, sapling_keys) { EXPECT_EQ(in_viewing_key, in_viewing_key_2); // Check that the default address from primitives and from sk method are the same - auto default_addr = sk.default_address(); - auto default_addr_2 = in_viewing_key.address(default_d); + auto addrOpt = sk.default_address(); + EXPECT_TRUE(addrOpt); + auto default_addr = addrOpt.value(); + auto addrOpt2 = in_viewing_key.address(default_d); + EXPECT_TRUE(addrOpt2); + auto default_addr_2 = addrOpt2.value(); EXPECT_EQ(default_addr, default_addr_2); auto default_addr_3 = libzcash::SaplingPaymentAddress(default_d, default_pk_d); @@ -161,6 +165,42 @@ TEST(keystore_tests, StoreAndRetrieveViewingKey) { EXPECT_EQ(ZCNoteDecryption(sk.receiving_key()), decOut); } +// Sapling +TEST(keystore_tests, StoreAndRetrieveSaplingSpendingKey) { + CBasicKeyStore keyStore; + libzcash::SaplingSpendingKey skOut; + libzcash::SaplingFullViewingKey fvkOut; + libzcash::SaplingIncomingViewingKey ivkOut; + + auto sk = libzcash::SaplingSpendingKey::random(); + auto fvk = sk.full_viewing_key(); + auto ivk = fvk.in_viewing_key(); + auto addrOpt = sk.default_address(); + EXPECT_TRUE(addrOpt); + auto addr = addrOpt.value(); + + // Sanity-check: we can't get a key we haven't added + EXPECT_FALSE(keyStore.HaveSaplingSpendingKey(fvk)); + EXPECT_FALSE(keyStore.GetSaplingSpendingKey(fvk, skOut)); + // Sanity-check: we can't get a full viewing key we haven't added + EXPECT_FALSE(keyStore.HaveSaplingFullViewingKey(ivk)); + EXPECT_FALSE(keyStore.GetSaplingFullViewingKey(ivk, fvkOut)); + // Sanity-check: we can't get an incoming viewing key we haven't added + EXPECT_FALSE(keyStore.HaveSaplingIncomingViewingKey(addr)); + EXPECT_FALSE(keyStore.GetSaplingIncomingViewingKey(addr, ivkOut)); + + keyStore.AddSaplingSpendingKey(sk); + EXPECT_TRUE(keyStore.HaveSaplingSpendingKey(fvk)); + EXPECT_TRUE(keyStore.GetSaplingSpendingKey(fvk, skOut)); + EXPECT_TRUE(keyStore.HaveSaplingFullViewingKey(ivk)); + EXPECT_TRUE(keyStore.GetSaplingFullViewingKey(ivk, fvkOut)); + EXPECT_TRUE(keyStore.HaveSaplingIncomingViewingKey(addr)); + EXPECT_TRUE(keyStore.GetSaplingIncomingViewingKey(addr, ivkOut)); + EXPECT_EQ(sk, skOut); + EXPECT_EQ(fvk, fvkOut); + EXPECT_EQ(ivk, ivkOut); +} + #ifdef ENABLE_WALLET class TestCCryptoKeyStore : public CCryptoKeyStore { diff --git a/src/keystore.cpp b/src/keystore.cpp index 828cae677..d0e5b8ec6 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -106,12 +106,10 @@ bool CBasicKeyStore::AddSaplingSpendingKey(const libzcash::SaplingSpendingKey &s // Add addr -> SaplingIncomingViewing to SaplingIncomingViewingKeyMap auto ivk = fvk.in_viewing_key(); auto addrOpt = sk.default_address(); - if (addrOpt){ - auto addr = addrOpt.value(); - mapSaplingIncomingViewingKeys[addr] = ivk; - } else { - return false; - } + assert(addrOpt != boost::none); + auto addr = addrOpt.value(); + mapSaplingIncomingViewingKeys[addr] = ivk; + return true; } @@ -129,8 +127,9 @@ bool CBasicKeyStore::AddSaplingFullViewingKey(const libzcash::SaplingFullViewing LOCK(cs_SpendingKeyStore); auto ivk = fvk.in_viewing_key(); mapSaplingFullViewingKeys[ivk] = fvk; + //! TODO: Note decryptors for Sapling - // mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(vk.sk_enc))); + return true; } diff --git a/src/wallet/gtest/test_wallet_zkeys.cpp b/src/wallet/gtest/test_wallet_zkeys.cpp index 9145648d2..efc06b962 100644 --- a/src/wallet/gtest/test_wallet_zkeys.cpp +++ b/src/wallet/gtest/test_wallet_zkeys.cpp @@ -7,42 +7,9 @@ #include -/** - * This test covers Sapling methods on CWallet - * GenerateNewSaplingZKey() - */ -TEST(wallet_zkeys_tests, store_and_load_sapling_zkeys) { - CWallet wallet; - - // wallet should be empty - // std::set addrs; - // wallet.GetSaplingPaymentAddresses(addrs); - // ASSERT_EQ(0, addrs.size()); - - // wallet should have one key - auto saplingAddr = wallet.GenerateNewSaplingZKey(); - // ASSERT_NE(boost::get(&address), nullptr); - // auto sapling_addr = boost::get(saplingAddr); - // wallet.GetSaplingPaymentAddresses(addrs); - // ASSERT_EQ(1, addrs.size()); - - auto sk = libzcash::SaplingSpendingKey::random(); - auto full_viewing_key = sk.full_viewing_key(); - ASSERT_TRUE(wallet.AddSaplingSpendingKey(sk)); - - // verify wallet has spending key for the address - ASSERT_TRUE(wallet.HaveSaplingSpendingKey(full_viewing_key)); - - // check key is the same - libzcash::SaplingSpendingKey keyOut; - wallet.GetSaplingSpendingKey(full_viewing_key, keyOut); - ASSERT_EQ(sk, keyOut); -} - /** * This test covers methods on CWallet * GenerateNewZKey() - * GenerateNewSaplingZKey() * AddZKey() * LoadZKey() * LoadZKeyMetadata() diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index b1bb54dac..410aec564 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -100,40 +100,40 @@ libzcash::PaymentAddress CWallet::GenerateNewZKey() return addr; } -//! TODO: Should be Sapling address format, SaplingPaymentAddress // Generate a new Sapling spending key and return its public payment address SaplingPaymentAddress CWallet::GenerateNewSaplingZKey() { AssertLockHeld(cs_wallet); // mapZKeyMetadata - auto sk = SaplingSpendingKey::random(); - auto fvk = sk.full_viewing_key(); - auto addrOpt = sk.default_address(); - if (addrOpt){ - auto addr = addrOpt.value(); - // Check for collision, even though it is unlikely to ever occur - if (CCryptoKeyStore::HaveSaplingSpendingKey(fvk)) - throw std::runtime_error("CWallet::GenerateNewSaplingZKey(): Collision detected"); - - // Create new metadata - int64_t nCreationTime = GetTime(); - mapSaplingZKeyMetadata[addr] = CKeyMetadata(nCreationTime); - - if (!AddSaplingZKey(sk)) { - throw std::runtime_error("CWallet::GenerateNewSaplingZKey(): AddSaplingZKey failed"); - } - // return default sapling payment address. - return addr; - } else { - throw std::runtime_error("CWallet::GenerateNewSaplingZKey(): default_address() did not return address"); + + SaplingSpendingKey sk; + boost::optional addrOpt; + while (!addrOpt){ + sk = SaplingSpendingKey::random(); + addrOpt = sk.default_address(); } + + auto addr = addrOpt.value(); + auto fvk = sk.full_viewing_key(); + + // Check for collision, even though it is unlikely to ever occur + if (CCryptoKeyStore::HaveSaplingSpendingKey(fvk)) + throw std::runtime_error("CWallet::GenerateNewSaplingZKey(): Collision detected"); + + // Create new metadata + int64_t nCreationTime = GetTime(); + mapSaplingZKeyMetadata[addr] = CKeyMetadata(nCreationTime); + + if (!AddSaplingZKey(sk)) { + throw std::runtime_error("CWallet::GenerateNewSaplingZKey(): AddSaplingZKey failed"); + } + // return default sapling payment address. + return addr; } // Add spending key to keystore -// TODO: persist to disk bool CWallet::AddSaplingZKey(const libzcash::SaplingSpendingKey &sk) { AssertLockHeld(cs_wallet); // mapSaplingZKeyMetadata - auto addr = sk.default_address(); if (!CCryptoKeyStore::AddSaplingSpendingKey(sk)) { return false; @@ -144,11 +144,7 @@ bool CWallet::AddSaplingZKey(const libzcash::SaplingSpendingKey &sk) } // TODO: Persist to disk - // if (!IsCrypted()) { - // return CWalletDB(strWalletFile).WriteSaplingZKey(addr, - // sk, - // mapSaplingZKeyMetadata[addr]); - // } + return true; } diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index 318993a65..554f1672b 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -141,7 +141,7 @@ public: bool WriteViewingKey(const libzcash::SproutViewingKey &vk); bool EraseViewingKey(const libzcash::SproutViewingKey &vk); - + private: CWalletDB(const CWalletDB&); void operator=(const CWalletDB&);