From b7e75b17af5159cb5cb3233a3944a36879b40ef1 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Sun, 15 Apr 2018 08:53:40 -0600 Subject: [PATCH] Implement basic Sapling v4 transaction parser Details of Sapling datatypes will be filled in later; for now, they are treated as binary blobs. Includes code cherry-picked from upstream commit: 7030d9eb47254499bba14f1c00abc6bf493efd91 BIP144: Serialization, hashes, relay (sender side) --- src/hash.h | 3 + src/primitives/transaction.cpp | 86 +++++++++++++++---- src/primitives/transaction.h | 146 ++++++++++++++++++++++++++++++--- src/rpcrawtransaction.cpp | 4 +- src/script/interpreter.cpp | 2 +- src/streams.h | 47 +++++++++++ 6 files changed, 257 insertions(+), 31 deletions(-) diff --git a/src/hash.h b/src/hash.h index 76634bd37..3e56f4a86 100644 --- a/src/hash.h +++ b/src/hash.h @@ -182,6 +182,9 @@ public: personal) == 0); } + int GetType() const { return nType; } + int GetVersion() const { return nVersion; } + CBLAKE2bWriter& write(const char *pch, size_t size) { crypto_generichash_blake2b_update(&state, (const unsigned char*)pch, size); return (*this); diff --git a/src/primitives/transaction.cpp b/src/primitives/transaction.cpp index 36d46377b..4f61b24ea 100644 --- a/src/primitives/transaction.cpp +++ b/src/primitives/transaction.cpp @@ -72,23 +72,50 @@ JSDescription JSDescription::Randomized( ); } +class SproutProofVerifier : public boost::static_visitor +{ + ZCJoinSplit& params; + libzcash::ProofVerifier& verifier; + const uint256& pubKeyHash; + const JSDescription& jsdesc; + +public: + SproutProofVerifier( + ZCJoinSplit& params, + libzcash::ProofVerifier& verifier, + const uint256& pubKeyHash, + const JSDescription& jsdesc + ) : params(params), jsdesc(jsdesc), verifier(verifier), pubKeyHash(pubKeyHash) {} + + bool operator()(const libzcash::ZCProof& proof) const + { + return params.verify( + proof, + verifier, + pubKeyHash, + jsdesc.randomSeed, + jsdesc.macs, + jsdesc.nullifiers, + jsdesc.commitments, + jsdesc.vpub_old, + jsdesc.vpub_new, + jsdesc.anchor + ); + } + + bool operator()(const libzcash::GrothProof& proof) const + { + return false; + } +}; + bool JSDescription::Verify( ZCJoinSplit& params, libzcash::ProofVerifier& verifier, const uint256& pubKeyHash ) const { - return params.verify( - proof, - verifier, - pubKeyHash, - randomSeed, - macs, - nullifiers, - commitments, - vpub_old, - vpub_new, - anchor - ); + auto pv = SproutProofVerifier(params, verifier, pubKeyHash, *this); + boost::apply_visitor(pv, proof); } uint256 JSDescription::h_sig(ZCJoinSplit& params, const uint256& pubKeyHash) const @@ -146,10 +173,12 @@ std::string CTxOut::ToString() const return strprintf("CTxOut(nValue=%d.%08d, scriptPubKey=%s)", nValue / COIN, nValue % COIN, HexStr(scriptPubKey).substr(0, 30)); } -CMutableTransaction::CMutableTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), nLockTime(0) {} +CMutableTransaction::CMutableTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), nLockTime(0), valueBalance(0) {} CMutableTransaction::CMutableTransaction(const CTransaction& tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { } @@ -164,11 +193,13 @@ void CTransaction::UpdateHash() const *const_cast(&hash) = SerializeHash(*this); } -CTransaction::CTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), vin(), vout(), nLockTime(0), vjoinsplit(), joinSplitPubKey(), joinSplitSig() { } +CTransaction::CTransaction() : nVersion(CTransaction::SPROUT_MIN_CURRENT_VERSION), fOverwintered(false), nVersionGroupId(0), nExpiryHeight(0), vin(), vout(), nLockTime(0), valueBalance(0), vShieldedSpend(), vShieldedOutput(), vjoinsplit(), joinSplitPubKey(), joinSplitSig(), bindingSig() { } CTransaction::CTransaction(const CMutableTransaction &tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { UpdateHash(); } @@ -179,13 +210,17 @@ CTransaction::CTransaction( const CMutableTransaction &tx, bool evilDeveloperFlag) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), nExpiryHeight(tx.nExpiryHeight), vin(tx.vin), vout(tx.vout), nLockTime(tx.nLockTime), - vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig) + valueBalance(tx.valueBalance), vShieldedSpend(tx.vShieldedSpend), vShieldedOutput(tx.vShieldedOutput), + vjoinsplit(tx.vjoinsplit), joinSplitPubKey(tx.joinSplitPubKey), joinSplitSig(tx.joinSplitSig), + bindingSig(tx.bindingSig) { assert(evilDeveloperFlag); } CTransaction::CTransaction(CMutableTransaction &&tx) : nVersion(tx.nVersion), fOverwintered(tx.fOverwintered), nVersionGroupId(tx.nVersionGroupId), vin(std::move(tx.vin)), vout(std::move(tx.vout)), nLockTime(tx.nLockTime), nExpiryHeight(tx.nExpiryHeight), + valueBalance(tx.valueBalance), + vShieldedSpend(std::move(tx.vShieldedSpend)), vShieldedOutput(std::move(tx.vShieldedOutput)), vjoinsplit(std::move(tx.vjoinsplit)), joinSplitPubKey(std::move(tx.joinSplitPubKey)), joinSplitSig(std::move(tx.joinSplitSig)) { @@ -200,9 +235,13 @@ CTransaction& CTransaction::operator=(const CTransaction &tx) { *const_cast*>(&vout) = tx.vout; *const_cast(&nLockTime) = tx.nLockTime; *const_cast(&nExpiryHeight) = tx.nExpiryHeight; + *const_cast(&valueBalance) = tx.valueBalance; + *const_cast*>(&vShieldedSpend) = tx.vShieldedSpend; + *const_cast*>(&vShieldedOutput) = tx.vShieldedOutput; *const_cast*>(&vjoinsplit) = tx.vjoinsplit; *const_cast(&joinSplitPubKey) = tx.joinSplitPubKey; *const_cast(&joinSplitSig) = tx.joinSplitSig; + *const_cast(&bindingSig) = tx.bindingSig; *const_cast(&hash) = tx.hash; return *this; } @@ -279,6 +318,19 @@ std::string CTransaction::ToString() const vin.size(), vout.size(), nLockTime); + } else if (nVersion >= SAPLING_MIN_TX_VERSION) { + str += strprintf("CTransaction(hash=%s, ver=%d, fOverwintered=%d, nVersionGroupId=%08x, vin.size=%u, vout.size=%u, nLockTime=%u, nExpiryHeight=%u, valueBalance=%u, vShieldedSpend.size=%u, vShieldedOutput.size=%u)\n", + GetHash().ToString().substr(0,10), + nVersion, + fOverwintered, + nVersionGroupId, + vin.size(), + vout.size(), + nLockTime, + nExpiryHeight, + valueBalance, + vShieldedSpend.size(), + vShieldedOutput.size()); } else if (nVersion >= 3) { str += strprintf("CTransaction(hash=%s, ver=%d, fOverwintered=%d, nVersionGroupId=%08x, vin.size=%u, vout.size=%u, nLockTime=%u, nExpiryHeight=%u)\n", GetHash().ToString().substr(0,10), diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index 1ec961e1c..4c8aaa4ac 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -10,16 +10,97 @@ #include "random.h" #include "script/script.h" #include "serialize.h" +#include "streams.h" #include "uint256.h" #include "consensus/consensus.h" #include +#include #include "zcash/NoteEncryption.hpp" #include "zcash/Zcash.h" #include "zcash/JoinSplit.hpp" #include "zcash/Proof.hpp" +// Sapling transaction version +static const int32_t SAPLING_TX_VERSION = 4; +static_assert(SAPLING_TX_VERSION >= SAPLING_MIN_TX_VERSION, + "Sapling tx version must not be lower than minimum"); +static_assert(SAPLING_TX_VERSION <= SAPLING_MAX_TX_VERSION, + "Sapling tx version must not be higher than maximum"); + +static constexpr size_t GROTH_PROOF_SIZE = ( + 48 + // π_A + 96 + // π_B + 48); // π_C +static constexpr size_t SPEND_DESCRIPTION_SIZE = ( + 32 + // cv + 32 + // anchor + 32 + // nullifier + 32 + // rk + GROTH_PROOF_SIZE + + 64); // spendAuthSig +static constexpr size_t OUTPUT_DESCRIPTION_SIZE = ( + 32 + // cv + 32 + // cm + 32 + // ephemeralKey + 580 + // encCiphertext + 80 + // outCiphertext + GROTH_PROOF_SIZE); + +namespace libzcash { + typedef boost::array GrothProof; +} +typedef boost::array SpendDescription; +typedef boost::array OutputDescription; + +template +class SproutProofSerializer : public boost::static_visitor<> +{ + Stream& s; + bool useGroth; + +public: + SproutProofSerializer(Stream& s, bool useGroth) : s(s), useGroth(useGroth) {} + + void operator()(const libzcash::ZCProof& proof) const + { + if (useGroth) { + throw std::ios_base::failure("Invalid Sprout proof for transaction format (expected GrothProof, found PHGRProof)"); + } + ::Serialize(s, proof); + } + + void operator()(const libzcash::GrothProof& proof) const + { + if (!useGroth) { + throw std::ios_base::failure("Invalid Sprout proof for transaction format (expected PHGRProof, found GrothProof)"); + } + ::Serialize(s, proof); + } +}; + +template +inline void SerReadWriteSproutProof(Stream& s, const T& proof, bool useGroth, CSerActionSerialize ser_action) +{ + auto ps = SproutProofSerializer(s, useGroth); + boost::apply_visitor(ps, proof); +} + +template +inline void SerReadWriteSproutProof(Stream& s, T& proof, bool useGroth, CSerActionUnserialize ser_action) +{ + if (useGroth) { + libzcash::GrothProof grothProof; + ::Unserialize(s, grothProof); + proof = grothProof; + } else { + libzcash::ZCProof pghrProof; + ::Unserialize(s, pghrProof); + proof = pghrProof; + } +} + class JSDescription { public: @@ -66,7 +147,7 @@ public: // JoinSplit proof // This is a zk-SNARK which ensures that this JoinSplit is valid. - libzcash::ZCProof proof; + boost::variant proof; JSDescription(): vpub_old(0), vpub_new(0) { } @@ -110,6 +191,12 @@ public: template inline void SerializationOp(Stream& s, Operation ser_action) { + // nVersion is set by CTransaction and CMutableTransaction to + // (tx.fOverwintered << 31) | tx.nVersion + bool fOverwintered = s.GetVersion() >> 31; + int32_t txVersion = s.GetVersion() & 0x7FFFFFFF; + bool useGroth = fOverwintered && txVersion >= SAPLING_TX_VERSION; + READWRITE(vpub_old); READWRITE(vpub_new); READWRITE(anchor); @@ -118,7 +205,7 @@ public: READWRITE(ephemeralKey); READWRITE(randomSeed); READWRITE(macs); - READWRITE(proof); + ::SerReadWriteSproutProof(s, proof, useGroth, ser_action); READWRITE(ciphertexts); } @@ -332,6 +419,7 @@ protected: public: typedef boost::array joinsplit_sig_t; + typedef boost::array binding_sig_t; // Transactions that include a list of JoinSplits are >= version 2. static const int32_t SPROUT_MIN_CURRENT_VERSION = 1; @@ -361,9 +449,13 @@ public: const std::vector vout; const uint32_t nLockTime; const uint32_t nExpiryHeight; + const CAmount valueBalance; + const std::vector vShieldedSpend; + const std::vector vShieldedOutput; const std::vector vjoinsplit; const uint256 joinSplitPubKey; const joinsplit_sig_t joinSplitSig = {{0}}; + const binding_sig_t bindingSig = {{0}}; /** Construct a CTransaction that qualifies as IsNull() */ CTransaction(); @@ -378,14 +470,14 @@ public: template inline void SerializationOp(Stream& s, Operation ser_action) { + uint32_t header; if (ser_action.ForRead()) { // When deserializing, unpack the 4 byte header to extract fOverwintered and nVersion. - uint32_t header; READWRITE(header); *const_cast(&fOverwintered) = header >> 31; *const_cast(&this->nVersion) = header & 0x7FFFFFFF; } else { - uint32_t header = GetHeader(); + header = GetHeader(); READWRITE(header); } if (fOverwintered) { @@ -395,23 +487,36 @@ public: bool isOverwinterV3 = fOverwintered && nVersionGroupId == OVERWINTER_VERSION_GROUP_ID && nVersion == 3; - if (fOverwintered && !isOverwinterV3) { + bool isSaplingV4 = + fOverwintered && + nVersionGroupId == SAPLING_VERSION_GROUP_ID && + nVersion == SAPLING_TX_VERSION; + if (fOverwintered && !(isOverwinterV3 || isSaplingV4)) { throw std::ios_base::failure("Unknown transaction format"); } READWRITE(*const_cast*>(&vin)); READWRITE(*const_cast*>(&vout)); READWRITE(*const_cast(&nLockTime)); - if (isOverwinterV3) { + if (isOverwinterV3 || isSaplingV4) { READWRITE(*const_cast(&nExpiryHeight)); } + if (isSaplingV4) { + READWRITE(*const_cast(&valueBalance)); + READWRITE(*const_cast*>(&vShieldedSpend)); + READWRITE(*const_cast*>(&vShieldedOutput)); + } if (nVersion >= 2) { - READWRITE(*const_cast*>(&vjoinsplit)); + auto os = WithVersion(&s, static_cast(header)); + ::SerReadWrite(os, *const_cast*>(&vjoinsplit), ser_action); if (vjoinsplit.size() > 0) { READWRITE(*const_cast(&joinSplitPubKey)); READWRITE(*const_cast(&joinSplitSig)); } } + if (isSaplingV4 && !(vShieldedSpend.empty() && vShieldedOutput.empty())) { + READWRITE(*const_cast(&bindingSig)); + } if (ser_action.ForRead()) UpdateHash(); } @@ -479,9 +584,13 @@ struct CMutableTransaction std::vector vout; uint32_t nLockTime; uint32_t nExpiryHeight; + CAmount valueBalance; + std::vector vShieldedSpend; + std::vector vShieldedOutput; std::vector vjoinsplit; uint256 joinSplitPubKey; CTransaction::joinsplit_sig_t joinSplitSig = {{0}}; + CTransaction::binding_sig_t bindingSig = {{0}}; CMutableTransaction(); CMutableTransaction(const CTransaction& tx); @@ -490,15 +599,15 @@ struct CMutableTransaction template inline void SerializationOp(Stream& s, Operation ser_action) { + uint32_t header; if (ser_action.ForRead()) { // When deserializing, unpack the 4 byte header to extract fOverwintered and nVersion. - uint32_t header; READWRITE(header); fOverwintered = header >> 31; this->nVersion = header & 0x7FFFFFFF; } else { // When serializing v1 and v2, the 4 byte header is nVersion - uint32_t header = this->nVersion; + header = this->nVersion; // When serializing Overwintered tx, the 4 byte header is the combination of fOverwintered and nVersion if (fOverwintered) { header |= 1 << 31; @@ -512,23 +621,36 @@ struct CMutableTransaction bool isOverwinterV3 = fOverwintered && nVersionGroupId == OVERWINTER_VERSION_GROUP_ID && nVersion == 3; - if (fOverwintered && !isOverwinterV3) { + bool isSaplingV4 = + fOverwintered && + nVersionGroupId == SAPLING_VERSION_GROUP_ID && + nVersion == SAPLING_TX_VERSION; + if (fOverwintered && !(isOverwinterV3 || isSaplingV4)) { throw std::ios_base::failure("Unknown transaction format"); } READWRITE(vin); READWRITE(vout); READWRITE(nLockTime); - if (isOverwinterV3) { + if (isOverwinterV3 || isSaplingV4) { READWRITE(nExpiryHeight); } + if (isSaplingV4) { + READWRITE(valueBalance); + READWRITE(vShieldedSpend); + READWRITE(vShieldedOutput); + } if (nVersion >= 2) { - READWRITE(vjoinsplit); + auto os = WithVersion(&s, static_cast(header)); + ::SerReadWrite(os, vjoinsplit, ser_action); if (vjoinsplit.size() > 0) { READWRITE(joinSplitPubKey); READWRITE(joinSplitSig); } } + if (isSaplingV4 && !(vShieldedSpend.empty() && vShieldedOutput.empty())) { + READWRITE(bindingSig); + } } template diff --git a/src/rpcrawtransaction.cpp b/src/rpcrawtransaction.cpp index d4307b2c5..7e03df07a 100644 --- a/src/rpcrawtransaction.cpp +++ b/src/rpcrawtransaction.cpp @@ -57,6 +57,7 @@ void ScriptPubKeyToJSON(const CScript& scriptPubKey, UniValue& out, bool fInclud UniValue TxJoinSplitToJSON(const CTransaction& tx) { + bool useGroth = tx.fOverwintered && tx.nVersion >= SAPLING_TX_VERSION; UniValue vjoinsplit(UniValue::VARR); for (unsigned int i = 0; i < tx.vjoinsplit.size(); i++) { const JSDescription& jsdescription = tx.vjoinsplit[i]; @@ -95,7 +96,8 @@ UniValue TxJoinSplitToJSON(const CTransaction& tx) { } CDataStream ssProof(SER_NETWORK, PROTOCOL_VERSION); - ssProof << jsdescription.proof; + auto ps = SproutProofSerializer(ssProof, useGroth); + boost::apply_visitor(ps, jsdescription.proof); joinsplit.push_back(Pair("proof", HexStr(ssProof.begin(), ssProof.end()))); { diff --git a/src/script/interpreter.cpp b/src/script/interpreter.cpp index e405b42f3..790ba1e73 100644 --- a/src/script/interpreter.cpp +++ b/src/script/interpreter.cpp @@ -1165,7 +1165,7 @@ uint256 SignatureHash( memcpy(personalization, "ZcashSigHash", 12); memcpy(personalization+12, &leConsensusBranchId, 4); - CBLAKE2bWriter ss(SER_GETHASH, 0, personalization); + CBLAKE2bWriter ss(SER_GETHASH, static_cast(txTo.GetHeader()), personalization); // Header ss << txTo.GetHeader(); // Version group ID diff --git a/src/streams.h b/src/streams.h index 2fb6c7e01..44feed89e 100644 --- a/src/streams.h +++ b/src/streams.h @@ -22,6 +22,53 @@ #include #include +template +class OverrideStream +{ + Stream* stream; + + const int nType; + const int nVersion; + +public: + OverrideStream(Stream* stream_, int nType_, int nVersion_) : stream(stream_), nType(nType_), nVersion(nVersion_) {} + + template + OverrideStream& operator<<(const T& obj) + { + // Serialize to this stream + ::Serialize(*this, obj); + return (*this); + } + + template + OverrideStream& operator>>(T&& obj) + { + // Unserialize from this stream + ::Unserialize(*this, obj); + return (*this); + } + + void write(const char* pch, size_t nSize) + { + stream->write(pch, nSize); + } + + void read(char* pch, size_t nSize) + { + stream->read(pch, nSize); + } + + int GetVersion() const { return nVersion; } + int GetType() const { return nType; } +}; + +template +OverrideStream WithVersion(S* s, int nVersion) +{ + return OverrideStream(s, s->GetType(), nVersion); +} + /** Double ended buffer combining vector and stream-like interfaces. * * >> and << read and write unformatted data using the above serialization templates.