diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp index 23cc45aae..903a48839 100644 --- a/src/gtest/test_keystore.cpp +++ b/src/gtest/test_keystore.cpp @@ -46,6 +46,50 @@ TEST(keystore_tests, store_and_retrieve_note_decryptor) { EXPECT_EQ(ZCNoteDecryption(sk.receiving_key()), decOut); } +TEST(keystore_tests, StoreAndRetrieveViewingKey) { + CBasicKeyStore keyStore; + libzcash::ViewingKey vkOut; + libzcash::SpendingKey skOut; + ZCNoteDecryption decOut; + + auto sk = libzcash::SpendingKey::random(); + auto vk = sk.viewing_key(); + auto addr = sk.address(); + + // Sanity-check: we can't get a viewing key we haven't added + EXPECT_FALSE(keyStore.HaveViewingKey(addr)); + EXPECT_FALSE(keyStore.GetViewingKey(addr, vkOut)); + + // and we shouldn't have a spending key or decryptor either + EXPECT_FALSE(keyStore.HaveSpendingKey(addr)); + EXPECT_FALSE(keyStore.GetSpendingKey(addr, skOut)); + EXPECT_FALSE(keyStore.GetNoteDecryptor(addr, decOut)); + + keyStore.AddViewingKey(vk); + EXPECT_TRUE(keyStore.HaveViewingKey(addr)); + EXPECT_TRUE(keyStore.GetViewingKey(addr, vkOut)); + EXPECT_EQ(vk, vkOut); + + // We should still not have the spending key... + EXPECT_FALSE(keyStore.HaveSpendingKey(addr)); + EXPECT_FALSE(keyStore.GetSpendingKey(addr, skOut)); + + // ... but we should have a decryptor + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk.receiving_key()), decOut); + + keyStore.RemoveViewingKey(vk); + EXPECT_FALSE(keyStore.HaveViewingKey(addr)); + EXPECT_FALSE(keyStore.GetViewingKey(addr, vkOut)); + EXPECT_FALSE(keyStore.HaveSpendingKey(addr)); + EXPECT_FALSE(keyStore.GetSpendingKey(addr, skOut)); + + // We still have a decryptor because those are cached in memory + // (and also we only remove viewing keys when adding a spending key) + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk.receiving_key()), decOut); +} + #ifdef ENABLE_WALLET class TestCCryptoKeyStore : public CCryptoKeyStore { diff --git a/src/keystore.cpp b/src/keystore.cpp index 3c32ab583..323fe710c 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -92,3 +92,37 @@ bool CBasicKeyStore::AddSpendingKey(const libzcash::SpendingKey &sk) mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(sk.receiving_key()))); return true; } + +bool CBasicKeyStore::AddViewingKey(const libzcash::ViewingKey &vk) +{ + LOCK(cs_SpendingKeyStore); + auto address = vk.address(); + mapViewingKeys[address] = vk; + mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(vk.sk_enc))); + return true; +} + +bool CBasicKeyStore::RemoveViewingKey(const libzcash::ViewingKey &vk) +{ + LOCK(cs_SpendingKeyStore); + mapViewingKeys.erase(vk.address()); + return true; +} + +bool CBasicKeyStore::HaveViewingKey(const libzcash::PaymentAddress &address) const +{ + LOCK(cs_SpendingKeyStore); + return mapViewingKeys.count(address) > 0; +} + +bool CBasicKeyStore::GetViewingKey(const libzcash::PaymentAddress &address, + libzcash::ViewingKey &vkOut) const +{ + LOCK(cs_SpendingKeyStore); + ViewingKeyMap::const_iterator mi = mapViewingKeys.find(address); + if (mi != mapViewingKeys.end()) { + vkOut = mi->second; + return true; + } + return false; +} diff --git a/src/keystore.h b/src/keystore.h index 84595cfb0..0b548920b 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -55,12 +55,19 @@ public: virtual bool HaveSpendingKey(const libzcash::PaymentAddress &address) const =0; virtual bool GetSpendingKey(const libzcash::PaymentAddress &address, libzcash::SpendingKey& skOut) const =0; virtual void GetPaymentAddresses(std::set &setAddress) const =0; + + //! Support for viewing keys + virtual bool AddViewingKey(const libzcash::ViewingKey &vk) =0; + virtual bool RemoveViewingKey(const libzcash::ViewingKey &vk) =0; + virtual bool HaveViewingKey(const libzcash::PaymentAddress &address) const =0; + virtual bool GetViewingKey(const libzcash::PaymentAddress &address, libzcash::ViewingKey& vkOut) const =0; }; typedef std::map KeyMap; typedef std::map ScriptMap; typedef std::set WatchOnlySet; typedef std::map SpendingKeyMap; +typedef std::map ViewingKeyMap; typedef std::map NoteDecryptorMap; /** Basic key store, that keeps keys in an address->secret map */ @@ -71,6 +78,7 @@ protected: ScriptMap mapScripts; WatchOnlySet setWatchOnly; SpendingKeyMap mapSpendingKeys; + ViewingKeyMap mapViewingKeys; NoteDecryptorMap mapNoteDecryptors; public: @@ -168,6 +176,11 @@ public: } } } + + virtual bool AddViewingKey(const libzcash::ViewingKey &vk); + virtual bool RemoveViewingKey(const libzcash::ViewingKey &vk); + virtual bool HaveViewingKey(const libzcash::PaymentAddress &address) const; + virtual bool GetViewingKey(const libzcash::PaymentAddress &address, libzcash::ViewingKey& vkOut) const; }; typedef std::vector > CKeyingMaterial; diff --git a/src/zcash/Address.cpp b/src/zcash/Address.cpp index 75324de4f..baefeae4e 100644 --- a/src/zcash/Address.cpp +++ b/src/zcash/Address.cpp @@ -12,20 +12,28 @@ uint256 PaymentAddress::GetHash() const { return Hash(ss.begin(), ss.end()); } -uint256 ReceivingKey::pk_enc() { +uint256 ReceivingKey::pk_enc() const { return ZCNoteEncryption::generate_pubkey(*this); } +PaymentAddress ViewingKey::address() const { + return PaymentAddress(a_pk, sk_enc.pk_enc()); +} + ReceivingKey SpendingKey::receiving_key() const { return ReceivingKey(ZCNoteEncryption::generate_privkey(*this)); } +ViewingKey SpendingKey::viewing_key() const { + return ViewingKey(PRF_addr_a_pk(*this), receiving_key()); +} + SpendingKey SpendingKey::random() { return SpendingKey(random_uint252()); } PaymentAddress SpendingKey::address() const { - return PaymentAddress(PRF_addr_a_pk(*this), receiving_key().pk_enc()); + return viewing_key().address(); } } diff --git a/src/zcash/Address.hpp b/src/zcash/Address.hpp index e76973cb6..4287fee4f 100644 --- a/src/zcash/Address.hpp +++ b/src/zcash/Address.hpp @@ -40,9 +40,37 @@ public: class ReceivingKey : public uint256 { public: + ReceivingKey() { } ReceivingKey(uint256 sk_enc) : uint256(sk_enc) { } - uint256 pk_enc(); + uint256 pk_enc() const; +}; + +class ViewingKey { +public: + uint256 a_pk; + ReceivingKey sk_enc; + + ViewingKey() : a_pk(), sk_enc() { } + ViewingKey(uint256 a_pk, ReceivingKey sk_enc) : a_pk(a_pk), sk_enc(sk_enc) { } + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) { + READWRITE(a_pk); + READWRITE(sk_enc); + } + + PaymentAddress address() const; + + friend inline bool operator==(const ViewingKey& a, const ViewingKey& b) { + return a.a_pk == b.a_pk && a.sk_enc == b.sk_enc; + } + friend inline bool operator<(const ViewingKey& a, const ViewingKey& b) { + return (a.a_pk < b.a_pk || + (a.a_pk == b.a_pk && a.sk_enc < b.sk_enc)); + } }; class SpendingKey : public uint252 { @@ -53,6 +81,7 @@ public: static SpendingKey random(); ReceivingKey receiving_key() const; + ViewingKey viewing_key() const; PaymentAddress address() const; };