diff --git a/src/crypto/equihash.cpp b/src/crypto/equihash.cpp index 41718918a..ae6585fb0 100644 --- a/src/crypto/equihash.cpp +++ b/src/crypto/equihash.cpp @@ -48,6 +48,19 @@ int Equihash::InitialiseState(eh_HashState& base_state) personalization); } +eh_trunc TruncateIndex(eh_index i, unsigned int ilen) +{ + // Truncate to 8 bits + assert(sizeof(eh_trunc) == 1); + return (i >> (ilen - 8)) & 0xff; +} + +eh_index UntruncateIndex(eh_trunc t, eh_index r, unsigned int ilen) +{ + eh_index i{t}; + return (i << (ilen - 8)) | r; +} + StepRow::StepRow(unsigned int n, const eh_HashState& base_state, eh_index i) : hash {new unsigned char[n/8]}, len {n/8} @@ -152,6 +165,47 @@ bool DistinctIndices(const FullStepRow& a, const FullStepRow& b) return true; } +bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t) +{ + return TruncateIndex(a.indices[0], ilen) == t; +} + +TruncatedStepRow::TruncatedStepRow(unsigned int n, const eh_HashState& base_state, eh_index i, unsigned int ilen) : + StepRow {n, base_state, i}, + indices {TruncateIndex(i, ilen)} +{ + assert(indices.size() == 1); +} + +TruncatedStepRow& TruncatedStepRow::operator=(const TruncatedStepRow& a) +{ + unsigned char* p = new unsigned char[a.len]; + std::copy(a.hash, a.hash+a.len, p); + delete[] hash; + hash = p; + len = a.len; + indices = a.indices; + return *this; +} + +TruncatedStepRow& TruncatedStepRow::operator^=(const TruncatedStepRow& a) +{ + if (a.len != len) { + throw std::invalid_argument("Hash length differs"); + } + if (a.indices.size() != indices.size()) { + throw std::invalid_argument("Number of indices differs"); + } + unsigned char* p = new unsigned char[len]; + for (int i = 0; i < len; i++) + p[i] = hash[i] ^ a.hash[i]; + delete[] hash; + hash = p; + indices.reserve(indices.size() + a.indices.size()); + indices.insert(indices.end(), a.indices.begin(), a.indices.end()); + return *this; +} + Equihash::Equihash(unsigned int n, unsigned int k) : n(n), k(k) { @@ -244,6 +298,207 @@ std::set> Equihash::BasicSolve(const eh_HashState& base_st return solns; } +void CollideBranches(std::vector& X, const unsigned int clen, const unsigned int ilen, const eh_trunc lt, const eh_trunc rt) +{ + int i = 0; + int posFree = 0; + std::vector Xc; + while (i < X.size() - 1) { + // 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits + int j = 1; + while (i+j < X.size() && + HasCollision(X[i], X[i+j], clen)) { + j++; + } + + // 2c) Calculate tuples (X_i ^ X_j, (i, j)) + for (int l = 0; l < j - 1; l++) { + for (int m = l + 1; m < j; m++) { + if (DistinctIndices(X[i+l], X[i+m])) { + if (IsValidBranch(X[i+l], ilen, lt) && IsValidBranch(X[i+m], ilen, rt)) { + Xc.push_back(X[i+l] ^ X[i+m]); + Xc.back().TrimHash(clen); + } else if (IsValidBranch(X[i+m], ilen, lt) && IsValidBranch(X[i+l], ilen, rt)) { + Xc.push_back(X[i+m] ^ X[i+l]); + Xc.back().TrimHash(clen); + } + } + } + } + + // 2d) Store tuples on the table in-place if possible + while (posFree < i+j && Xc.size() > 0) { + X[posFree++] = Xc.back(); + Xc.pop_back(); + } + + i += j; + } + + // 2e) Handle edge case where final table entry has no collision + while (posFree < X.size() && Xc.size() > 0) { + X[posFree++] = Xc.back(); + Xc.pop_back(); + } + + if (Xc.size() > 0) { + // 2f) Add overflow to end of table + X.insert(X.end(), Xc.begin(), Xc.end()); + } else if (posFree < X.size()) { + // 2g) Remove empty space at the end + X.erase(X.begin()+posFree, X.end()); + X.shrink_to_fit(); + } +} + +std::set> Equihash::OptimisedSolve(const eh_HashState& base_state) +{ + assert(CollisionBitLength() + 1 < 8*sizeof(eh_index)); + eh_index init_size { 1 << (CollisionBitLength() + 1) }; + + // First run the algorithm with truncated indices + + std::vector> partialSolns; + { + + // 1) Generate first list + LogPrint("pow", "Generating first list\n"); + std::vector Xt; + Xt.reserve(init_size); + for (eh_index i = 0; i < init_size; i++) { + Xt.emplace_back(n, base_state, i, CollisionBitLength() + 1); + } + + // 3) Repeat step 2 until 2n/(k+1) bits remain + for (int r = 1; r < k && Xt.size() > 0; r++) { + LogPrint("pow", "Round %d:\n", r); + // 2a) Sort the list + LogPrint("pow", "- Sorting list\n"); + std::sort(Xt.begin(), Xt.end()); + + LogPrint("pow", "- Finding collisions\n"); + int i = 0; + int posFree = 0; + std::vector Xc; + while (i < Xt.size() - 1) { + // 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits + int j = 1; + while (i+j < Xt.size() && + HasCollision(Xt[i], Xt[i+j], CollisionByteLength())) { + j++; + } + + // 2c) Calculate tuples (X_i ^ X_j, (i, j)) + for (int l = 0; l < j - 1; l++) { + for (int m = l + 1; m < j; m++) { + // We truncated, so don't check for distinct indices here + Xc.push_back(Xt[i+l] ^ Xt[i+m]); + Xc.back().TrimHash(CollisionByteLength()); + } + } + + // 2d) Store tuples on the table in-place if possible + while (posFree < i+j && Xc.size() > 0) { + Xt[posFree++] = Xc.back(); + Xc.pop_back(); + } + + i += j; + } + + // 2e) Handle edge case where final table entry has no collision + while (posFree < Xt.size() && Xc.size() > 0) { + Xt[posFree++] = Xc.back(); + Xc.pop_back(); + } + + if (Xc.size() > 0) { + // 2f) Add overflow to end of table + Xt.insert(Xt.end(), Xc.begin(), Xc.end()); + } else if (posFree < Xt.size()) { + // 2g) Remove empty space at the end + Xt.erase(Xt.begin()+posFree, Xt.end()); + Xt.shrink_to_fit(); + } + } + + // k+1) Find a collision on last 2n(k+1) bits + LogPrint("pow", "Final round:\n"); + if (Xt.size() > 1) { + LogPrint("pow", "- Sorting list\n"); + std::sort(Xt.begin(), Xt.end()); + LogPrint("pow", "- Finding collisions\n"); + for (int i = 0; i < Xt.size() - 1; i++) { + TruncatedStepRow res = Xt[i] ^ Xt[i+1]; + if (res.IsZero()) { + partialSolns.push_back(res.GetPartialSolution()); + } + } + } else + LogPrint("pow", "- List is empty\n"); + + } // Ensure Xt goes out of scope and is destroyed + + LogPrint("pow", "Found %d partial solutions\n", partialSolns.size()); + + // Now for each solution run the algorithm again to recreate the indices + LogPrint("pow", "Culling solutions\n"); + std::set> solns; + eh_index recreate_size { UntruncateIndex(1, 0, CollisionBitLength() + 1) }; + int invalidCount = 0; + for (std::vector partialSoln : partialSolns) { + // 1) Generate first list of possibilities + std::vector> X; + X.reserve(partialSoln.size()); + for (int i = 0; i < partialSoln.size(); i++) { + std::vector ic; + ic.reserve(recreate_size); + for (eh_index j = 0; j < recreate_size; j++) { + eh_index newIndex { UntruncateIndex(partialSoln[i], j, CollisionBitLength() + 1) }; + ic.emplace_back(n, base_state, newIndex); + } + X.push_back(ic); + } + + // 3) Repeat step 2 for each level of the tree + for (int r = 0; X.size() > 1; r++) { + std::vector> Xc; + Xc.reserve(X.size()/2); + + // 2a) For each pair of lists: + for (int v = 0; v < X.size(); v += 2) { + // 2b) Merge the lists + std::vector ic(X[v]); + ic.reserve(X[v].size() + X[v+1].size()); + ic.insert(ic.end(), X[v+1].begin(), X[v+1].end()); + std::sort(ic.begin(), ic.end()); + CollideBranches(ic, CollisionByteLength(), CollisionBitLength() + 1, partialSoln[(1< soln) { eh_index soln_size { 1u << k }; diff --git a/src/crypto/equihash.h b/src/crypto/equihash.h index 2603024ae..32ae011c5 100644 --- a/src/crypto/equihash.h +++ b/src/crypto/equihash.h @@ -17,6 +17,7 @@ typedef crypto_generichash_blake2b_state eh_HashState; typedef uint32_t eh_index; +typedef uint8_t eh_trunc; struct invalid_params { }; @@ -66,9 +67,33 @@ public: } friend bool DistinctIndices(const FullStepRow& a, const FullStepRow& b); + friend bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t); }; bool DistinctIndices(const FullStepRow& a, const FullStepRow& b); +bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t); + +class TruncatedStepRow : public StepRow +{ +private: + std::vector indices; + +public: + TruncatedStepRow(unsigned int n, const eh_HashState& base_state, eh_index i, unsigned int ilen); + ~TruncatedStepRow() { } + + TruncatedStepRow(const TruncatedStepRow& a) : StepRow {a}, indices(a.indices) { } + TruncatedStepRow& operator=(const TruncatedStepRow& a); + TruncatedStepRow& operator^=(const TruncatedStepRow& a); + + bool IndicesBefore(const TruncatedStepRow& a) { return indices[0] < a.indices[0]; } + std::vector GetPartialSolution() { return std::vector(indices); } + + friend inline const TruncatedStepRow operator^(const TruncatedStepRow& a, const TruncatedStepRow& b) { + if (a.indices[0] < b.indices[0]) { return TruncatedStepRow(a) ^= b; } + else { return TruncatedStepRow(b) ^= a; } + } +}; class Equihash { @@ -84,6 +109,7 @@ public: int InitialiseState(eh_HashState& base_state); std::set> BasicSolve(const eh_HashState& base_state); + std::set> OptimisedSolve(const eh_HashState& base_state); bool IsValidSolution(const eh_HashState& base_state, std::vector soln); }; diff --git a/src/test/equihash_tests.cpp b/src/test/equihash_tests.cpp index 2066af919..0d28c2771 100644 --- a/src/test/equihash_tests.cpp +++ b/src/test/equihash_tests.cpp @@ -56,6 +56,15 @@ void TestEquihashSolvers(unsigned int n, unsigned int k, const std::string &I, c PrintSolutions(strm, ret); BOOST_TEST_MESSAGE(strm.str()); BOOST_CHECK(ret == solns); + + // The optimised solver should have the exact same result + std::set> retOpt = eh.OptimisedSolve(state); + BOOST_TEST_MESSAGE("[Optimised] Number of solutions: " << retOpt.size()); + strm.str(""); + PrintSolutions(strm, retOpt); + BOOST_TEST_MESSAGE(strm.str()); + BOOST_CHECK(retOpt == solns); + BOOST_CHECK(retOpt == ret); } void TestEquihashValidator(unsigned int n, unsigned int k, const std::string &I, const arith_uint256 &nonce, std::vector soln, bool expected) {