Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
uintx_impl.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
8#include "./uintx.hpp"
10
11namespace bb::numeric {
12template <class base_uint>
14
15{
16 ASSERT(b != 0);
17 if (*this == 0) {
18 return { uintx(0), uintx(0) };
19 }
20 if (b == 1) {
21 return { *this, uintx(0) };
22 }
23 if (*this == b) {
24 return { uintx(1), uintx(0) };
25 }
26 if (b > *this) {
27 return { uintx(0), *this };
28 }
29
30 uintx quotient(0);
31 uintx remainder = *this;
32
33 uint64_t bit_difference = get_msb() - b.get_msb();
34
35 uintx divisor = b << bit_difference;
36 uintx accumulator = uintx(1) << bit_difference;
37
38 // if the divisor is bigger than the remainder, a and b have the same bit length
39 if (divisor > remainder) {
40 divisor >>= 1;
41 accumulator >>= 1;
42 }
43
44 // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
45 // and add to the quotient
46 while (remainder >= b) {
47
48 // we've shunted 'divisor' up to have the same bit length as our remainder.
49 // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
50 if (remainder >= divisor) {
51 remainder -= divisor;
52 // we can use OR here instead of +, as
53 // accumulator is always a nice power of two
54 quotient |= accumulator;
55 }
56 divisor >>= 1;
57 accumulator >>= 1;
58 }
59
60 return std::make_pair(quotient, remainder);
61}
62
72template <class base_uint> uintx<base_uint> uintx<base_uint>::unsafe_invmod(const uintx& modulus) const
73{
74
75 uintx t1 = 0;
76 uintx t2 = 1;
77 uintx r2 = (*this > modulus) ? *this % modulus : *this;
78 uintx r1 = modulus;
79 uintx q = 0;
80 while (r2 != 0) {
81 q = r1 / r2;
82 uintx temp_t1 = t1;
83 uintx temp_r1 = r1;
84 t1 = t2;
85 t2 = temp_t1 - q * t2;
86 r1 = r2;
87 r2 = temp_r1 - q * r2;
88 }
89
90 if (t1 > modulus) {
91 return modulus + t1;
92 }
93 return t1;
94}
95
104template <class base_uint> uintx<base_uint> uintx<base_uint>::invmod(const uintx& modulus) const
105{
106 ASSERT((*this) != 0);
107 if (modulus == 0) {
108 return 0;
109 }
110 if (modulus.get_msb() >= (2 * base_uint::length() - 1)) {
111 uintx<uintx<base_uint>> a_expanded(*this);
112 uintx<uintx<base_uint>> modulus_expanded(modulus);
113 return a_expanded.unsafe_invmod(modulus_expanded).lo;
114 }
115 return this->unsafe_invmod(modulus);
116}
117
118template <class base_uint> bool uintx<base_uint>::get_bit(const uint64_t bit_index) const
119{
120 if (bit_index >= base_uint::length()) {
121 return hi.get_bit(bit_index - base_uint::length());
122 }
123 return lo.get_bit(bit_index);
124}
125
126template <class base_uint> uintx<base_uint> uintx<base_uint>::operator-() const
127{
128 return uintx(0) - *this;
129}
130
131template <class base_uint> uintx<base_uint> uintx<base_uint>::operator*(const uintx& other) const
132{
133 const auto lolo = lo.mul_extended(other.lo);
134 const auto lohi = lo.mul_extended(other.hi);
135 const auto hilo = hi.mul_extended(other.lo);
136
137 base_uint top = lolo.second + hilo.first + lohi.first;
138 base_uint bottom = lolo.first;
139 return { bottom, top };
140}
141
142template <class base_uint>
144{
145 const auto lolo = lo.mul_extended(other.lo);
146 const auto lohi = lo.mul_extended(other.hi);
147 const auto hilo = hi.mul_extended(other.lo);
148 const auto hihi = hi.mul_extended(other.hi);
149
150 base_uint t0 = lolo.first;
151 base_uint t1 = lolo.second;
152 base_uint t2 = hilo.second;
153 base_uint t3 = hihi.second;
154 base_uint t2_carry(0);
155 base_uint t3_carry(0);
156 t1 += hilo.first;
157 t2_carry += (t1 < hilo.first ? base_uint(1) : base_uint(0));
158 t1 += lohi.first;
159 t2_carry += (t1 < lohi.first ? base_uint(1) : base_uint(0));
160 t2 += lohi.second;
161 t3_carry += (t2 < lohi.second ? base_uint(1) : base_uint(0));
162 t2 += hihi.first;
163 t3_carry += (t2 < hihi.first ? base_uint(1) : base_uint(0));
164 t2 += t2_carry;
165 t3_carry += (t2 < t2_carry ? base_uint(1) : base_uint(0));
166 t3 += t3_carry;
167 return { uintx(t0, t1), uintx(t2, t3) };
168}
169
170template <class base_uint> uintx<base_uint> uintx<base_uint>::operator/(const uintx& other) const
171
172{
173 return divmod(other).first;
174}
175
176template <class base_uint> uintx<base_uint> uintx<base_uint>::operator%(const uintx& other) const
177
178{
179 return divmod(other).second;
180}
182template <class base_uint> uintx<base_uint> uintx<base_uint>::operator^(const uintx& other) const
184 return { lo ^ other.lo, hi ^ other.hi };
186
187template <class base_uint> uintx<base_uint> uintx<base_uint>::operator|(const uintx& other) const
188{
189 return { lo | other.lo, hi | other.hi };
192template <class base_uint> uintx<base_uint> uintx<base_uint>::operator~() const
193{
194 return { ~lo, ~hi };
197template <class base_uint> bool uintx<base_uint>::operator==(const uintx& other) const
199 return ((lo == other.lo) && (hi == other.hi));
202template <class base_uint> bool uintx<base_uint>::operator!=(const uintx& other) const
203{
204 return !(*this == other);
205}
206
207template <class base_uint> bool uintx<base_uint>::operator!() const
208{
209 return *this == uintx(0ULL);
210}
211
212template <class base_uint> bool uintx<base_uint>::operator>(const uintx& other) const
213{
214 bool hi_gt = hi > other.hi;
215 bool lo_gt = lo > other.lo;
216
217 bool gt = (hi_gt) || (lo_gt && (hi == other.hi));
218 return gt;
219}
220
221template <class base_uint> bool uintx<base_uint>::operator>=(const uintx& other) const
222{
223 return (*this > other) || (*this == other);
224}
225
226template <class base_uint> bool uintx<base_uint>::operator<(const uintx& other) const
227{
228 return other > *this;
229}
230
231template <class base_uint> bool uintx<base_uint>::operator<=(const uintx& other) const
232{
233 return (*this < other) || (*this == other);
234}
235
237
238{
239 constexpr uint256_t BN254FQMODULUS256 =
240 uint256_t(0x3C208C16D87CFD47UL, 0x97816a916871ca8dUL, 0xb85045b68181585dUL, 0x30644e72e131a029UL);
241 constexpr uint256_t SECP256K1FQMODULUS256 =
242 uint256_t(0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL);
243 constexpr uint256_t SECP256R1FQMODULUS256 =
244 uint256_t(0xFFFFFFFFFFFFFFFFULL, 0x00000000FFFFFFFFULL, 0x0000000000000000ULL, 0xFFFFFFFF00000001ULL);
245
246 if (b == uintx(BN254FQMODULUS256)) {
247 return (*this).template barrett_reduction<BN254FQMODULUS256>();
248 }
249 if (b == uintx(SECP256K1FQMODULUS256)) {
250 return (*this).template barrett_reduction<SECP256K1FQMODULUS256>();
251 }
252 if (b == uintx(SECP256R1FQMODULUS256)) {
253 return (*this).template barrett_reduction<SECP256R1FQMODULUS256>();
254 }
255
256 return divmod_base(b);
257}
258
270template <class base_uint>
271template <base_uint modulus>
273{
274 // N.B. k could be modulus.get_msb() + 1 if we have strong bounds on the max value of (*self)
275 // (a smaller k would allow us to fit `redc_parameter` into `base_uint` and not `uintx`)
276 constexpr size_t k = base_uint::length() - 1;
277 // N.B. computation of redc_parameter requires division operation - if this cannot be precomputed (or amortized over
278 // multiple reductions over the same modulus), barrett_reduction is much slower than divmod
279 static const uintx redc_parameter = ((uintx(1) << (k * 2)).divmod_base(uintx(modulus))).first;
281 const auto x = *this;
282
283 // compute x * redc_parameter
284 const auto mul_result = x.mul_extended(redc_parameter);
285 constexpr size_t shift = 2 * k;
286
287 // compute (x * redc_parameter) >> 2k
288 // This is equivalent to (x * (2^{2k} / modulus) / 2^{2k})
289 // which approximates to x / modulus
290 const uintx downshifted_hi_bits = mul_result.second & ((uintx(1) << shift) - 1);
291 const uintx mul_hi_underflow = uintx(downshifted_hi_bits) << (length() - shift);
292 uintx quotient = (mul_result.first >> shift) | mul_hi_underflow;
293
294 // compute remainder by determining value of x - quotient * modulus
295 uintx qm_lo(0);
296 {
297 const auto lolo = quotient.lo.mul_extended(modulus);
298 const auto lohi = quotient.hi.mul_extended(modulus);
299 base_uint t0 = lolo.first;
300 base_uint t1 = lolo.second;
301 t1 = t1 + lohi.first;
302 qm_lo = uintx(t0, t1);
303 }
304 uintx remainder = x - qm_lo;
305
306 // because redc_parameter is an imperfect representation of 2^{2k} / n (might be too small),
307 // the computed quotient may be off by up to 4 (classic algorithm should be up to 1,
308 // TODO(https://github.com/AztecProtocol/barretenberg/issues/1051): investigate, why)
309 size_t i = 0;
310 while (remainder >= uintx(modulus)) {
311 BB_ASSERT_LT(i, 4U);
312 remainder = remainder - modulus;
313 quotient = quotient + 1;
314 i++;
315 }
316 return std::make_pair(quotient, remainder);
317}
318} // namespace bb::numeric
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:115
#define ASSERT(expression,...)
Definition assert.hpp:49
uintx operator%(const uintx &other) const
std::pair< uintx, uintx > divmod(const uintx &b) const
bool operator!=(const uintx &other) const
bool operator<(const uintx &other) const
uintx operator*(const uintx &other) const
uintx unsafe_invmod(const uintx &modulus) const
uintx operator-() const
std::pair< uintx, uintx > divmod_base(const uintx &b) const
bool operator!() const
bool operator==(const uintx &other) const
bool operator>(const uintx &other) const
bool operator<=(const uintx &other) const
uintx operator~() const
bool get_bit(uint64_t bit_index) const
constexpr uint64_t get_msb() const
Definition uintx.hpp:69
std::pair< uintx, uintx > barrett_reduction() const
uintx operator|(const uintx &other) const
bool operator>=(const uintx &other) const
uintx operator/(const uintx &other) const
std::pair< uintx, uintx > mul_extended(const uintx &other) const
uintx invmod(const uintx &modulus) const
uintx operator^(const uintx &other) const
GreaterThan gt
FF b
uint8_t const size_t length
Definition data_store.hpp:9
constexpr T get_msb(const T in)
Definition get_msb.hpp:47
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13