diff --git a/src/zcash/Proof.cpp b/src/zcash/Proof.cpp index d4fd90b69..ab7c64483 100644 --- a/src/zcash/Proof.cpp +++ b/src/zcash/Proof.cpp @@ -20,43 +20,16 @@ BOOST_STATIC_ASSERT(sizeof(mp_limb_t) == 8); namespace libzcash { +// FE2IP as defined in the protocol spec and IEEE Std 1363a-2004. bigint<8> fq2_to_bigint(const curve_Fq2 &e) { auto modq = curve_Fq::field_char(); auto c0 = e.c0.as_bigint(); auto c1 = e.c1.as_bigint(); - // TODO: It should be possible to use libsnark's bigint - // to do this stuff. - - bigint<8> res; - // Multiply c1 by modq - mpn_mul(res.data, c1.data, 4, modq.data, 4); - // Add c0 - mpn_add(res.data, res.data, 8, c0.data, 4); - - return res; -} - -// Compares two bigints, returning 0 if equal, 1 if a > b, and -1 if a < b -template -int cmp_bigint(const bigint &a, const bigint &b) -{ - for (ssize_t i = LIMBS-1; i >= 0; i--) { - if (a.data[i] < b.data[i]) { - return -1; - } else if (a.data[i] > b.data[i]) { - return 1; - } - } - - return 0; -} - -// Returns whether a > b -bool cmp_fq2(const curve_Fq2 &a, const curve_Fq2 &b) -{ - return cmp_bigint(fq2_to_bigint(a), fq2_to_bigint(b)) > 0; + bigint<8> temp = c1 * modq; + temp += c0; + return temp; } // Writes a bigint in big endian @@ -87,7 +60,7 @@ bigint read_bigint(const base_blob<8 * LIMBS * sizeof(mp_limb_t)> &blob) template<> Fq::Fq(curve_Fq element) : data() { - write_bigint(data, element.as_bigint()); + write_bigint<4>(data, element.as_bigint()); } template<> @@ -97,9 +70,7 @@ curve_Fq Fq::to_libsnark_fq() const // Check that the integer is smaller than the modulus auto modq = curve_Fq::field_char(); - if (cmp_bigint(element_bigint, modq) != -1) { - throw std::logic_error("element is not in Fq"); - } + element_bigint.limit(modq, "element is not in Fq"); return curve_Fq(element_bigint); } @@ -107,33 +78,18 @@ curve_Fq Fq::to_libsnark_fq() const template<> Fq2::Fq2(curve_Fq2 element) : data() { - write_bigint(data, fq2_to_bigint(element)); + write_bigint<8>(data, fq2_to_bigint(element)); } template<> curve_Fq2 Fq2::to_libsnark_fq2() const { - auto modq = curve_Fq::field_char(); - auto combined = read_bigint<8>(data); - - // TODO: It should be possible to use libsnark's bigint - // to do this stuff. - + bigint<4> modq = curve_Fq::field_char(); + bigint<8> combined = read_bigint<8>(data); bigint<5> res; bigint<4> c0; - - mpn_tdiv_qr(res.data, c0.data, 0, combined.data, 8, modq.data, 4); - - if (res.data[4] != 0) { - throw std::logic_error("element is not in Fq2"); - } - - bigint<4> c1; - memcpy(c1.data, res.data, 4 * sizeof(mp_limb_t)); - - if (cmp_bigint(c1, modq) != -1) { - throw std::logic_error("element is not in Fq2"); - } + bigint<8>::div_qr(res, c0, combined, modq); + bigint<4> c1 = res.shorten(modq, "element is not in Fq2"); return curve_Fq2(curve_Fq(c0), curve_Fq(c1)); } @@ -183,7 +139,7 @@ CompressedG2::CompressedG2(curve_G2 point) point.to_affine_coordinates(); x = Fq2(point.X); - y_gt = cmp_fq2(point.Y, -(point.Y)); + y_gt = fq2_to_bigint(point.Y) > fq2_to_bigint(-(point.Y)); } template<> @@ -195,7 +151,7 @@ curve_G2 CompressedG2::to_libsnark_g2() const auto y_coordinate = ((x_coordinate.squared() * x_coordinate) + alt_bn128_twist_coeff_b).sqrt(); auto y_coordinate_neg = -y_coordinate; - if (cmp_fq2(y_coordinate, y_coordinate_neg) != y_gt) { + if ((fq2_to_bigint(y_coordinate) > fq2_to_bigint(y_coordinate_neg)) != y_gt) { y_coordinate = y_coordinate_neg; }