diff --git a/src/consensus/consensus.h b/src/consensus/consensus.h index 8650c453a..1efaf99ea 100644 --- a/src/consensus/consensus.h +++ b/src/consensus/consensus.h @@ -10,10 +10,14 @@ static const int32_t MIN_BLOCK_VERSION = 4; /** The minimum allowed transaction version (network rule) */ static const int32_t SPROUT_MIN_TX_VERSION = 1; -/** The minimum allowed transaction version (network rule) */ +/** The minimum allowed Overwinter transaction version (network rule) */ static const int32_t OVERWINTER_MIN_TX_VERSION = 3; -/** The maximum allowed transaction version (network rule) */ +/** The maximum allowed Overwinter transaction version (network rule) */ static const int32_t OVERWINTER_MAX_TX_VERSION = 3; +/** The minimum allowed Sapling transaction version (network rule) */ +static const int32_t SAPLING_MIN_TX_VERSION = 4; +/** The maximum allowed Sapling transaction version (network rule) */ +static const int32_t SAPLING_MAX_TX_VERSION = 4; /** The maximum allowed size for a serialized block, in bytes (network rule) */ static const unsigned int MAX_BLOCK_SIZE = 2000000; /** The maximum allowed number of signature check operations in a block (network rule) */ diff --git a/src/gtest/test_checkblock.cpp b/src/gtest/test_checkblock.cpp index 53c0efcc7..225f8ef3b 100644 --- a/src/gtest/test_checkblock.cpp +++ b/src/gtest/test_checkblock.cpp @@ -110,6 +110,36 @@ TEST(ContextualCheckBlock, BadCoinbaseHeight) { EXPECT_TRUE(ContextualCheckBlock(block, state, &indexPrev)); } + +// Test that a block evaluated under Sprout rules cannot contain Sapling transactions. +// This test assumes that mainnet Overwinter activation is at least height 2. +TEST(ContextualCheckBlock, BlockSproutRulesRejectSaplingTx) { + SelectParams(CBaseChainParams::MAIN); + + CMutableTransaction mtx; + mtx.vin.resize(1); + mtx.vin[0].prevout.SetNull(); + mtx.vin[0].scriptSig = CScript() << 1 << OP_0; + mtx.vout.resize(1); + mtx.vout[0].scriptPubKey = CScript() << OP_TRUE; + mtx.vout[0].nValue = 0; + + mtx.fOverwintered = true; + mtx.nVersion = SAPLING_TX_VERSION; + mtx.nVersionGroupId = SAPLING_VERSION_GROUP_ID; + + CTransaction tx {mtx}; + CBlock block; + block.vtx.push_back(tx); + + MockCValidationState state; + CBlockIndex indexPrev {Params().GenesisBlock()}; + + EXPECT_CALL(state, DoS(100, false, REJECT_INVALID, "tx-overwinter-not-active", false)).Times(1); + EXPECT_FALSE(ContextualCheckBlock(block, state, &indexPrev)); +} + + // Test that a block evaluated under Sprout rules cannot contain Overwinter transactions. // This test assumes that mainnet Overwinter activation is at least height 2. TEST(ContextualCheckBlock, BlockSproutRulesRejectOverwinterTx) { @@ -170,6 +200,41 @@ TEST(ContextualCheckBlock, BlockSproutRulesAcceptSproutTx) { } +// Test that a block evaluated under Overwinter rules cannot contain Sapling transactions. +TEST(ContextualCheckBlock, BlockOverwinterRulesRejectSaplingTx) { + SelectParams(CBaseChainParams::REGTEST); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, 1); + + CMutableTransaction mtx; + mtx.vin.resize(1); + mtx.vin[0].prevout.SetNull(); + mtx.vin[0].scriptSig = CScript() << 1 << OP_0; + mtx.vout.resize(1); + mtx.vout[0].scriptPubKey = CScript() << OP_TRUE; + mtx.vout[0].nValue = 0; + mtx.vout.push_back(CTxOut( + GetBlockSubsidy(1, Params().GetConsensus())/5, + Params().GetFoundersRewardScriptAtHeight(1))); + + mtx.fOverwintered = true; + mtx.nVersion = SAPLING_TX_VERSION; + mtx.nVersionGroupId = SAPLING_VERSION_GROUP_ID; + + CTransaction tx {mtx}; + CBlock block; + block.vtx.push_back(tx); + + MockCValidationState state; + CBlockIndex indexPrev {Params().GenesisBlock()}; + + EXPECT_CALL(state, DoS(100, false, REJECT_INVALID, "bad-overwinter-tx-version-group-id", false)).Times(1); + EXPECT_FALSE(ContextualCheckBlock(block, state, &indexPrev)); + + // Revert to default + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); +} + + // Test block evaluated under Overwinter rules will accept Overwinter transactions TEST(ContextualCheckBlock, BlockOverwinterRulesAcceptOverwinterTx) { SelectParams(CBaseChainParams::REGTEST); @@ -202,7 +267,6 @@ TEST(ContextualCheckBlock, BlockOverwinterRulesAcceptOverwinterTx) { } - // Test block evaluated under Overwinter rules will reject Sprout transactions TEST(ContextualCheckBlock, BlockOverwinterRulesRejectSproutTx) { SelectParams(CBaseChainParams::REGTEST); @@ -230,4 +294,108 @@ TEST(ContextualCheckBlock, BlockOverwinterRulesRejectSproutTx) { // Revert to default UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); -} \ No newline at end of file +} + + +// Test that a block evaluated under Sapling rules can contain Sapling transactions. +TEST(ContextualCheckBlock, BlockSaplingRulesAcceptSaplingTx) { + SelectParams(CBaseChainParams::REGTEST); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, 1); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, 1); + + CMutableTransaction mtx; + mtx.vin.resize(1); + mtx.vin[0].prevout.SetNull(); + mtx.vin[0].scriptSig = CScript() << 1 << OP_0; + mtx.vout.resize(1); + mtx.vout[0].scriptPubKey = CScript() << OP_TRUE; + mtx.vout[0].nValue = 0; + mtx.vout.push_back(CTxOut( + GetBlockSubsidy(1, Params().GetConsensus())/5, + Params().GetFoundersRewardScriptAtHeight(1))); + + mtx.fOverwintered = true; + mtx.nVersion = SAPLING_TX_VERSION; + mtx.nVersionGroupId = SAPLING_VERSION_GROUP_ID; + + CTransaction tx {mtx}; + CBlock block; + block.vtx.push_back(tx); + + MockCValidationState state; + CBlockIndex indexPrev {Params().GenesisBlock()}; + + EXPECT_TRUE(ContextualCheckBlock(block, state, &indexPrev)); + + // Revert to default + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); +} + + +// Test block evaluated under Sapling rules cannot contain Overwinter transactions +TEST(ContextualCheckBlock, BlockSaplingRulesRejectOverwinterTx) { + SelectParams(CBaseChainParams::REGTEST); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, 1); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, 1); + + CMutableTransaction mtx; + mtx.vin.resize(1); + mtx.vin[0].prevout.SetNull(); + mtx.vin[0].scriptSig = CScript() << 1 << OP_0; + mtx.vout.resize(1); + mtx.vout[0].scriptPubKey = CScript() << OP_TRUE; + mtx.vout[0].nValue = 0; + mtx.vout.push_back(CTxOut( + GetBlockSubsidy(1, Params().GetConsensus())/5, + Params().GetFoundersRewardScriptAtHeight(1))); + mtx.fOverwintered = true; + mtx.nVersion = 3; + mtx.nVersionGroupId = OVERWINTER_VERSION_GROUP_ID; + + CTransaction tx {mtx}; + CBlock block; + block.vtx.push_back(tx); + MockCValidationState state; + CBlockIndex indexPrev {Params().GenesisBlock()}; + + EXPECT_CALL(state, DoS(100, false, REJECT_INVALID, "bad-sapling-tx-version-group-id", false)).Times(1); + EXPECT_FALSE(ContextualCheckBlock(block, state, &indexPrev)); + + // Revert to default + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); +} + + + +// Test block evaluated under Sapling rules cannot contain Sprout transactions +TEST(ContextualCheckBlock, BlockSaplingRulesRejectSproutTx) { + SelectParams(CBaseChainParams::REGTEST); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, 1); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, 1); + + CMutableTransaction mtx; + mtx.vin.resize(1); + mtx.vin[0].prevout.SetNull(); + mtx.vin[0].scriptSig = CScript() << 1 << OP_0; + mtx.vout.resize(1); + mtx.vout[0].scriptPubKey = CScript() << OP_TRUE; + mtx.vout[0].nValue = 0; + + mtx.nVersion = 2; + + CTransaction tx {mtx}; + CBlock block; + block.vtx.push_back(tx); + + MockCValidationState state; + CBlockIndex indexPrev {Params().GenesisBlock()}; + + EXPECT_CALL(state, DoS(100, false, REJECT_INVALID, "tx-overwinter-active", false)).Times(1); + EXPECT_FALSE(ContextualCheckBlock(block, state, &indexPrev)); + + // Revert to default + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); +} diff --git a/src/hash.h b/src/hash.h index 76634bd37..3e56f4a86 100644 --- a/src/hash.h +++ b/src/hash.h @@ -182,6 +182,9 @@ public: personal) == 0); } + int GetType() const { return nType; } + int GetVersion() const { return nVersion; } + CBLAKE2bWriter& write(const char *pch, size_t size) { crypto_generichash_blake2b_update(&state, (const unsigned char*)pch, size); return (*this); diff --git a/src/main.cpp b/src/main.cpp index f6d3f3992..07ed8c592 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -645,9 +645,16 @@ unsigned int LimitOrphanTxSize(unsigned int nMaxOrphans) EXCLUSIVE_LOCKS_REQUIRE bool IsStandardTx(const CTransaction& tx, string& reason, const int nHeight) { - bool isOverwinter = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_OVERWINTER); + bool overwinterActive = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_OVERWINTER); + bool saplingActive = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_SAPLING); - if (isOverwinter) { + if (saplingActive) { + // Sapling standard rules apply + if (tx.nVersion > CTransaction::SAPLING_MAX_CURRENT_VERSION || tx.nVersion < CTransaction::SAPLING_MIN_CURRENT_VERSION) { + reason = "sapling-version"; + return false; + } + } else if (overwinterActive) { // Overwinter standard rules apply if (tx.nVersion > CTransaction::OVERWINTER_MAX_CURRENT_VERSION || tx.nVersion < CTransaction::OVERWINTER_MIN_CURRENT_VERSION) { reason = "overwinter-version"; @@ -866,8 +873,9 @@ unsigned int GetP2SHSigOpCount(const CTransaction& tx, const CCoinsViewCache& in */ bool ContextualCheckTransaction(const CTransaction& tx, CValidationState &state, const int nHeight, const int dosLevel) { - bool isOverwinter = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_OVERWINTER); - bool isSprout = !isOverwinter; + bool overwinterActive = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_OVERWINTER); + bool saplingActive = NetworkUpgradeActive(nHeight, Params().GetConsensus(), Consensus::UPGRADE_SAPLING); + bool isSprout = !overwinterActive; // If Sprout rules apply, reject transactions which are intended for Overwinter and beyond if (isSprout && tx.fOverwintered) { @@ -875,20 +883,52 @@ bool ContextualCheckTransaction(const CTransaction& tx, CValidationState &state, REJECT_INVALID, "tx-overwinter-not-active"); } - // If Overwinter rules apply: - if (isOverwinter) { + if (saplingActive) { + // Reject transactions with valid version but missing overwintered flag + if (tx.nVersion >= SAPLING_MIN_TX_VERSION && !tx.fOverwintered) { + return state.DoS(dosLevel, error("ContextualCheckTransaction(): overwintered flag must be set"), + REJECT_INVALID, "tx-overwintered-flag-not-set"); + } + + // Reject transactions with non-Sapling version group ID + if (tx.fOverwintered && tx.nVersionGroupId != SAPLING_VERSION_GROUP_ID) { + return state.DoS(dosLevel, error("CheckTransaction(): invalid Sapling tx version"), + REJECT_INVALID, "bad-sapling-tx-version-group-id"); + } + + // Reject transactions with invalid version + if (tx.fOverwintered && tx.nVersion < SAPLING_MIN_TX_VERSION ) { + return state.DoS(100, error("CheckTransaction(): Sapling version too low"), + REJECT_INVALID, "bad-tx-sapling-version-too-low"); + } + + // Reject transactions with invalid version + if (tx.fOverwintered && tx.nVersion > SAPLING_MAX_TX_VERSION ) { + return state.DoS(100, error("CheckTransaction(): Sapling version too high"), + REJECT_INVALID, "bad-tx-sapling-version-too-high"); + } + } else if (overwinterActive) { // Reject transactions with valid version but missing overwinter flag if (tx.nVersion >= OVERWINTER_MIN_TX_VERSION && !tx.fOverwintered) { return state.DoS(dosLevel, error("ContextualCheckTransaction(): overwinter flag must be set"), REJECT_INVALID, "tx-overwinter-flag-not-set"); } + // Reject transactions with non-Overwinter version group ID + if (tx.fOverwintered && tx.nVersionGroupId != OVERWINTER_VERSION_GROUP_ID) { + return state.DoS(dosLevel, error("CheckTransaction(): invalid Overwinter tx version"), + REJECT_INVALID, "bad-overwinter-tx-version-group-id"); + } + // Reject transactions with invalid version if (tx.fOverwintered && tx.nVersion > OVERWINTER_MAX_TX_VERSION ) { return state.DoS(100, error("CheckTransaction(): overwinter version too high"), REJECT_INVALID, "bad-tx-overwinter-version-too-high"); } + } + // Rules that apply to Overwinter or later: + if (overwinterActive) { // Reject transactions intended for Sprout if (!tx.fOverwintered) { return state.DoS(dosLevel, error("ContextualCheckTransaction: overwinter is active"), @@ -988,7 +1028,8 @@ bool CheckTransactionWithoutProofVerification(const CTransaction& tx, CValidatio return state.DoS(100, error("CheckTransaction(): overwinter version too low"), REJECT_INVALID, "bad-tx-overwinter-version-too-low"); } - if (tx.nVersionGroupId != OVERWINTER_VERSION_GROUP_ID) { + if (tx.nVersionGroupId != OVERWINTER_VERSION_GROUP_ID && + tx.nVersionGroupId != SAPLING_VERSION_GROUP_ID) { return state.DoS(100, error("CheckTransaction(): unknown tx version group id"), REJECT_INVALID, "bad-tx-version-group-id"); } diff --git a/src/primitives/transaction.cpp b/src/primitives/transaction.cpp index 36d46377b..4f61b24ea 100644 --- a/src/primitives/transaction.cpp +++ b/src/primitives/transaction.cpp @@ -72,23 +72,50 @@ JSDescription JSDescription::Randomized( ); } +class SproutProofVerifier : public boost::static_visitor +{ + ZCJoinSplit& params; + libzcash::ProofVerifier& verifier; + const uint256& pubKeyHash; + const JSDescription& jsdesc; + +public: + SproutProofVerifier( + ZCJoinSplit& params, + libzcash::ProofVerifier& verifier, + const uint256& pubKeyHash, + const JSDescription& jsdesc + ) : params(params), jsdesc(jsdesc), verifier(verifier), pubKeyHash(pubKeyHash) {} + + bool operator()(const libzcash::ZCProof& proof) const + { + return params.verify( + proof, + verifier, + pubKeyHash, + jsdesc.randomSeed, + jsdesc.macs, + jsdesc.nullifiers, + jsdesc.commitments, + jsdesc.vpub_old, + jsdesc.vpub_new, + jsdesc.anchor + ); + } + + bool operator()(const libzcash::GrothProof& proof) const + { + return false; + } +}; + bool JSDescription::Verify( ZCJoinSplit& params, libzcash::ProofVerifier& verifier, const uint256& pubKeyHash ) const { - return params.verify( - proof, - verifier, - pubKeyHash, - randomSeed, - macs, - nullifiers, - commitments, - vpub_old, - vpub_new, - anchor - ); + auto pv = SproutProofVerifier(params, verifier, pubKeyHash, *this); + boost::apply_visitor(pv, proof); } uint256 JSDescription::h_sig(ZCJoinSplit& params, const uint256& pubKeyHash) const @@ -146,10 +173,12 @@ std::string CTxOut::ToString() const return strprintf("CTxOut(nValue=%d.%08d, scriptPubKey=%s)", nValue / COIN, nValue % COIN, HexStr(scriptPubKey).substr(0, 30)); } -CMutableTransaction::CMutableTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), nLockTime(0) {} +CMutableTransaction::CMutableTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), nLockTime(0), valueBalance(0) {} CMutableTransaction::CMutableTransaction(const CTransaction& tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { } @@ -164,11 +193,13 @@ void CTransaction::UpdateHash() const *const_cast(&hash) = SerializeHash(*this); } -CTransaction::CTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), vin(), vout(), nLockTime(0), vjoinsplit(), joinSplitPubKey(), joinSplitSig() { } +CTransaction::CTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), vin(), vout(), nLockTime(0), valueBalance(0), vShieldedSpend(), vShieldedOutput(), vjoinsplit(), joinSplitPubKey(), joinSplitSig(), bindingSig() { } CTransaction::CTransaction(const CMutableTransaction &tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { UpdateHash(); } @@ -179,13 +210,17 @@ CTransaction::CTransaction( const CMutableTransaction &tx, bool evilDeveloperFlag) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { assert(evilDeveloperFlag); } CTransaction::CTransaction(CMutableTransaction &&tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), vin(std::move(tx.vin)), vout(std::move(tx.vout)), nLockTime(tx.nLockTime), nExpiryHeight(tx.nExpiryHeight), + valueBalance(tx.valueBalance), + vShieldedSpend(std::move(tx.vShieldedSpend)), vShieldedOutput(std::move(tx.vShieldedOutput)), vjoinsplit(std::move(tx.vjoinsplit)), joinSplitPubKey(std::move(tx.joinSplitPubKey)), joinSplitSig(std::move(tx.joinSplitSig)) { @@ -200,9 +235,13 @@ CTransaction& CTransaction::operator=(const CTransaction &tx) { *const_cast*>(&vout) = tx.vout; *const_cast(&nLockTime) = tx.nLockTime; *const_cast(&nExpiryHeight) = tx.nExpiryHeight; + *const_cast(&valueBalance) = tx.valueBalance; + *const_cast*>(&vShieldedSpend) = tx.vShieldedSpend; + *const_cast*>(&vShieldedOutput) = tx.vShieldedOutput; *const_cast*>(&vjoinsplit) = tx.vjoinsplit; *const_cast(&joinSplitPubKey) = tx.joinSplitPubKey; *const_cast(&joinSplitSig) = tx.joinSplitSig; + *const_cast(&bindingSig) = tx.bindingSig; *const_cast(&hash) = tx.hash; return *this; } @@ -279,6 +318,19 @@ std::string CTransaction::ToString() const vin.size(), vout.size(), nLockTime); + } else if (nVersion >= SAPLING_MIN_TX_VERSION) { + str += strprintf("CTransaction(hash=%s, ver=%d, fOverwintered=%d, nVersionGroupId=%08x, vin.size=%u, vout.size=%u, nLockTime=%u, nExpiryHeight=%u, valueBalance=%u, vShieldedSpend.size=%u, vShieldedOutput.size=%u)\n", + GetHash().ToString().substr(0,10), + nVersion, + fOverwintered, + nVersionGroupId, + vin.size(), + vout.size(), + nLockTime, + nExpiryHeight, + valueBalance, + vShieldedSpend.size(), + vShieldedOutput.size()); } else if (nVersion >= 3) { str += strprintf("CTransaction(hash=%s, ver=%d, fOverwintered=%d, nVersionGroupId=%08x, vin.size=%u, vout.size=%u, nLockTime=%u, nExpiryHeight=%u)\n", GetHash().ToString().substr(0,10), diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index 683b13543..fb15ffeeb 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -10,16 +10,97 @@ #include "random.h" #include "script/script.h" #include "serialize.h" +#include "streams.h" #include "uint256.h" #include "consensus/consensus.h" #include +#include #include "zcash/NoteEncryption.hpp" #include "zcash/Zcash.h" #include "zcash/JoinSplit.hpp" #include "zcash/Proof.hpp" +// Sapling transaction version +static const int32_t SAPLING_TX_VERSION = 4; +static_assert(SAPLING_TX_VERSION >= SAPLING_MIN_TX_VERSION, + "Sapling tx version must not be lower than minimum"); +static_assert(SAPLING_TX_VERSION <= SAPLING_MAX_TX_VERSION, + "Sapling tx version must not be higher than maximum"); + +static constexpr size_t GROTH_PROOF_SIZE = ( + 48 + // π_A + 96 + // π_B + 48); // π_C +static constexpr size_t SPEND_DESCRIPTION_SIZE = ( + 32 + // cv + 32 + // anchor + 32 + // nullifier + 32 + // rk + GROTH_PROOF_SIZE + + 64); // spendAuthSig +static constexpr size_t OUTPUT_DESCRIPTION_SIZE = ( + 32 + // cv + 32 + // cm + 32 + // ephemeralKey + 580 + // encCiphertext + 80 + // outCiphertext + GROTH_PROOF_SIZE); + +namespace libzcash { + typedef boost::array GrothProof; +} +typedef boost::array SpendDescription; +typedef boost::array OutputDescription; + +template +class SproutProofSerializer : public boost::static_visitor<> +{ + Stream& s; + bool useGroth; + +public: + SproutProofSerializer(Stream& s, bool useGroth) : s(s), useGroth(useGroth) {} + + void operator()(const libzcash::ZCProof& proof) const + { + if (useGroth) { + throw std::ios_base::failure("Invalid Sprout proof for transaction format (expected GrothProof, found PHGRProof)"); + } + ::Serialize(s, proof); + } + + void operator()(const libzcash::GrothProof& proof) const + { + if (!useGroth) { + throw std::ios_base::failure("Invalid Sprout proof for transaction format (expected PHGRProof, found GrothProof)"); + } + ::Serialize(s, proof); + } +}; + +template +inline void SerReadWriteSproutProof(Stream& s, const T& proof, bool useGroth, CSerActionSerialize ser_action) +{ + auto ps = SproutProofSerializer(s, useGroth); + boost::apply_visitor(ps, proof); +} + +template +inline void SerReadWriteSproutProof(Stream& s, T& proof, bool useGroth, CSerActionUnserialize ser_action) +{ + if (useGroth) { + libzcash::GrothProof grothProof; + ::Unserialize(s, grothProof); + proof = grothProof; + } else { + libzcash::ZCProof pghrProof; + ::Unserialize(s, pghrProof); + proof = pghrProof; + } +} + class JSDescription { public: @@ -66,7 +147,7 @@ public: // JoinSplit proof // This is a zk-SNARK which ensures that this JoinSplit is valid. - libzcash::ZCProof proof; + boost::variant proof; JSDescription(): vpub_old(0), vpub_new(0) { } @@ -110,6 +191,12 @@ public: template inline void SerializationOp(Stream& s, Operation ser_action) { + // nVersion is set by CTransaction and CMutableTransaction to + // (tx.fOverwintered << 31) | tx.nVersion + bool fOverwintered = s.GetVersion() >> 31; + int32_t txVersion = s.GetVersion() & 0x7FFFFFFF; + bool useGroth = fOverwintered && txVersion >= SAPLING_TX_VERSION; + READWRITE(vpub_old); READWRITE(vpub_new); READWRITE(anchor); @@ -118,7 +205,7 @@ public: READWRITE(ephemeralKey); READWRITE(randomSeed); READWRITE(macs); - READWRITE(proof); + ::SerReadWriteSproutProof(s, proof, useGroth, ser_action); READWRITE(ciphertexts); } @@ -308,6 +395,10 @@ public: static constexpr uint32_t OVERWINTER_VERSION_GROUP_ID = 0x03C48270; static_assert(OVERWINTER_VERSION_GROUP_ID != 0, "version group id must be non-zero as specified in ZIP 202"); +// Sapling version group id +static constexpr uint32_t SAPLING_VERSION_GROUP_ID = 0x892F2085; +static_assert(SAPLING_VERSION_GROUP_ID != 0, "version group id must be non-zero as specified in ZIP 202"); + struct CMutableTransaction; /** The basic transaction that is broadcasted on the network and contained in @@ -328,12 +419,15 @@ protected: public: typedef boost::array joinsplit_sig_t; + typedef boost::array binding_sig_t; // Transactions that include a list of JoinSplits are >= version 2. static const int32_t SPROUT_MIN_CURRENT_VERSION = 1; static const int32_t SPROUT_MAX_CURRENT_VERSION = 2; static const int32_t OVERWINTER_MIN_CURRENT_VERSION = 3; static const int32_t OVERWINTER_MAX_CURRENT_VERSION = 3; + static const int32_t SAPLING_MIN_CURRENT_VERSION = 4; + static const int32_t SAPLING_MAX_CURRENT_VERSION = 4; static_assert(SPROUT_MIN_CURRENT_VERSION >= SPROUT_MIN_TX_VERSION, "standard rule for tx version should be consistent with network rule"); @@ -345,6 +439,13 @@ public: OVERWINTER_MAX_CURRENT_VERSION >= OVERWINTER_MIN_CURRENT_VERSION), "standard rule for tx version should be consistent with network rule"); + static_assert(SAPLING_MIN_CURRENT_VERSION >= SAPLING_MIN_TX_VERSION, + "standard rule for tx version should be consistent with network rule"); + + static_assert( (SAPLING_MAX_CURRENT_VERSION <= SAPLING_MAX_TX_VERSION && + SAPLING_MAX_CURRENT_VERSION >= SAPLING_MIN_CURRENT_VERSION), + "standard rule for tx version should be consistent with network rule"); + // The local variables are made const to prevent unintended modification // without updating the cached hash value. However, CTransaction is not // actually immutable; deserialization and assignment are implemented, @@ -357,9 +458,13 @@ public: const std::vector vout; const uint32_t nLockTime; const uint32_t nExpiryHeight; + const CAmount valueBalance; + const std::vector vShieldedSpend; + const std::vector vShieldedOutput; const std::vector vjoinsplit; const uint256 joinSplitPubKey; const joinsplit_sig_t joinSplitSig = {{0}}; + const binding_sig_t bindingSig = {{0}}; /** Construct a CTransaction that qualifies as IsNull() */ CTransaction(); @@ -374,14 +479,14 @@ public: template inline void SerializationOp(Stream& s, Operation ser_action) { + uint32_t header; if (ser_action.ForRead()) { // When deserializing, unpack the 4 byte header to extract fOverwintered and nVersion. - uint32_t header; READWRITE(header); *const_cast(&fOverwintered) = header >> 31; *const_cast(&this->nVersion) = header & 0x7FFFFFFF; } else { - uint32_t header = GetHeader(); + header = GetHeader(); READWRITE(header); } if (fOverwintered) { @@ -391,23 +496,36 @@ public: bool isOverwinterV3 = fOverwintered && nVersionGroupId == OVERWINTER_VERSION_GROUP_ID && nVersion == 3; - if (fOverwintered && !isOverwinterV3) { + bool isSaplingV4 = + fOverwintered && + nVersionGroupId == SAPLING_VERSION_GROUP_ID && + nVersion == SAPLING_TX_VERSION; + if (fOverwintered && !(isOverwinterV3 || isSaplingV4)) { throw std::ios_base::failure("Unknown transaction format"); } READWRITE(*const_cast*>(&vin)); READWRITE(*const_cast*>(&vout)); READWRITE(*const_cast(&nLockTime)); - if (isOverwinterV3) { + if (isOverwinterV3 || isSaplingV4) { READWRITE(*const_cast(&nExpiryHeight)); } + if (isSaplingV4) { + READWRITE(*const_cast(&valueBalance)); + READWRITE(*const_cast*>(&vShieldedSpend)); + READWRITE(*const_cast*>(&vShieldedOutput)); + } if (nVersion >= 2) { - READWRITE(*const_cast*>(&vjoinsplit)); + auto os = WithVersion(&s, static_cast(header)); + ::SerReadWrite(os, *const_cast*>(&vjoinsplit), ser_action); if (vjoinsplit.size() > 0) { READWRITE(*const_cast(&joinSplitPubKey)); READWRITE(*const_cast(&joinSplitSig)); } } + if (isSaplingV4 && !(vShieldedSpend.empty() && vShieldedOutput.empty())) { + READWRITE(*const_cast(&bindingSig)); + } if (ser_action.ForRead()) UpdateHash(); } @@ -475,9 +593,13 @@ struct CMutableTransaction std::vector vout; uint32_t nLockTime; uint32_t nExpiryHeight; + CAmount valueBalance; + std::vector vShieldedSpend; + std::vector vShieldedOutput; std::vector vjoinsplit; uint256 joinSplitPubKey; CTransaction::joinsplit_sig_t joinSplitSig = {{0}}; + CTransaction::binding_sig_t bindingSig = {{0}}; CMutableTransaction(); CMutableTransaction(const CTransaction& tx); @@ -486,15 +608,15 @@ struct CMutableTransaction template inline void SerializationOp(Stream& s, Operation ser_action) { + uint32_t header; if (ser_action.ForRead()) { // When deserializing, unpack the 4 byte header to extract fOverwintered and nVersion. - uint32_t header; READWRITE(header); fOverwintered = header >> 31; this->nVersion = header & 0x7FFFFFFF; } else { // When serializing v1 and v2, the 4 byte header is nVersion - uint32_t header = this->nVersion; + header = this->nVersion; // When serializing Overwintered tx, the 4 byte header is the combination of fOverwintered and nVersion if (fOverwintered) { header |= 1 << 31; @@ -508,23 +630,36 @@ struct CMutableTransaction bool isOverwinterV3 = fOverwintered && nVersionGroupId == OVERWINTER_VERSION_GROUP_ID && nVersion == 3; - if (fOverwintered && !isOverwinterV3) { + bool isSaplingV4 = + fOverwintered && + nVersionGroupId == SAPLING_VERSION_GROUP_ID && + nVersion == SAPLING_TX_VERSION; + if (fOverwintered && !(isOverwinterV3 || isSaplingV4)) { throw std::ios_base::failure("Unknown transaction format"); } READWRITE(vin); READWRITE(vout); READWRITE(nLockTime); - if (isOverwinterV3) { + if (isOverwinterV3 || isSaplingV4) { READWRITE(nExpiryHeight); } + if (isSaplingV4) { + READWRITE(valueBalance); + READWRITE(vShieldedSpend); + READWRITE(vShieldedOutput); + } if (nVersion >= 2) { - READWRITE(vjoinsplit); + auto os = WithVersion(&s, static_cast(header)); + ::SerReadWrite(os, vjoinsplit, ser_action); if (vjoinsplit.size() > 0) { READWRITE(joinSplitPubKey); READWRITE(joinSplitSig); } } + if (isSaplingV4 && !(vShieldedSpend.empty() && vShieldedOutput.empty())) { + READWRITE(bindingSig); + } } template diff --git a/src/rpcrawtransaction.cpp b/src/rpcrawtransaction.cpp index d4307b2c5..7e03df07a 100644 --- a/src/rpcrawtransaction.cpp +++ b/src/rpcrawtransaction.cpp @@ -57,6 +57,7 @@ void ScriptPubKeyToJSON(const CScript& scriptPubKey, UniValue& out, bool fInclud UniValue TxJoinSplitToJSON(const CTransaction& tx) { + bool useGroth = tx.fOverwintered && tx.nVersion >= SAPLING_TX_VERSION; UniValue vjoinsplit(UniValue::VARR); for (unsigned int i = 0; i < tx.vjoinsplit.size(); i++) { const JSDescription& jsdescription = tx.vjoinsplit[i]; @@ -95,7 +96,8 @@ UniValue TxJoinSplitToJSON(const CTransaction& tx) { } CDataStream ssProof(SER_NETWORK, PROTOCOL_VERSION); - ssProof << jsdescription.proof; + auto ps = SproutProofSerializer(ssProof, useGroth); + boost::apply_visitor(ps, jsdescription.proof); joinsplit.push_back(Pair("proof", HexStr(ssProof.begin(), ssProof.end()))); { diff --git a/src/script/interpreter.cpp b/src/script/interpreter.cpp index e405b42f3..58534d854 100644 --- a/src/script/interpreter.cpp +++ b/src/script/interpreter.cpp @@ -1091,7 +1091,7 @@ uint256 GetOutputsHash(const CTransaction& txTo) { } uint256 GetJoinSplitsHash(const CTransaction& txTo) { - CBLAKE2bWriter ss(SER_GETHASH, 0, ZCASH_JOINSPLITS_HASH_PERSONALIZATION); + CBLAKE2bWriter ss(SER_GETHASH, static_cast(txTo.GetHeader()), ZCASH_JOINSPLITS_HASH_PERSONALIZATION); for (unsigned int n = 0; n < txTo.vjoinsplit.size(); n++) { ss << txTo.vjoinsplit[n]; } diff --git a/src/streams.h b/src/streams.h index 2fb6c7e01..44feed89e 100644 --- a/src/streams.h +++ b/src/streams.h @@ -22,6 +22,53 @@ #include #include +template +class OverrideStream +{ + Stream* stream; + + const int nType; + const int nVersion; + +public: + OverrideStream(Stream* stream_, int nType_, int nVersion_) : stream(stream_), nType(nType_), nVersion(nVersion_) {} + + template + OverrideStream& operator<<(const T& obj) + { + // Serialize to this stream + ::Serialize(*this, obj); + return (*this); + } + + template + OverrideStream& operator>>(T&& obj) + { + // Unserialize from this stream + ::Unserialize(*this, obj); + return (*this); + } + + void write(const char* pch, size_t nSize) + { + stream->write(pch, nSize); + } + + void read(char* pch, size_t nSize) + { + stream->read(pch, nSize); + } + + int GetVersion() const { return nVersion; } + int GetType() const { return nType; } +}; + +template +OverrideStream WithVersion(S* s, int nVersion) +{ + return OverrideStream(s, s->GetType(), nVersion); +} + /** Double ended buffer combining vector and stream-like interfaces. * * >> and << read and write unformatted data using the above serialization templates.