diff --git a/src/main.cpp b/src/main.cpp index 2851fbbbb..1a619e297 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2687,7 +2687,8 @@ bool static DisconnectTip(CValidationState &state, bool fBare = false) { if (!ReadBlockFromDisk(block, pindexDelete)) return AbortNode(state, "Failed to read block"); // Apply the block atomically to the chain state. - uint256 anchorBeforeDisconnect = pcoinsTip->GetBestAnchor(SPROUT); + uint256 sproutAnchorBeforeDisconnect = pcoinsTip->GetBestAnchor(SPROUT); + uint256 saplingAnchorBeforeDisconnect = pcoinsTip->GetBestAnchor(SAPLING); int64_t nStart = GetTimeMicros(); { CCoinsViewCache view(pcoinsTip); @@ -2696,7 +2697,8 @@ bool static DisconnectTip(CValidationState &state, bool fBare = false) { assert(view.Flush()); } LogPrint("bench", "- Disconnect block: %.2fms\n", (GetTimeMicros() - nStart) * 0.001); - uint256 anchorAfterDisconnect = pcoinsTip->GetBestAnchor(SPROUT); + uint256 sproutAnchorAfterDisconnect = pcoinsTip->GetBestAnchor(SPROUT); + uint256 saplingAnchorAfterDisconnect = pcoinsTip->GetBestAnchor(SAPLING); // Write the chain state to disk, if necessary. if (!FlushStateToDisk(state, FLUSH_STATE_IF_NEEDED)) return false; @@ -2710,10 +2712,15 @@ bool static DisconnectTip(CValidationState &state, bool fBare = false) { if (tx.IsCoinBase() || !AcceptToMemoryPool(mempool, stateDummy, tx, false, NULL)) mempool.remove(tx, removed, true); } - if (anchorBeforeDisconnect != anchorAfterDisconnect) { + if (sproutAnchorBeforeDisconnect != sproutAnchorAfterDisconnect) { // The anchor may not change between block disconnects, // in which case we don't want to evict from the mempool yet! - mempool.removeWithAnchor(anchorBeforeDisconnect); + mempool.removeWithAnchor(sproutAnchorBeforeDisconnect, SPROUT); + } + if (saplingAnchorBeforeDisconnect != saplingAnchorAfterDisconnect) { + // The anchor may not change between block disconnects, + // in which case we don't want to evict from the mempool yet! + mempool.removeWithAnchor(saplingAnchorBeforeDisconnect, SAPLING); } } diff --git a/src/txmempool.cpp b/src/txmempool.cpp index a9a33602f..6debd55b8 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -206,7 +206,7 @@ void CTxMemPool::removeForReorg(const CCoinsViewCache *pcoins, unsigned int nMem } -void CTxMemPool::removeWithAnchor(const uint256 &invalidRoot) +void CTxMemPool::removeWithAnchor(const uint256 &invalidRoot, ShieldedType type) { // If a block is disconnected from the tip, and the root changed, // we must invalidate transactions from the mempool which spend @@ -217,11 +217,26 @@ void CTxMemPool::removeWithAnchor(const uint256 &invalidRoot) for (indexed_transaction_set::const_iterator it = mapTx.begin(); it != mapTx.end(); it++) { const CTransaction& tx = it->GetTx(); - BOOST_FOREACH(const JSDescription& joinsplit, tx.vjoinsplit) { - if (joinsplit.anchor == invalidRoot) { - transactionsToRemove.push_back(tx); - break; - } + switch (type) { + case SPROUT: + BOOST_FOREACH(const JSDescription& joinsplit, tx.vjoinsplit) { + if (joinsplit.anchor == invalidRoot) { + transactionsToRemove.push_back(tx); + break; + } + } + break; + case SAPLING: + BOOST_FOREACH(const SpendDescription& spendDescription, tx.vShieldedSpend) { + if (spendDescription.anchor == invalidRoot) { + transactionsToRemove.push_back(tx); + break; + } + } + break; + default: + throw runtime_error("Unknown shielded type " + type); + break; } } diff --git a/src/txmempool.h b/src/txmempool.h index f37636efb..ec8a8518a 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -169,7 +169,7 @@ public: bool addUnchecked(const uint256& hash, const CTxMemPoolEntry &entry, bool fCurrentEstimate = true); void remove(const CTransaction &tx, std::list& removed, bool fRecursive = false); - void removeWithAnchor(const uint256 &invalidRoot); + void removeWithAnchor(const uint256 &invalidRoot, ShieldedType type); void removeForReorg(const CCoinsViewCache *pcoins, unsigned int nMemPoolHeight, int flags); void removeConflicts(const CTransaction &tx, std::list& removed); void removeExpired(unsigned int nBlockHeight);