diff --git a/src/base58.cpp b/src/base58.cpp index efce9186f..a6152e692 100644 --- a/src/base58.cpp +++ b/src/base58.cpp @@ -313,21 +313,19 @@ bool CBitcoinSecret::SetString(const std::string& strSecret) return SetString(strSecret.c_str()); } -const size_t serializedPaymentAddressSize = 64; - bool CZCPaymentAddress::Set(const libzcash::PaymentAddress& addr) { CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); ss << addr; std::vector addrSerialized(ss.begin(), ss.end()); - assert(addrSerialized.size() == serializedPaymentAddressSize); - SetData(Params().Base58Prefix(CChainParams::ZCPAYMENT_ADDRRESS), &addrSerialized[0], serializedPaymentAddressSize); + assert(addrSerialized.size() == libzcash::SerializedPaymentAddressSize); + SetData(Params().Base58Prefix(CChainParams::ZCPAYMENT_ADDRRESS), &addrSerialized[0], libzcash::SerializedPaymentAddressSize); return true; } libzcash::PaymentAddress CZCPaymentAddress::Get() const { - if (vchData.size() != serializedPaymentAddressSize) { + if (vchData.size() != libzcash::SerializedPaymentAddressSize) { throw std::runtime_error( "payment address is invalid" ); @@ -347,21 +345,19 @@ libzcash::PaymentAddress CZCPaymentAddress::Get() const return ret; } -const size_t serializedSpendingKeySize = 32; - bool CZCSpendingKey::Set(const libzcash::SpendingKey& addr) { CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); ss << addr; std::vector addrSerialized(ss.begin(), ss.end()); - assert(addrSerialized.size() == serializedSpendingKeySize); - SetData(Params().Base58Prefix(CChainParams::ZCSPENDING_KEY), &addrSerialized[0], serializedSpendingKeySize); + assert(addrSerialized.size() == libzcash::SerializedSpendingKeySize); + SetData(Params().Base58Prefix(CChainParams::ZCSPENDING_KEY), &addrSerialized[0], libzcash::SerializedSpendingKeySize); return true; } libzcash::SpendingKey CZCSpendingKey::Get() const { - if (vchData.size() != serializedSpendingKeySize) { + if (vchData.size() != libzcash::SerializedSpendingKeySize) { throw std::runtime_error( "spending key is invalid" ); diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp index 11d967e89..26fde42d9 100644 --- a/src/gtest/test_keystore.cpp +++ b/src/gtest/test_keystore.cpp @@ -1,6 +1,8 @@ #include #include "keystore.h" +#include "random.h" +#include "wallet/crypter.h" #include "zcash/Address.hpp" TEST(keystore_tests, store_and_retrieve_spending_key) { @@ -41,3 +43,85 @@ TEST(keystore_tests, store_and_retrieve_note_decryptor) { EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); EXPECT_EQ(ZCNoteDecryption(sk.viewing_key()), decOut); } + +class TestCCryptoKeyStore : public CCryptoKeyStore +{ +public: + bool EncryptKeys(CKeyingMaterial& vMasterKeyIn) { return CCryptoKeyStore::EncryptKeys(vMasterKeyIn); } + bool Unlock(const CKeyingMaterial& vMasterKeyIn) { return CCryptoKeyStore::Unlock(vMasterKeyIn); } +}; + +TEST(keystore_tests, store_and_retrieve_spending_key_in_encrypted_store) { + TestCCryptoKeyStore keyStore; + uint256 r {GetRandHash()}; + CKeyingMaterial vMasterKey (r.begin(), r.end()); + libzcash::SpendingKey keyOut; + ZCNoteDecryption decOut; + std::set addrs; + + // 1) Test adding a key to an unencrypted key store, then encrypting it + auto sk = libzcash::SpendingKey::random(); + auto addr = sk.address(); + EXPECT_FALSE(keyStore.GetNoteDecryptor(addr, decOut)); + + keyStore.AddSpendingKey(sk); + ASSERT_TRUE(keyStore.HaveSpendingKey(addr)); + ASSERT_TRUE(keyStore.GetSpendingKey(addr, keyOut)); + ASSERT_EQ(sk, keyOut); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk.viewing_key()), decOut); + + ASSERT_TRUE(keyStore.EncryptKeys(vMasterKey)); + ASSERT_TRUE(keyStore.HaveSpendingKey(addr)); + ASSERT_FALSE(keyStore.GetSpendingKey(addr, keyOut)); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk.viewing_key()), decOut); + + // Unlocking with a random key should fail + uint256 r2 {GetRandHash()}; + CKeyingMaterial vRandomKey (r2.begin(), r2.end()); + EXPECT_FALSE(keyStore.Unlock(vRandomKey)); + + // Unlocking with a slightly-modified vMasterKey should fail + CKeyingMaterial vModifiedKey (r.begin(), r.end()); + vModifiedKey[0] += 1; + EXPECT_FALSE(keyStore.Unlock(vModifiedKey)); + + // Unlocking with vMasterKey should succeed + ASSERT_TRUE(keyStore.Unlock(vMasterKey)); + ASSERT_TRUE(keyStore.GetSpendingKey(addr, keyOut)); + ASSERT_EQ(sk, keyOut); + + keyStore.GetPaymentAddresses(addrs); + ASSERT_EQ(1, addrs.size()); + ASSERT_EQ(1, addrs.count(addr)); + + // 2) Test adding a spending key to an already-encrypted key store + auto sk2 = libzcash::SpendingKey::random(); + auto addr2 = sk2.address(); + EXPECT_FALSE(keyStore.GetNoteDecryptor(addr2, decOut)); + + keyStore.AddSpendingKey(sk2); + ASSERT_TRUE(keyStore.HaveSpendingKey(addr2)); + ASSERT_TRUE(keyStore.GetSpendingKey(addr2, keyOut)); + ASSERT_EQ(sk2, keyOut); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr2, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk2.viewing_key()), decOut); + + ASSERT_TRUE(keyStore.Lock()); + ASSERT_TRUE(keyStore.HaveSpendingKey(addr2)); + ASSERT_FALSE(keyStore.GetSpendingKey(addr2, keyOut)); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr2, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk2.viewing_key()), decOut); + + ASSERT_TRUE(keyStore.Unlock(vMasterKey)); + ASSERT_TRUE(keyStore.GetSpendingKey(addr2, keyOut)); + ASSERT_EQ(sk2, keyOut); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr2, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk2.viewing_key()), decOut); + + keyStore.GetPaymentAddresses(addrs); + ASSERT_EQ(2, addrs.size()); + ASSERT_EQ(1, addrs.count(addr)); + ASSERT_EQ(1, addrs.count(addr2)); +} diff --git a/src/keystore.h b/src/keystore.h index aa3aefdf2..84595cfb0 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -172,5 +172,6 @@ public: typedef std::vector > CKeyingMaterial; typedef std::map > > CryptedKeyMap; +typedef std::map > CryptedSpendingKeyMap; #endif // BITCOIN_KEYSTORE_H diff --git a/src/streams.h b/src/streams.h index fa1e18def..787d8e297 100644 --- a/src/streams.h +++ b/src/streams.h @@ -27,54 +27,55 @@ * >> and << read and write unformatted data using the above serialization templates. * Fills with data in linear time; some stringstream implementations take N^2 time. */ -class CDataStream +template +class CBaseDataStream { protected: - typedef CSerializeData vector_type; + typedef SerializeType vector_type; vector_type vch; unsigned int nReadPos; public: int nType; int nVersion; - typedef vector_type::allocator_type allocator_type; - typedef vector_type::size_type size_type; - typedef vector_type::difference_type difference_type; - typedef vector_type::reference reference; - typedef vector_type::const_reference const_reference; - typedef vector_type::value_type value_type; - typedef vector_type::iterator iterator; - typedef vector_type::const_iterator const_iterator; - typedef vector_type::reverse_iterator reverse_iterator; + typedef typename vector_type::allocator_type allocator_type; + typedef typename vector_type::size_type size_type; + typedef typename vector_type::difference_type difference_type; + typedef typename vector_type::reference reference; + typedef typename vector_type::const_reference const_reference; + typedef typename vector_type::value_type value_type; + typedef typename vector_type::iterator iterator; + typedef typename vector_type::const_iterator const_iterator; + typedef typename vector_type::reverse_iterator reverse_iterator; - explicit CDataStream(int nTypeIn, int nVersionIn) + explicit CBaseDataStream(int nTypeIn, int nVersionIn) { Init(nTypeIn, nVersionIn); } - CDataStream(const_iterator pbegin, const_iterator pend, int nTypeIn, int nVersionIn) : vch(pbegin, pend) + CBaseDataStream(const_iterator pbegin, const_iterator pend, int nTypeIn, int nVersionIn) : vch(pbegin, pend) { Init(nTypeIn, nVersionIn); } #if !defined(_MSC_VER) || _MSC_VER >= 1300 - CDataStream(const char* pbegin, const char* pend, int nTypeIn, int nVersionIn) : vch(pbegin, pend) + CBaseDataStream(const char* pbegin, const char* pend, int nTypeIn, int nVersionIn) : vch(pbegin, pend) { Init(nTypeIn, nVersionIn); } #endif - CDataStream(const vector_type& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) + CBaseDataStream(const vector_type& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) { Init(nTypeIn, nVersionIn); } - CDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) + CBaseDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) { Init(nTypeIn, nVersionIn); } - CDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) + CBaseDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : vch(vchIn.begin(), vchIn.end()) { Init(nTypeIn, nVersionIn); } @@ -86,15 +87,15 @@ public: nVersion = nVersionIn; } - CDataStream& operator+=(const CDataStream& b) + CBaseDataStream& operator+=(const CBaseDataStream& b) { vch.insert(vch.end(), b.begin(), b.end()); return *this; } - friend CDataStream operator+(const CDataStream& a, const CDataStream& b) + friend CBaseDataStream operator+(const CBaseDataStream& a, const CBaseDataStream& b) { - CDataStream ret = a; + CBaseDataStream ret = a; ret += b; return (ret); } @@ -207,7 +208,7 @@ public: // Stream subset // bool eof() const { return size() == 0; } - CDataStream* rdbuf() { return this; } + CBaseDataStream* rdbuf() { return this; } int in_avail() { return size(); } void SetType(int n) { nType = n; } @@ -217,7 +218,7 @@ public: void ReadVersion() { *this >> nVersion; } void WriteVersion() { *this << nVersion; } - CDataStream& read(char* pch, size_t nSize) + CBaseDataStream& read(char* pch, size_t nSize) { // Read from the beginning of the buffer unsigned int nReadPosNext = nReadPos + nSize; @@ -225,7 +226,7 @@ public: { if (nReadPosNext > vch.size()) { - throw std::ios_base::failure("CDataStream::read(): end of data"); + throw std::ios_base::failure("CBaseDataStream::read(): end of data"); } memcpy(pch, &vch[nReadPos], nSize); nReadPos = 0; @@ -237,7 +238,7 @@ public: return (*this); } - CDataStream& ignore(int nSize) + CBaseDataStream& ignore(int nSize) { // Ignore from the beginning of the buffer assert(nSize >= 0); @@ -245,7 +246,7 @@ public: if (nReadPosNext >= vch.size()) { if (nReadPosNext > vch.size()) - throw std::ios_base::failure("CDataStream::ignore(): end of data"); + throw std::ios_base::failure("CBaseDataStream::ignore(): end of data"); nReadPos = 0; vch.clear(); return (*this); @@ -254,7 +255,7 @@ public: return (*this); } - CDataStream& write(const char* pch, size_t nSize) + CBaseDataStream& write(const char* pch, size_t nSize) { // Write to the end of the buffer vch.insert(vch.end(), pch, pch + nSize); @@ -277,7 +278,7 @@ public: } template - CDataStream& operator<<(const T& obj) + CBaseDataStream& operator<<(const T& obj) { // Serialize to this stream ::Serialize(*this, obj, nType, nVersion); @@ -285,7 +286,7 @@ public: } template - CDataStream& operator>>(T& obj) + CBaseDataStream& operator>>(T& obj) { // Unserialize from this stream ::Unserialize(*this, obj, nType, nVersion); @@ -298,6 +299,30 @@ public: } }; +class CDataStream : public CBaseDataStream +{ +public: + explicit CDataStream(int nTypeIn, int nVersionIn) : CBaseDataStream(nTypeIn, nVersionIn) { } + + CDataStream(const_iterator pbegin, const_iterator pend, int nTypeIn, int nVersionIn) : + CBaseDataStream(pbegin, pend, nTypeIn, nVersionIn) { } + +#if !defined(_MSC_VER) || _MSC_VER >= 1300 + CDataStream(const char* pbegin, const char* pend, int nTypeIn, int nVersionIn) : + CBaseDataStream(pbegin, pend, nTypeIn, nVersionIn) { } +#endif + + CDataStream(const vector_type& vchIn, int nTypeIn, int nVersionIn) : + CBaseDataStream(vchIn, nTypeIn, nVersionIn) { } + + CDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : + CBaseDataStream(vchIn, nTypeIn, nVersionIn) { } + + CDataStream(const std::vector& vchIn, int nTypeIn, int nVersionIn) : + CBaseDataStream(vchIn, nTypeIn, nVersionIn) { } + +}; + diff --git a/src/wallet/crypter.cpp b/src/wallet/crypter.cpp index 0b0fb562e..886492aaa 100644 --- a/src/wallet/crypter.cpp +++ b/src/wallet/crypter.cpp @@ -6,6 +6,7 @@ #include "script/script.h" #include "script/standard.h" +#include "streams.h" #include "util.h" #include @@ -135,12 +136,29 @@ static bool DecryptKey(const CKeyingMaterial& vMasterKey, const std::vector& vchCryptedSecret, + const libzcash::PaymentAddress& address, + libzcash::SpendingKey& sk) +{ + CKeyingMaterial vchSecret; + if(!DecryptSecret(vMasterKey, vchCryptedSecret, address.GetHash(), vchSecret)) + return false; + + if (vchSecret.size() != libzcash::SerializedSpendingKeySize) + return false; + + CSecureDataStream ss(vchSecret, SER_NETWORK, PROTOCOL_VERSION); + ss >> sk; + return sk.address() == address; +} + bool CCryptoKeyStore::SetCrypted() { LOCK(cs_KeyStore); if (fUseCrypto) return true; - if (!mapKeys.empty()) + if (!(mapKeys.empty() && mapSpendingKeys.empty())) return false; fUseCrypto = true; return true; @@ -184,6 +202,21 @@ bool CCryptoKeyStore::Unlock(const CKeyingMaterial& vMasterKeyIn) if (fDecryptionThoroughlyChecked) break; } + CryptedSpendingKeyMap::const_iterator skmi = mapCryptedSpendingKeys.begin(); + for (; skmi != mapCryptedSpendingKeys.end(); ++skmi) + { + const libzcash::PaymentAddress &address = (*skmi).first; + const std::vector &vchCryptedSecret = (*skmi).second; + libzcash::SpendingKey sk; + if (!DecryptSpendingKey(vMasterKeyIn, vchCryptedSecret, address, sk)) + { + keyFail = true; + break; + } + keyPass = true; + if (fDecryptionThoroughlyChecked) + break; + } if (keyPass && keyFail) { LogPrintf("The wallet is probably corrupted: Some keys decrypt but not all.\n"); @@ -267,10 +300,66 @@ bool CCryptoKeyStore::GetPubKey(const CKeyID &address, CPubKey& vchPubKeyOut) co return false; } +bool CCryptoKeyStore::AddSpendingKey(const libzcash::SpendingKey &sk) +{ + { + LOCK(cs_SpendingKeyStore); + if (!IsCrypted()) + return CBasicKeyStore::AddSpendingKey(sk); + + if (IsLocked()) + return false; + + std::vector vchCryptedSecret; + CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION); + ss << sk; + CKeyingMaterial vchSecret(ss.begin(), ss.end()); + auto address = sk.address(); + if (!EncryptSecret(vMasterKey, vchSecret, address.GetHash(), vchCryptedSecret)) + return false; + + if (!AddCryptedSpendingKey(address, sk.viewing_key(), vchCryptedSecret)) + return false; + } + return true; +} + +bool CCryptoKeyStore::AddCryptedSpendingKey(const libzcash::PaymentAddress &address, + const libzcash::ViewingKey &vk, + const std::vector &vchCryptedSecret) +{ + { + LOCK(cs_SpendingKeyStore); + if (!SetCrypted()) + return false; + + mapCryptedSpendingKeys[address] = vchCryptedSecret; + mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(vk))); + } + return true; +} + +bool CCryptoKeyStore::GetSpendingKey(const libzcash::PaymentAddress &address, libzcash::SpendingKey &skOut) const +{ + { + LOCK(cs_SpendingKeyStore); + if (!IsCrypted()) + return CBasicKeyStore::GetSpendingKey(address, skOut); + + CryptedSpendingKeyMap::const_iterator mi = mapCryptedSpendingKeys.find(address); + if (mi != mapCryptedSpendingKeys.end()) + { + const std::vector &vchCryptedSecret = (*mi).second; + return DecryptSpendingKey(vMasterKey, vchCryptedSecret, address, skOut); + } + } + return false; +} + bool CCryptoKeyStore::EncryptKeys(CKeyingMaterial& vMasterKeyIn) { { - LOCK(cs_KeyStore); + LOCK2(cs_KeyStore, cs_SpendingKeyStore); if (!mapCryptedKeys.empty() || IsCrypted()) return false; @@ -287,6 +376,20 @@ bool CCryptoKeyStore::EncryptKeys(CKeyingMaterial& vMasterKeyIn) return false; } mapKeys.clear(); + BOOST_FOREACH(SpendingKeyMap::value_type& mSpendingKey, mapSpendingKeys) + { + const libzcash::SpendingKey &sk = mSpendingKey.second; + CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION); + ss << sk; + CKeyingMaterial vchSecret(ss.begin(), ss.end()); + libzcash::PaymentAddress address = sk.address(); + std::vector vchCryptedSecret; + if (!EncryptSecret(vMasterKeyIn, vchSecret, address.GetHash(), vchCryptedSecret)) + return false; + if (!AddCryptedSpendingKey(address, sk.viewing_key(), vchCryptedSecret)) + return false; + } + mapSpendingKeys.clear(); } return true; } diff --git a/src/wallet/crypter.h b/src/wallet/crypter.h index 70aeb7672..b310b77b0 100644 --- a/src/wallet/crypter.h +++ b/src/wallet/crypter.h @@ -7,7 +7,9 @@ #include "keystore.h" #include "serialize.h" +#include "streams.h" #include "support/allocators/secure.h" +#include "zcash/Address.hpp" class uint256; @@ -66,6 +68,18 @@ public: typedef std::vector > CKeyingMaterial; +class CSecureDataStream : public CBaseDataStream +{ +public: + explicit CSecureDataStream(int nTypeIn, int nVersionIn) : CBaseDataStream(nTypeIn, nVersionIn) { } + + CSecureDataStream(const_iterator pbegin, const_iterator pend, int nTypeIn, int nVersionIn) : + CBaseDataStream(pbegin, pend, nTypeIn, nVersionIn) { } + + CSecureDataStream(const vector_type& vchIn, int nTypeIn, int nVersionIn) : + CBaseDataStream(vchIn, nTypeIn, nVersionIn) { } +}; + /** Encryption/decryption context with key information */ class CCrypter { @@ -114,10 +128,11 @@ class CCryptoKeyStore : public CBasicKeyStore { private: CryptedKeyMap mapCryptedKeys; + CryptedSpendingKeyMap mapCryptedSpendingKeys; CKeyingMaterial vMasterKey; - //! if fUseCrypto is true, mapKeys must be empty + //! if fUseCrypto is true, mapKeys and mapSpendingKeys must be empty //! if fUseCrypto is false, vMasterKey must be empty bool fUseCrypto; @@ -185,6 +200,36 @@ public: mi++; } } + virtual bool AddCryptedSpendingKey(const libzcash::PaymentAddress &address, + const libzcash::ViewingKey &vk, + const std::vector &vchCryptedSecret); + bool AddSpendingKey(const libzcash::SpendingKey &sk); + bool HaveSpendingKey(const libzcash::PaymentAddress &address) const + { + { + LOCK(cs_KeyStore); + if (!IsCrypted()) + return CBasicKeyStore::HaveSpendingKey(address); + return mapCryptedSpendingKeys.count(address) > 0; + } + return false; + } + bool GetSpendingKey(const libzcash::PaymentAddress &address, libzcash::SpendingKey &skOut) const; + void GetPaymentAddresses(std::set &setAddress) const + { + if (!IsCrypted()) + { + CBasicKeyStore::GetPaymentAddresses(setAddress); + return; + } + setAddress.clear(); + CryptedSpendingKeyMap::const_iterator mi = mapCryptedSpendingKeys.begin(); + while (mi != mapCryptedSpendingKeys.end()) + { + setAddress.insert((*mi).first); + mi++; + } + } /** * Wallet status (encrypted, locked) changed. diff --git a/src/zcash/Address.cpp b/src/zcash/Address.cpp index 9bb32fb6c..3849b2ffc 100644 --- a/src/zcash/Address.cpp +++ b/src/zcash/Address.cpp @@ -1,9 +1,17 @@ #include "Address.hpp" #include "NoteEncryption.hpp" +#include "hash.h" #include "prf.h" +#include "streams.h" namespace libzcash { +uint256 PaymentAddress::GetHash() const { + CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); + ss << *this; + return Hash(ss.begin(), ss.end()); +} + uint256 ViewingKey::pk_enc() { return ZCNoteEncryption::generate_pubkey(*this); } diff --git a/src/zcash/Address.hpp b/src/zcash/Address.hpp index 58caae772..efae2af22 100644 --- a/src/zcash/Address.hpp +++ b/src/zcash/Address.hpp @@ -7,6 +7,9 @@ namespace libzcash { +const size_t SerializedPaymentAddressSize = 64; +const size_t SerializedSpendingKeySize = 32; + class PaymentAddress { public: uint256 a_pk; @@ -23,6 +26,9 @@ public: READWRITE(pk_enc); } + //! Get the 256-bit SHA256d hash of this payment address. + uint256 GetHash() const; + friend inline bool operator==(const PaymentAddress& a, const PaymentAddress& b) { return a.a_pk == b.a_pk && a.pk_enc == b.pk_enc; }