diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index d68c9a98a..d247e3e96 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -17,6 +17,7 @@ zcash_gtest_SOURCES = \ gtest/test_noteencryption.cpp \ gtest/test_merkletree.cpp \ gtest/test_pow.cpp \ + gtest/test_random.cpp \ gtest/test_rpc.cpp \ gtest/test_circuit.cpp \ gtest/test_txid.cpp \ diff --git a/src/gtest/test_random.cpp b/src/gtest/test_random.cpp new file mode 100644 index 000000000..e4a2f5c41 --- /dev/null +++ b/src/gtest/test_random.cpp @@ -0,0 +1,34 @@ +#include + +#include "random.h" + +int GenZero(int n) +{ + return 0; +} + +int GenMax(int n) +{ + return n-1; +} + +TEST(Random, MappedShuffle) { + std::vector a {8, 4, 6, 3, 5}; + std::vector m {0, 1, 2, 3, 4}; + + auto a1 = a; + auto m1 = m; + MappedShuffle(a1.begin(), m1.begin(), a1.size(), GenZero); + std::vector ea1 {4, 6, 3, 5, 8}; + std::vector em1 {1, 2, 3, 4, 0}; + EXPECT_EQ(ea1, a1); + EXPECT_EQ(em1, m1); + + auto a2 = a; + auto m2 = m; + MappedShuffle(a2.begin(), m2.begin(), a2.size(), GenMax); + std::vector ea2 {8, 4, 6, 3, 5}; + std::vector em2 {0, 1, 2, 3, 4}; + EXPECT_EQ(ea2, a2); + EXPECT_EQ(em2, m2); +} diff --git a/src/random.h b/src/random.h index 1a2d3e8ee..11be2ee37 100644 --- a/src/random.h +++ b/src/random.h @@ -8,6 +8,7 @@ #include "uint256.h" +#include #include /** @@ -24,6 +25,29 @@ uint64_t GetRand(uint64_t nMax); int GetRandInt(int nMax); uint256 GetRandHash(); +/** + * Rearranges the elements in the range [first,first+len) randomly, assuming + * that gen is a uniform random number generator. Follows the same algorithm as + * std::shuffle in C++11 (a Durstenfeld shuffle). + * + * The elements in the range [mapFirst,mapFirst+len) are rearranged according to + * the same permutation, enabling the permutation to be tracked by the caller. + * + * gen takes an integer n and produces a uniform random output in [0,n). + */ +template +void MappedShuffle(RandomAccessIterator first, + MapRandomAccessIterator mapFirst, + size_t len, + std::function gen) +{ + for (size_t i = len-1; i > 0; --i) { + auto r = gen(i+1); + std::swap(first[i], first[r]); + std::swap(mapFirst[i], mapFirst[r]); + } +} + /** * Seed insecure_rand using the random pool. * @param Deterministic Use a deterministic seed