diff --git a/src/gtest/test_validation.cpp b/src/gtest/test_validation.cpp index 6f603eaec..59a466588 100644 --- a/src/gtest/test_validation.cpp +++ b/src/gtest/test_validation.cpp @@ -71,7 +71,8 @@ TEST(Validation, ContextualCheckInputsPassesWithCoinbase) { CCoinsViewCache view(&fakeDB); CValidationState state; - EXPECT_TRUE(ContextualCheckInputs(tx, state, view, false, 0, false, Params(CBaseChainParams::MAIN).GetConsensus())); + CachedHashes cachedHashes(tx); + EXPECT_TRUE(ContextualCheckInputs(tx, state, view, false, 0, false, cachedHashes, Params(CBaseChainParams::MAIN).GetConsensus())); } TEST(Validation, ReceivedBlockTransactions) { diff --git a/src/main.cpp b/src/main.cpp index b43bbd9ef..d5dbe1e3f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1347,7 +1347,8 @@ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransa // Check against previous transactions // This is done last to help prevent CPU exhaustion denial-of-service attacks. - if (!ContextualCheckInputs(tx, state, view, true, STANDARD_SCRIPT_VERIFY_FLAGS, true, Params().GetConsensus())) + CachedHashes cachedHashes(tx); + if (!ContextualCheckInputs(tx, state, view, true, STANDARD_SCRIPT_VERIFY_FLAGS, true, cachedHashes, Params().GetConsensus())) { return error("AcceptToMemoryPool: ConnectInputs failed %s", hash.ToString()); } @@ -1361,7 +1362,7 @@ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransa // There is a similar check in CreateNewBlock() to prevent creating // invalid blocks, however allowing such transactions into the mempool // can be exploited as a DoS attack. - if (!ContextualCheckInputs(tx, state, view, true, MANDATORY_SCRIPT_VERIFY_FLAGS, true, Params().GetConsensus())) + if (!ContextualCheckInputs(tx, state, view, true, MANDATORY_SCRIPT_VERIFY_FLAGS, true, cachedHashes, Params().GetConsensus())) { return error("AcceptToMemoryPool: BUG! PLEASE REPORT THIS! ConnectInputs failed against MANDATORY but not STANDARD flags %s", hash.ToString()); } @@ -1726,7 +1727,7 @@ void UpdateCoins(const CTransaction& tx, CCoinsViewCache& inputs, int nHeight) bool CScriptCheck::operator()() { const CScript &scriptSig = ptxTo->vin[nIn].scriptSig; - if (!VerifyScript(scriptSig, scriptPubKey, nFlags, CachingTransactionSignatureChecker(ptxTo, nIn, amount, cacheStore), &error)) { + if (!VerifyScript(scriptSig, scriptPubKey, nFlags, CachingTransactionSignatureChecker(ptxTo, nIn, amount, cacheStore, *cachedHashes), &error)) { return ::error("CScriptCheck(): %s:%d VerifySignature failed: %s", ptxTo->GetHash().ToString(), nIn, ScriptErrorString(error)); } return true; @@ -1809,7 +1810,7 @@ bool CheckTxInputs(const CTransaction& tx, CValidationState& state, const CCoins } }// namespace Consensus -bool ContextualCheckInputs(const CTransaction& tx, CValidationState &state, const CCoinsViewCache &inputs, bool fScriptChecks, unsigned int flags, bool cacheStore, const Consensus::Params& consensusParams, std::vector *pvChecks) +bool ContextualCheckInputs(const CTransaction& tx, CValidationState &state, const CCoinsViewCache &inputs, bool fScriptChecks, unsigned int flags, bool cacheStore, CachedHashes& cachedHashes, const Consensus::Params& consensusParams, std::vector *pvChecks) { if (!tx.IsCoinBase()) { @@ -1834,7 +1835,7 @@ bool ContextualCheckInputs(const CTransaction& tx, CValidationState &state, cons assert(coins); // Verify signature - CScriptCheck check(*coins, tx, i, flags, cacheStore); + CScriptCheck check(*coins, tx, i, flags, cacheStore, &cachedHashes); if (pvChecks) { pvChecks->push_back(CScriptCheck()); check.swap(pvChecks->back()); @@ -1847,7 +1848,7 @@ bool ContextualCheckInputs(const CTransaction& tx, CValidationState &state, cons // avoid splitting the network between upgraded and // non-upgraded nodes. CScriptCheck check2(*coins, tx, i, - flags & ~STANDARD_NOT_MANDATORY_VERIFY_FLAGS, cacheStore); + flags & ~STANDARD_NOT_MANDATORY_VERIFY_FLAGS, cacheStore, &cachedHashes); if (check2()) return state.Invalid(false, REJECT_NONSTANDARD, strprintf("non-mandatory-script-verify-flag (%s)", ScriptErrorString(check.GetScriptError()))); } @@ -2237,6 +2238,8 @@ bool ConnectBlock(const CBlock& block, CValidationState& state, CBlockIndex* pin assert(tree.root() == old_tree_root); } + std::vector cachedHashes; + cachedHashes.reserve(block.vtx.size()); // Required so that pointers to individual CachedHashes don't get invalidated for (unsigned int i = 0; i < block.vtx.size(); i++) { const CTransaction &tx = block.vtx[i]; @@ -2265,11 +2268,16 @@ bool ConnectBlock(const CBlock& block, CValidationState& state, CBlockIndex* pin if (nSigOps > MAX_BLOCK_SIGOPS) return state.DoS(100, error("ConnectBlock(): too many sigops"), REJECT_INVALID, "bad-blk-sigops"); + } + cachedHashes.emplace_back(tx); + + if (!tx.IsCoinBase()) + { nFees += view.GetValueIn(tx)-tx.GetValueOut(); std::vector vChecks; - if (!ContextualCheckInputs(tx, state, view, fExpensiveChecks, flags, false, chainparams.GetConsensus(), nScriptCheckThreads ? &vChecks : NULL)) + if (!ContextualCheckInputs(tx, state, view, fExpensiveChecks, flags, false, cachedHashes[i], chainparams.GetConsensus(), nScriptCheckThreads ? &vChecks : NULL)) return false; control.Add(vChecks); } diff --git a/src/main.h b/src/main.h index df11c39e3..ae90c2903 100644 --- a/src/main.h +++ b/src/main.h @@ -45,6 +45,7 @@ class CInv; class CScriptCheck; class CValidationInterface; class CValidationState; +class CachedHashes; struct CNodeStateStats; @@ -330,7 +331,7 @@ unsigned int GetP2SHSigOpCount(const CTransaction& tx, const CCoinsViewCache& ma * instead of being performed inline. */ bool ContextualCheckInputs(const CTransaction& tx, CValidationState &state, const CCoinsViewCache &view, bool fScriptChecks, - unsigned int flags, bool cacheStore, const Consensus::Params& consensusParams, + unsigned int flags, bool cacheStore, CachedHashes& cachedHashes, const Consensus::Params& consensusParams, std::vector *pvChecks = NULL); /** Check a transaction contextually against a set of consensus rules */ @@ -390,12 +391,13 @@ private: unsigned int nFlags; bool cacheStore; ScriptError error; + CachedHashes *cachedHashes; public: CScriptCheck(): amount(0), ptxTo(0), nIn(0), nFlags(0), cacheStore(false), error(SCRIPT_ERR_UNKNOWN_ERROR) {} - CScriptCheck(const CCoins& txFromIn, const CTransaction& txToIn, unsigned int nInIn, unsigned int nFlagsIn, bool cacheIn) : + CScriptCheck(const CCoins& txFromIn, const CTransaction& txToIn, unsigned int nInIn, unsigned int nFlagsIn, bool cacheIn, CachedHashes* cachedHashesIn) : scriptPubKey(txFromIn.vout[txToIn.vin[nInIn].prevout.n].scriptPubKey), amount(txFromIn.vout[txToIn.vin[nInIn].prevout.n].nValue), - ptxTo(&txToIn), nIn(nInIn), nFlags(nFlagsIn), cacheStore(cacheIn), error(SCRIPT_ERR_UNKNOWN_ERROR) { } + ptxTo(&txToIn), nIn(nInIn), nFlags(nFlagsIn), cacheStore(cacheIn), error(SCRIPT_ERR_UNKNOWN_ERROR), cachedHashes(cachedHashesIn) { } bool operator()(); @@ -407,6 +409,7 @@ public: std::swap(nFlags, check.nFlags); std::swap(cacheStore, check.cacheStore); std::swap(error, check.error); + std::swap(cachedHashes, check.cachedHashes); } ScriptError GetScriptError() const { return error; } diff --git a/src/miner.cpp b/src/miner.cpp index 84a6edf6a..a59856c60 100644 --- a/src/miner.cpp +++ b/src/miner.cpp @@ -293,7 +293,8 @@ CBlockTemplate* CreateNewBlock(const CScript& scriptPubKeyIn) // policy here, but we still have to ensure that the block we // create only contains transactions that are valid in new blocks. CValidationState state; - if (!ContextualCheckInputs(tx, state, view, true, MANDATORY_SCRIPT_VERIFY_FLAGS, true, Params().GetConsensus())) + CachedHashes cachedHashes(tx); + if (!ContextualCheckInputs(tx, state, view, true, MANDATORY_SCRIPT_VERIFY_FLAGS, true, cachedHashes, Params().GetConsensus())) continue; UpdateCoins(tx, view, nHeight); diff --git a/src/script/interpreter.cpp b/src/script/interpreter.cpp index bf3313f5c..0405b8fa7 100644 --- a/src/script/interpreter.cpp +++ b/src/script/interpreter.cpp @@ -1050,9 +1050,40 @@ public: } }; +uint256 GetPrevoutHash(const CTransaction& txTo) { + CHashWriter ss(SER_GETHASH, 0); + for (unsigned int n = 0; n < txTo.vin.size(); n++) { + ss << txTo.vin[n].prevout; + } + return ss.GetHash(); +} + +uint256 GetSequenceHash(const CTransaction& txTo) { + CHashWriter ss(SER_GETHASH, 0); + for (unsigned int n = 0; n < txTo.vin.size(); n++) { + ss << txTo.vin[n].nSequence; + } + return ss.GetHash(); +} + +uint256 GetOutputsHash(const CTransaction& txTo) { + CHashWriter ss(SER_GETHASH, 0); + for (unsigned int n = 0; n < txTo.vout.size(); n++) { + ss << txTo.vout[n]; + } + return ss.GetHash(); +} + } // anon namespace -uint256 SignatureHash(const CScript& scriptCode, const CTransaction& txTo, unsigned int nIn, int nHashType, const CAmount& amount, SigVersion sigversion) +CachedHashes::CachedHashes(const CTransaction& txTo) +{ + hashPrevouts = GetPrevoutHash(txTo); + hashSequence = GetSequenceHash(txTo); + hashOutputs = GetOutputsHash(txTo); +} + +uint256 SignatureHash(const CScript& scriptCode, const CTransaction& txTo, unsigned int nIn, int nHashType, const CAmount& amount, SigVersion sigversion, const CachedHashes* cache) { if (nIn >= txTo.vin.size() && nIn != NOT_AN_INPUT) { // nIn out of range @@ -1065,27 +1096,16 @@ uint256 SignatureHash(const CScript& scriptCode, const CTransaction& txTo, unsig uint256 hashOutputs; if (!(nHashType & SIGHASH_ANYONECANPAY)) { - CHashWriter ss(SER_GETHASH, 0); - for (unsigned int n = 0; n < txTo.vin.size(); n++) { - ss << txTo.vin[n].prevout; - } - hashPrevouts = ss.GetHash(); // TODO: cache this value for all signatures in a transaction + hashPrevouts = cache ? cache->hashPrevouts : GetPrevoutHash(txTo); } if (!(nHashType & SIGHASH_ANYONECANPAY) && (nHashType & 0x1f) != SIGHASH_SINGLE && (nHashType & 0x1f) != SIGHASH_NONE) { - CHashWriter ss(SER_GETHASH, 0); - for (unsigned int n = 0; n < txTo.vin.size(); n++) { - ss << txTo.vin[n].nSequence; - } - hashSequence = ss.GetHash(); // TODO: cache this value for all signatures in a transaction + hashSequence = cache ? cache->hashSequence : GetSequenceHash(txTo); } + if ((nHashType & 0x1f) != SIGHASH_SINGLE && (nHashType & 0x1f) != SIGHASH_NONE) { - CHashWriter ss(SER_GETHASH, 0); - for (unsigned int n = 0; n < txTo.vout.size(); n++) { - ss << txTo.vout[n]; - } - hashOutputs = ss.GetHash(); // TODO: cache this value for all signatures in a transaction + hashOutputs = cache ? cache->hashOutputs : GetOutputsHash(txTo); } else if ((nHashType & 0x1f) == SIGHASH_SINGLE && nIn < txTo.vout.size()) { CHashWriter ss(SER_GETHASH, 0); ss << txTo.vout[nIn]; @@ -1154,7 +1174,7 @@ bool TransactionSignatureChecker::CheckSig(const vector& vchSigIn uint256 sighash; try { - sighash = SignatureHash(scriptCode, *txTo, nIn, nHashType, amount, sigversion); + sighash = SignatureHash(scriptCode, *txTo, nIn, nHashType, amount, sigversion, this->cachedHashes); } catch (logic_error ex) { return false; } diff --git a/src/script/interpreter.h b/src/script/interpreter.h index 103c94d1d..4b37a5c76 100644 --- a/src/script/interpreter.h +++ b/src/script/interpreter.h @@ -88,13 +88,20 @@ enum SCRIPT_VERIFY_CHECKLOCKTIMEVERIFY = (1U << 9), }; +struct CachedHashes +{ + uint256 hashPrevouts, hashSequence, hashOutputs; + + CachedHashes(const CTransaction& tx); +}; + enum SigVersion { SIGVERSION_BASE = 0, SIGVERSION_WITNESS_V0 = 1, }; -uint256 SignatureHash(const CScript &scriptCode, const CTransaction& txTo, unsigned int nIn, int nHashType, const CAmount& amount, SigVersion sigversion); +uint256 SignatureHash(const CScript &scriptCode, const CTransaction& txTo, unsigned int nIn, int nHashType, const CAmount& amount, SigVersion sigversion, const CachedHashes* cache = NULL); class BaseSignatureChecker { @@ -118,12 +125,14 @@ private: const CTransaction* txTo; unsigned int nIn; const CAmount amount; + const CachedHashes* cachedHashes; protected: virtual bool VerifySignature(const std::vector& vchSig, const CPubKey& vchPubKey, const uint256& sighash) const; public: - TransactionSignatureChecker(const CTransaction* txToIn, unsigned int nInIn, const CAmount& amountIn) : txTo(txToIn), nIn(nInIn), amount(amountIn) {} + TransactionSignatureChecker(const CTransaction* txToIn, unsigned int nInIn, const CAmount& amountIn) : txTo(txToIn), nIn(nInIn), amount(amountIn), cachedHashes(NULL) {} + TransactionSignatureChecker(const CTransaction* txToIn, unsigned int nInIn, const CAmount& amountIn, const CachedHashes& cachedHashesIn) : txTo(txToIn), nIn(nInIn), amount(amountIn), cachedHashes(&cachedHashesIn) {} bool CheckSig(const std::vector& scriptSig, const std::vector& vchPubKey, const CScript& scriptCode, SigVersion sigversion) const; bool CheckLockTime(const CScriptNum& nLockTime) const; }; diff --git a/src/script/sigcache.h b/src/script/sigcache.h index 0dcbf1c59..7181841ed 100644 --- a/src/script/sigcache.h +++ b/src/script/sigcache.h @@ -18,7 +18,7 @@ private: bool store; public: - CachingTransactionSignatureChecker(const CTransaction* txToIn, unsigned int nInIn, const CAmount& amount, bool storeIn) : TransactionSignatureChecker(txToIn, nInIn, amount), store(storeIn) {} + CachingTransactionSignatureChecker(const CTransaction* txToIn, unsigned int nInIn, const CAmount& amount, bool storeIn, CachedHashes& cachedHashesIn) : TransactionSignatureChecker(txToIn, nInIn, amount, cachedHashesIn), store(storeIn) {} bool VerifySignature(const std::vector& vchSig, const CPubKey& vchPubKey, const uint256& sighash) const; }; diff --git a/src/script/zcashconsensus.cpp b/src/script/zcashconsensus.cpp index 80e92300f..eec42a00b 100644 --- a/src/script/zcashconsensus.cpp +++ b/src/script/zcashconsensus.cpp @@ -84,9 +84,9 @@ int zcashconsensus_verify_script(const unsigned char *scriptPubKey, unsigned int // Regardless of the verification result, the tx did not error. set_error(err, zcashconsensus_ERR_OK); - + CachedHashes cachedHashes(tx); CAmount am(0); - return VerifyScript(tx.vin[nIn].scriptSig, CScript(scriptPubKey, scriptPubKey + scriptPubKeyLen), flags, TransactionSignatureChecker(&tx, nIn, am), NULL); + return VerifyScript(tx.vin[nIn].scriptSig, CScript(scriptPubKey, scriptPubKey + scriptPubKeyLen), flags, TransactionSignatureChecker(&tx, nIn, am, cachedHashes), NULL); } catch (const std::exception&) { return set_error(err, zcashconsensus_ERR_TX_DESERIALIZE); // Error deserializing } diff --git a/src/test/script_P2SH_tests.cpp b/src/test/script_P2SH_tests.cpp index 3b99fc740..788ea44aa 100644 --- a/src/test/script_P2SH_tests.cpp +++ b/src/test/script_P2SH_tests.cpp @@ -111,18 +111,20 @@ BOOST_AUTO_TEST_CASE(sign) } // All of the above should be OK, and the txTos have valid signatures // Check to make sure signature verification fails if we use the wrong ScriptSig: - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; i++) { + CachedHashes cachedHashes(txTo[i]); for (int j = 0; j < 8; j++) { CScript sigSave = txTo[i].vin[0].scriptSig; txTo[i].vin[0].scriptSig = txTo[j].vin[0].scriptSig; - bool sigOK = CScriptCheck(CCoins(txFrom, 0), txTo[i], 0, SCRIPT_VERIFY_P2SH | SCRIPT_VERIFY_STRICTENC, false)(); + bool sigOK = CScriptCheck(CCoins(txFrom, 0), txTo[i], 0, SCRIPT_VERIFY_P2SH | SCRIPT_VERIFY_STRICTENC, false, &cachedHashes)(); if (i == j) BOOST_CHECK_MESSAGE(sigOK, strprintf("VerifySignature %d %d", i, j)); else BOOST_CHECK_MESSAGE(!sigOK, strprintf("VerifySignature %d %d", i, j)); txTo[i].vin[0].scriptSig = sigSave; } + } } BOOST_AUTO_TEST_CASE(norecurse) diff --git a/src/test/transaction_tests.cpp b/src/test/transaction_tests.cpp index 9adfd6e28..fbbd8a24e 100644 --- a/src/test/transaction_tests.cpp +++ b/src/test/transaction_tests.cpp @@ -147,6 +147,7 @@ BOOST_AUTO_TEST_CASE(tx_valid) BOOST_CHECK_MESSAGE(CheckTransaction(tx, state, verifier), strTest + comment); BOOST_CHECK_MESSAGE(state.IsValid(), comment); + CachedHashes cachedHashes(tx); for (unsigned int i = 0; i < tx.vin.size(); i++) { if (!mapprevOutScriptPubKeys.count(tx.vin[i].prevout)) @@ -158,7 +159,7 @@ BOOST_AUTO_TEST_CASE(tx_valid) CAmount amount = 0; unsigned int verify_flags = ParseScriptFlags(test[2].get_str()); BOOST_CHECK_MESSAGE(VerifyScript(tx.vin[i].scriptSig, mapprevOutScriptPubKeys[tx.vin[i].prevout], - verify_flags, TransactionSignatureChecker(&tx, i, amount), &err), + verify_flags, TransactionSignatureChecker(&tx, i, amount, cachedHashes), &err), strTest + comment); BOOST_CHECK_MESSAGE(err == SCRIPT_ERR_OK, ScriptErrorString(err) + comment); } @@ -231,6 +232,7 @@ BOOST_AUTO_TEST_CASE(tx_invalid) CValidationState state; fValid = CheckTransaction(tx, state, verifier) && state.IsValid(); + CachedHashes cachedHashes(tx); for (unsigned int i = 0; i < tx.vin.size() && fValid; i++) { if (!mapprevOutScriptPubKeys.count(tx.vin[i].prevout)) @@ -242,7 +244,7 @@ BOOST_AUTO_TEST_CASE(tx_invalid) unsigned int verify_flags = ParseScriptFlags(test[2].get_str()); CAmount amount = 0; fValid = VerifyScript(tx.vin[i].scriptSig, mapprevOutScriptPubKeys[tx.vin[i].prevout], - verify_flags, TransactionSignatureChecker(&tx, i, amount), &err); + verify_flags, TransactionSignatureChecker(&tx, i, amount, cachedHashes), &err); } BOOST_CHECK_MESSAGE(!fValid, strTest + comment); BOOST_CHECK_MESSAGE(err != SCRIPT_ERR_OK, ScriptErrorString(err) + comment);