Pass ZCIncrementalMerkleTree to wallet to prevent race conditions

This commit is contained in:
Jack Grigg
2016-08-31 02:00:11 +12:00
parent 769e031c1a
commit de42390f90
6 changed files with 29 additions and 42 deletions

View File

@@ -29,8 +29,8 @@ public:
void IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblock,
const CCoinsViewCache* pcoins) {
CWallet::IncrementNoteWitnesses(pindex, pblock, pcoins);
ZCIncrementalMerkleTree tree) {
CWallet::IncrementNoteWitnesses(pindex, pblock, tree);
}
void DecrementNoteWitnesses() {
CWallet::DecrementNoteWitnesses();
@@ -328,10 +328,8 @@ TEST(wallet_tests, cached_witnesses_empty_chain) {
CBlock block;
block.vtx.push_back(wtx);
MockCCoinsViewCache coins;
// Empty chain, so we shouldn't try to fetch an anchor
EXPECT_CALL(coins, GetAnchorAt(_, _)).Times(0);
wallet.IncrementNoteWitnesses(NULL, &block, &coins);
ZCIncrementalMerkleTree tree;
wallet.IncrementNoteWitnesses(NULL, &block, tree);
witnesses.clear();
wallet.GetNoteWitnesses(notes, witnesses, anchor);
EXPECT_TRUE((bool) witnesses[0]);
@@ -344,9 +342,9 @@ TEST(wallet_tests, cached_witnesses_empty_chain) {
TEST(wallet_tests, cached_witnesses_chain_tip) {
TestWallet wallet;
MockCCoinsViewCache coins;
uint256 anchor1;
CBlock block1;
ZCIncrementalMerkleTree tree;
auto sk = libzcash::SpendingKey::random();
wallet.AddSpendingKey(sk);
@@ -369,9 +367,7 @@ TEST(wallet_tests, cached_witnesses_chain_tip) {
// First block (case tested in _empty_chain)
block1.vtx.push_back(wtx);
EXPECT_CALL(coins, GetAnchorAt(_, _))
.Times(0);
wallet.IncrementNoteWitnesses(NULL, &block1, &coins);
wallet.IncrementNoteWitnesses(NULL, &block1, tree);
// Called to fetch anchor
wallet.GetNoteWitnesses(notes, witnesses, anchor1);
}
@@ -400,10 +396,8 @@ TEST(wallet_tests, cached_witnesses_chain_tip) {
CBlock block2;
block2.hashPrevBlock = block1.GetHash();
block2.vtx.push_back(wtx);
EXPECT_CALL(coins, GetAnchorAt(anchor1, _))
.Times(2)
.WillRepeatedly(Return(true));
wallet.IncrementNoteWitnesses(NULL, &block2, &coins);
ZCIncrementalMerkleTree tree2 {tree};
wallet.IncrementNoteWitnesses(NULL, &block2, tree2);
witnesses.clear();
wallet.GetNoteWitnesses(notes, witnesses, anchor2);
EXPECT_TRUE((bool) witnesses[0]);
@@ -419,7 +413,7 @@ TEST(wallet_tests, cached_witnesses_chain_tip) {
// Re-incrementing with the same block should give the same result
uint256 anchor4;
wallet.IncrementNoteWitnesses(NULL, &block2, &coins);
wallet.IncrementNoteWitnesses(NULL, &block2, tree);
witnesses.clear();
wallet.GetNoteWitnesses(notes, witnesses, anchor4);
EXPECT_TRUE((bool) witnesses[0]);

View File

@@ -336,10 +336,11 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase,
return false;
}
void CWallet::ChainTip(const CBlockIndex *pindex, const CBlock *pblock, bool added)
void CWallet::ChainTip(const CBlockIndex *pindex, const CBlock *pblock,
ZCIncrementalMerkleTree tree, bool added)
{
if (added) {
IncrementNoteWitnesses(pindex, pblock, pcoinsTip);
IncrementNoteWitnesses(pindex, pblock, tree);
} else {
DecrementNoteWitnesses();
}
@@ -594,7 +595,7 @@ void CWallet::AddToSpends(const uint256& wtxid)
void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblockIn,
const CCoinsViewCache* pcoins)
ZCIncrementalMerkleTree tree)
{
{
LOCK(cs_wallet);
@@ -618,23 +619,7 @@ void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex,
pblock = █
}
ZCIncrementalMerkleTree tree;
bool treeInitialised = false;
for (const CTransaction& tx : pblock->vtx) {
if (!treeInitialised && tx.vjoinsplit.size() > 0) {
LOCK(cs_main);
// vAnchorCache will only be empty at the beginning
if (vAnchorCache.size() && !pcoins->GetAnchorAt(vAnchorCache.front(), tree)) {
// This should not happen, because IncrementNoteWitnesses()
// is only called when the chain tip updates, and the
// anchors for the JoinSplits in that block should still be
// cached.
// TODO: Calculate the anchor from scratch?
throw std::runtime_error("CWallet::IncrementNoteWitnesses(): anchor not cached");
}
treeInitialised = true;
}
auto hash = tx.GetTxid();
bool txIsOurs = mapWallet.count(hash);
for (size_t i = 0; i < tx.vjoinsplit.size(); i++) {

View File

@@ -588,7 +588,7 @@ public:
protected:
void IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblock,
const CCoinsViewCache* pcoins);
ZCIncrementalMerkleTree tree);
void DecrementNoteWitnesses();
private:
@@ -810,7 +810,7 @@ public:
CAmount GetDebit(const CTransaction& tx, const isminefilter& filter) const;
CAmount GetCredit(const CTransaction& tx, const isminefilter& filter) const;
CAmount GetChange(const CTransaction& tx) const;
void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, bool added);
void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added);
void SetBestChain(const CBlockLocator& loc);
DBErrors LoadWallet(bool& fFirstRunRet);