diff --git a/src/coins.cpp b/src/coins.cpp index 7f400ed44..5c01d8eba 100644 --- a/src/coins.cpp +++ b/src/coins.cpp @@ -43,6 +43,7 @@ bool CCoins::Spend(uint32_t nPos) return true; } bool CCoinsView::GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const { return false; } +bool CCoinsView::GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { return false; } bool CCoinsView::GetNullifier(const uint256 &nullifier, ShieldedType type) const { return false; } bool CCoinsView::GetCoins(const uint256 &txid, CCoins &coins) const { return false; } bool CCoinsView::HaveCoins(const uint256 &txid) const { return false; } @@ -53,6 +54,7 @@ bool CCoinsView::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers) { return false; } bool CCoinsView::GetStats(CCoinsStats &stats) const { return false; } @@ -61,6 +63,7 @@ bool CCoinsView::GetStats(CCoinsStats &stats) const { return false; } CCoinsViewBacked::CCoinsViewBacked(CCoinsView *viewIn) : base(viewIn) { } bool CCoinsViewBacked::GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const { return base->GetSproutAnchorAt(rt, tree); } +bool CCoinsViewBacked::GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { return base->GetSaplingAnchorAt(rt, tree); } bool CCoinsViewBacked::GetNullifier(const uint256 &nullifier, ShieldedType type) const { return base->GetNullifier(nullifier, type); } bool CCoinsViewBacked::GetCoins(const uint256 &txid, CCoins &coins) const { return base->GetCoins(txid, coins); } bool CCoinsViewBacked::HaveCoins(const uint256 &txid) const { return base->HaveCoins(txid); } @@ -72,8 +75,9 @@ bool CCoinsViewBacked::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, - CNullifiersMap &mapSaplingNullifiers) { return base->BatchWrite(mapCoins, hashBlock, hashSproutAnchor, hashSaplingAnchor, mapSproutAnchors, mapSproutNullifiers, mapSaplingNullifiers); } + CNullifiersMap &mapSaplingNullifiers) { return base->BatchWrite(mapCoins, hashBlock, hashSproutAnchor, hashSaplingAnchor, mapSproutAnchors, mapSaplingAnchors, mapSproutNullifiers, mapSaplingNullifiers); } bool CCoinsViewBacked::GetStats(CCoinsStats &stats) const { return base->GetStats(stats); } CCoinsKeyHasher::CCoinsKeyHasher() : salt(GetRandHash()) {} @@ -88,6 +92,7 @@ CCoinsViewCache::~CCoinsViewCache() size_t CCoinsViewCache::DynamicMemoryUsage() const { return memusage::DynamicUsage(cacheCoins) + memusage::DynamicUsage(cacheSproutAnchors) + + memusage::DynamicUsage(cacheSaplingAnchors) + memusage::DynamicUsage(cacheSproutNullifiers) + memusage::DynamicUsage(cacheSaplingNullifiers) + cachedCoinsUsage; @@ -135,6 +140,29 @@ bool CCoinsViewCache::GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTr return true; } +bool CCoinsViewCache::GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { + CAnchorsSaplingMap::const_iterator it = cacheSaplingAnchors.find(rt); + if (it != cacheSaplingAnchors.end()) { + if (it->second.entered) { + tree = it->second.tree; + return true; + } else { + return false; + } + } + + if (!base->GetSaplingAnchorAt(rt, tree)) { + return false; + } + + CAnchorsSaplingMap::iterator ret = cacheSaplingAnchors.insert(std::make_pair(rt, CAnchorsSaplingCacheEntry())).first; + ret->second.entered = true; + ret->second.tree = tree; + cachedCoinsUsage += ret->second.tree.DynamicMemoryUsage(); + + return true; +} + bool CCoinsViewCache::GetNullifier(const uint256 &nullifier, ShieldedType type) const { CNullifiersMap* cacheToUse; switch (type) { @@ -203,6 +231,24 @@ void CCoinsViewCache::PushSproutAnchor(const ZCIncrementalMerkleTree &tree) { ); } +template<> +void CCoinsViewCache::BringBestAnchorIntoCache( + const uint256 ¤tRoot, + ZCIncrementalMerkleTree &tree +) +{ + assert(GetSproutAnchorAt(currentRoot, tree)); +} + +template<> +void CCoinsViewCache::BringBestAnchorIntoCache( + const uint256 ¤tRoot, + ZCSaplingIncrementalMerkleTree &tree +) +{ + assert(GetSaplingAnchorAt(currentRoot, tree)); +} + template void CCoinsViewCache::AbstractPopAnchor( const uint256 &newrt, @@ -221,17 +267,7 @@ void CCoinsViewCache::AbstractPopAnchor( // so that its tree exists in memory. { Tree tree; - switch (type) { - case SPROUT: - assert(GetSproutAnchorAt(currentRoot, tree)); - break; - case SAPLING: - // TODO - assert(false); - break; - default: - throw std::runtime_error("Unknown shielded type " + type); - } + BringBestAnchorIntoCache(currentRoot, tree); } // Mark the anchor as unentered, removing it from view @@ -367,11 +403,45 @@ void BatchWriteNullifiers(CNullifiersMap &mapNullifiers, CNullifiersMap &cacheNu } } +template +void BatchWriteAnchors( + Map &mapAnchors, + Map &cacheAnchors, + size_t &cachedCoinsUsage +) +{ + for (MapIterator child_it = mapAnchors.begin(); child_it != mapAnchors.end();) + { + if (child_it->second.flags & MapEntry::DIRTY) { + MapIterator parent_it = cacheAnchors.find(child_it->first); + + if (parent_it == cacheAnchors.end()) { + MapEntry& entry = cacheAnchors[child_it->first]; + entry.entered = child_it->second.entered; + entry.tree = child_it->second.tree; + entry.flags = MapEntry::DIRTY; + + cachedCoinsUsage += entry.tree.DynamicMemoryUsage(); + } else { + if (parent_it->second.entered != child_it->second.entered) { + // The parent may have removed the entry. + parent_it->second.entered = child_it->second.entered; + parent_it->second.flags |= MapEntry::DIRTY; + } + } + } + + MapIterator itOld = child_it++; + mapAnchors.erase(itOld); + } +} + bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlockIn, const uint256 &hashSproutAnchorIn, const uint256 &hashSaplingAnchorIn, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers) { assert(!hasModifier); @@ -410,30 +480,8 @@ bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, mapCoins.erase(itOld); } - for (CAnchorsSproutMap::iterator child_it = mapSproutAnchors.begin(); child_it != mapSproutAnchors.end();) - { - if (child_it->second.flags & CAnchorsSproutCacheEntry::DIRTY) { - CAnchorsSproutMap::iterator parent_it = cacheSproutAnchors.find(child_it->first); - - if (parent_it == cacheSproutAnchors.end()) { - CAnchorsSproutCacheEntry& entry = cacheSproutAnchors[child_it->first]; - entry.entered = child_it->second.entered; - entry.tree = child_it->second.tree; - entry.flags = CAnchorsSproutCacheEntry::DIRTY; - - cachedCoinsUsage += entry.tree.DynamicMemoryUsage(); - } else { - if (parent_it->second.entered != child_it->second.entered) { - // The parent may have removed the entry. - parent_it->second.entered = child_it->second.entered; - parent_it->second.flags |= CAnchorsSproutCacheEntry::DIRTY; - } - } - } - - CAnchorsSproutMap::iterator itOld = child_it++; - mapSproutAnchors.erase(itOld); - } + ::BatchWriteAnchors(mapSproutAnchors, cacheSproutAnchors, cachedCoinsUsage); + ::BatchWriteAnchors(mapSaplingAnchors, cacheSaplingAnchors, cachedCoinsUsage); ::BatchWriteNullifiers(mapSproutNullifiers, cacheSproutNullifiers); ::BatchWriteNullifiers(mapSaplingNullifiers, cacheSaplingNullifiers); @@ -445,9 +493,10 @@ bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, } bool CCoinsViewCache::Flush() { - bool fOk = base->BatchWrite(cacheCoins, hashBlock, hashSproutAnchor, hashSaplingAnchor, cacheSproutAnchors, cacheSproutNullifiers, cacheSaplingNullifiers); + bool fOk = base->BatchWrite(cacheCoins, hashBlock, hashSproutAnchor, hashSaplingAnchor, cacheSproutAnchors, cacheSaplingAnchors, cacheSproutNullifiers, cacheSaplingNullifiers); cacheCoins.clear(); cacheSproutAnchors.clear(); + cacheSaplingAnchors.clear(); cacheSproutNullifiers.clear(); cacheSaplingNullifiers.clear(); cachedCoinsUsage = 0; @@ -515,6 +564,8 @@ bool CCoinsViewCache::HaveJoinSplitRequirements(const CTransaction& tx) const return false; } + // TODO: Sapling anchor checks + return true; } diff --git a/src/coins.h b/src/coins.h index b8c41e5d7..11f758c34 100644 --- a/src/coins.h +++ b/src/coins.h @@ -286,6 +286,19 @@ struct CAnchorsSproutCacheEntry CAnchorsSproutCacheEntry() : entered(false), flags(0) {} }; +struct CAnchorsSaplingCacheEntry +{ + bool entered; // This will be false if the anchor is removed from the cache + ZCSaplingIncrementalMerkleTree tree; // The tree itself + unsigned char flags; + + enum Flags { + DIRTY = (1 << 0), // This cache entry is potentially different from the version in the parent view. + }; + + CAnchorsSaplingCacheEntry() : entered(false), flags(0) {} +}; + struct CNullifiersCacheEntry { bool entered; // If the nullifier is spent or not @@ -306,6 +319,7 @@ enum ShieldedType typedef boost::unordered_map CCoinsMap; typedef boost::unordered_map CAnchorsSproutMap; +typedef boost::unordered_map CAnchorsSaplingMap; typedef boost::unordered_map CNullifiersMap; struct CCoinsStats @@ -326,9 +340,12 @@ struct CCoinsStats class CCoinsView { public: - //! Retrieve the tree at a particular anchored root in the chain + //! Retrieve the tree (Sprout) at a particular anchored root in the chain virtual bool GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; + //! Retrieve the tree (Sapling) at a particular anchored root in the chain + virtual bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const; + //! Determine whether a nullifier is spent or not virtual bool GetNullifier(const uint256 &nullifier, ShieldedType type) const; @@ -352,6 +369,7 @@ public: const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers); @@ -372,6 +390,7 @@ protected: public: CCoinsViewBacked(CCoinsView *viewIn); bool GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; + bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const; bool GetNullifier(const uint256 &nullifier, ShieldedType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; @@ -383,6 +402,7 @@ public: const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers); bool GetStats(CCoinsStats &stats) const; @@ -428,6 +448,7 @@ protected: mutable uint256 hashSproutAnchor; mutable uint256 hashSaplingAnchor; mutable CAnchorsSproutMap cacheSproutAnchors; + mutable CAnchorsSaplingMap cacheSaplingAnchors; mutable CNullifiersMap cacheSproutNullifiers; mutable CNullifiersMap cacheSaplingNullifiers; @@ -440,6 +461,7 @@ public: // Standard CCoinsView methods bool GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; + bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const; bool GetNullifier(const uint256 &nullifier, ShieldedType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; @@ -451,6 +473,7 @@ public: const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers); @@ -542,6 +565,13 @@ private: Cache &cacheAnchors, uint256 &hash ); + + //! Interface for bringing an anchor into the cache. + template + void BringBestAnchorIntoCache( + const uint256 ¤tRoot, + Tree &tree + ); }; #endif // BITCOIN_COINS_H diff --git a/src/gtest/test_mempool.cpp b/src/gtest/test_mempool.cpp index 1d5999de2..2056950b6 100644 --- a/src/gtest/test_mempool.cpp +++ b/src/gtest/test_mempool.cpp @@ -23,6 +23,10 @@ public: return false; } + bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { + return false; + } + bool GetNullifier(const uint256 &nf, ShieldedType type) const { return false; } diff --git a/src/gtest/test_validation.cpp b/src/gtest/test_validation.cpp index b76c3e77d..2198f592c 100644 --- a/src/gtest/test_validation.cpp +++ b/src/gtest/test_validation.cpp @@ -25,6 +25,10 @@ public: return false; } + bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { + return false; + } + bool GetNullifier(const uint256 &nf, ShieldedType type) const { return false; } diff --git a/src/txdb.cpp b/src/txdb.cpp index f996c3551..68424b226 100644 --- a/src/txdb.cpp +++ b/src/txdb.cpp @@ -18,6 +18,7 @@ using namespace std; static const char DB_SPROUT_ANCHOR = 'A'; +static const char DB_SAPLING_ANCHOR = 'X'; static const char DB_NULLIFIER = 's'; static const char DB_SAPLING_NULLIFIER = 'S'; static const char DB_COINS = 'c'; @@ -53,6 +54,18 @@ bool CCoinsViewDB::GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree return read; } +bool CCoinsViewDB::GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const { + if (rt == ZCSaplingIncrementalMerkleTree::empty_root()) { + ZCSaplingIncrementalMerkleTree new_tree; + tree = new_tree; + return true; + } + + bool read = db.Read(make_pair(DB_SAPLING_ANCHOR, rt), tree); + + return read; +} + bool CCoinsViewDB::GetNullifier(const uint256 &nf, ShieldedType type) const { bool spent = false; char dbChar; @@ -118,11 +131,29 @@ void BatchWriteNullifiers(CDBBatch& batch, CNullifiersMap& mapToUse, const char& } } +template +void BatchWriteAnchors(CDBBatch& batch, Map& mapToUse, const char& dbChar) +{ + for (MapIterator it = mapToUse.begin(); it != mapToUse.end();) { + if (it->second.flags & MapEntry::DIRTY) { + if (!it->second.entered) + batch.Erase(make_pair(dbChar, it->first)); + else { + batch.Write(make_pair(dbChar, it->first), it->second.tree); + } + // TODO: changed++? + } + MapIterator itOld = it++; + mapToUse.erase(itOld); + } +} + bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock, const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers) { CDBBatch batch(db); @@ -141,18 +172,8 @@ bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, mapCoins.erase(itOld); } - for (CAnchorsSproutMap::iterator it = mapSproutAnchors.begin(); it != mapSproutAnchors.end();) { - if (it->second.flags & CAnchorsSproutCacheEntry::DIRTY) { - if (!it->second.entered) - batch.Erase(make_pair(DB_SPROUT_ANCHOR, it->first)); - else { - batch.Write(make_pair(DB_SPROUT_ANCHOR, it->first), it->second.tree); - } - // TODO: changed++? - } - CAnchorsSproutMap::iterator itOld = it++; - mapSproutAnchors.erase(itOld); - } + ::BatchWriteAnchors(batch, mapSproutAnchors, DB_SPROUT_ANCHOR); + ::BatchWriteAnchors(batch, mapSaplingAnchors, DB_SAPLING_ANCHOR); ::BatchWriteNullifiers(batch, mapSproutNullifiers, DB_NULLIFIER); ::BatchWriteNullifiers(batch, mapSaplingNullifiers, DB_SAPLING_NULLIFIER); diff --git a/src/txdb.h b/src/txdb.h index d97f18303..e8d8ed899 100644 --- a/src/txdb.h +++ b/src/txdb.h @@ -36,6 +36,7 @@ public: CCoinsViewDB(size_t nCacheSize, bool fMemory = false, bool fWipe = false); bool GetSproutAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; + bool GetSaplingAnchorAt(const uint256 &rt, ZCSaplingIncrementalMerkleTree &tree) const; bool GetNullifier(const uint256 &nf, ShieldedType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; @@ -46,6 +47,7 @@ public: const uint256 &hashSproutAnchor, const uint256 &hashSaplingAnchor, CAnchorsSproutMap &mapSproutAnchors, + CAnchorsSaplingMap &mapSaplingAnchors, CNullifiersMap &mapSproutNullifiers, CNullifiersMap &mapSaplingNullifiers); bool GetStats(CCoinsStats &stats) const; diff --git a/src/txmempool.cpp b/src/txmempool.cpp index a7fc3d1f7..2a96ecac9 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -413,6 +413,7 @@ void CTxMemPool::check(const CCoinsViewCache *pcoins) const intermediates.insert(std::make_pair(tree.root(), tree)); } for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + // TODO: anchor check assert(!pcoins->GetNullifier(spendDescription.nullifier, SAPLING)); } if (fDependsWait)