Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
scalar_multiplication.test.cpp
Go to the documentation of this file.
9#include <filesystem>
10#include <gtest/gtest.h>
11
12using namespace bb;
13
14namespace {
16} // namespace
17
18template <class Curve> class ScalarMultiplicationTest : public ::testing::Test {
19 public:
20 using Group = typename Curve::Group;
21 using Element = typename Curve::Element;
24
25 static constexpr size_t num_points = 201123;
26 static inline std::vector<AffineElement> generators{};
27 static inline std::vector<ScalarField> scalars{};
28
29 static AffineElement naive_msm(std::span<ScalarField> input_scalars, std::span<const AffineElement> input_points)
30 {
31 size_t total_points = input_scalars.size();
32 size_t num_threads = get_num_cpus();
33 std::vector<Element> expected_accs(num_threads);
34 size_t range_per_thread = (total_points + num_threads - 1) / num_threads;
35 parallel_for(num_threads, [&](size_t thread_idx) {
36 Element expected_thread_acc;
37 expected_thread_acc.self_set_infinity();
38 size_t start = thread_idx * range_per_thread;
39 size_t end = ((thread_idx + 1) * range_per_thread > total_points) ? total_points
40 : (thread_idx + 1) * range_per_thread;
41 bool skip = start >= total_points;
42 if (!skip) {
43 for (size_t i = start; i < end; ++i) {
44 expected_thread_acc += input_points[i] * input_scalars[i];
45 }
46 }
47 expected_accs[thread_idx] = expected_thread_acc;
48 });
49
50 Element expected_acc = Element();
51 expected_acc.self_set_infinity();
52 for (auto& acc : expected_accs) {
53 expected_acc += acc;
54 }
55 return AffineElement(expected_acc);
56 }
57 static void SetUpTestSuite()
58 {
59 generators.resize(num_points);
60 scalars.resize(num_points);
61 parallel_for_range(num_points, [&](size_t start, size_t end) {
62 for (size_t i = start; i < end; ++i) {
63 generators[i] = Group::one * Curve::ScalarField::random_element(&engine);
64 scalars[i] = Curve::ScalarField::random_element(&engine);
65 }
66 });
67 for (size_t i = 0; i < num_points - 1; ++i) {
68 ASSERT_EQ(generators[i].x == generators[i + 1].x, false);
69 }
70 };
71};
72
73using CurveTypes = ::testing::Types<bb::curve::BN254, bb::curve::Grumpkin>;
75
76#define SCALAR_MULTIPLICATION_TYPE_ALIASES \
77 using Curve = TypeParam; \
78 using ScalarField = typename Curve::ScalarField; \
79 // using AffineElement = typename Curve::AffineEleent;
80
82{
84 const size_t fr_size = 254;
85 const size_t slice_bits = 7;
86 size_t num_slices = (fr_size + 6) / 7;
87 size_t last_slice_bits = fr_size - ((num_slices - 1) * slice_bits);
88
89 for (size_t x = 0; x < 100; ++x) {
90
91 uint256_t input_u256 = engine.get_random_uint256();
92 input_u256.data[3] = input_u256.data[3] & 0x3FFFFFFFFFFFFFFF; // 254 bits
93 while (input_u256 > ScalarField::modulus) {
94 input_u256 -= ScalarField::modulus;
95 }
96 std::vector<uint32_t> slices(num_slices);
97
98 uint256_t acc = input_u256;
99 for (size_t i = 0; i < num_slices; ++i) {
100 size_t mask = ((1 << slice_bits) - 1UL);
101 size_t shift = slice_bits;
102 if (i == 0) {
103 mask = ((1UL << last_slice_bits) - 1UL);
104 shift = last_slice_bits;
105 }
106 slices[num_slices - 1 - i] = static_cast<uint32_t>((acc & mask).data[0]);
107 acc = acc >> shift;
108 }
109 // uint256_t input_u256 = 0;
110
111 // for (size_t i = 0; i < num_slices; ++i) {
112 // bool valid_slice = false;
113 // while (!valid_slice) {
114 // size_t mask = ((1 << slice_bits) - 1);
115 // if (i == num_slices - 1) {
116 // mask = ((1 << last_slice_bits) - 1);
117 // }
118 // const uint32_t slice = engine.get_random_uint32() & mask;
119
120 // size_t shift = (fr_size - slice_bits - (i * slice_bits));
121 // if (i == num_slices - 1) {
122 // shift = 0;
123 // }
124
125 // const uint256_t new_input_u256 = input_u256 + (uint256_t(slice) << shift);
126 // // ASSERT(new_input_u256 < fr::modulus);
127 // if (new_input_u256 < fr::modulus) {
128 // input_u256 = new_input_u256;
129 // slices[i] = slice;
130 // valid_slice = true;
131 // }
132 // }
133 // }
134
135 // ASSERT(input_u256 < fr::modulus);
136 // while (input_u256 > fr::modulus) {
137 // input_u256 -= fr::modulus;
138 // }
139 ScalarField input(input_u256);
140 input.self_from_montgomery_form();
141
142 ASSERT_EQ(input.data[0], input_u256.data[0]);
143 ASSERT_EQ(input.data[1], input_u256.data[1]);
144 ASSERT_EQ(input.data[2], input_u256.data[2]);
145 ASSERT_EQ(input.data[3], input_u256.data[3]);
146
147 for (size_t i = 0; i < num_slices; ++i) {
148
149 uint32_t result = scalar_multiplication::MSM<Curve>::get_scalar_slice(input, i, slice_bits);
150 EXPECT_EQ(result, slices[i]);
151 }
152 }
153 // fr test = 0;
154 // test.data[0] = 0b;
155 // test.data[1] = 0b010101
156}
157
159{
160 using Curve = TypeParam;
161 using AffineElement = typename Curve::AffineElement;
162 // todo make this not a multiple of 10k
163 const size_t total_points = 30071;
164 const size_t num_buckets = 128;
165
166 std::vector<uint64_t> input_point_schedule;
167 for (size_t i = 0; i < total_points; ++i) {
168
169 uint64_t bucket = static_cast<uint64_t>(engine.get_random_uint8()) & 0x7f;
170
171 uint64_t schedule = static_cast<uint64_t>(bucket) + (static_cast<uint64_t>(i) << 32);
172 input_point_schedule.push_back(schedule);
173 }
176 typename scalar_multiplication::MSM<Curve>::BucketAccumulators bucket_data(num_buckets);
178 input_point_schedule, TestFixture::generators, affine_data, bucket_data, 0, 0);
179
180 std::vector<typename Curve::Element> expected_buckets(num_buckets);
181 for (auto& e : expected_buckets) {
182 e.self_set_infinity();
183 }
184 // std::cout << "computing expected" << std::endl;
185 for (size_t i = 0; i < total_points; ++i) {
186 uint64_t bucket = input_point_schedule[i] & 0xFFFFFFFF;
187 EXPECT_LT(static_cast<size_t>(bucket), num_buckets);
188 expected_buckets[static_cast<size_t>(bucket)] += TestFixture::generators[i];
189 }
190 for (size_t i = 0; i < num_buckets; ++i) {
191 if (!expected_buckets[i].is_point_at_infinity()) {
192 AffineElement expected(expected_buckets[i]);
193 EXPECT_EQ(expected, bucket_data.buckets[i]);
194 } else {
195 EXPECT_FALSE(bucket_data.bucket_exists.get(i));
196 }
197 }
198}
199
200TYPED_TEST(ScalarMultiplicationTest, ConsumePointBatchAndAccumulate)
201{
203 using Element = typename Curve::Element;
204 using AffineElement = typename Curve::AffineElement;
205
206 // todo make this not a multiple of 10k
207 const size_t total_points = 30071;
208 const size_t num_buckets = 128;
209
210 std::vector<uint64_t> input_point_schedule;
211 for (size_t i = 0; i < total_points; ++i) {
212
213 uint64_t bucket = static_cast<uint64_t>(engine.get_random_uint8()) & 0x7f;
214
215 uint64_t schedule = static_cast<uint64_t>(bucket) + (static_cast<uint64_t>(i) << 32);
216 input_point_schedule.push_back(schedule);
217 }
220 typename scalar_multiplication::MSM<Curve>::BucketAccumulators bucket_data(num_buckets);
222 input_point_schedule, TestFixture::generators, affine_data, bucket_data, 0, 0);
223
225
226 Element expected_acc = Element();
227 expected_acc.self_set_infinity();
228 size_t num_threads = get_num_cpus();
229 std::vector<Element> expected_accs(num_threads);
230 size_t range_per_thread = (total_points + num_threads - 1) / num_threads;
231 parallel_for(num_threads, [&](size_t thread_idx) {
232 Element expected_thread_acc;
233 expected_thread_acc.self_set_infinity();
234 size_t start = thread_idx * range_per_thread;
235 size_t end = (thread_idx == num_threads - 1) ? total_points : (thread_idx + 1) * range_per_thread;
236 bool skip = start >= total_points;
237 if (!skip) {
238 for (size_t i = start; i < end; ++i) {
239 ScalarField scalar = input_point_schedule[i] & 0xFFFFFFFF;
240 expected_thread_acc += TestFixture::generators[i] * scalar;
241 }
242 }
243 expected_accs[thread_idx] = expected_thread_acc;
244 });
245
246 for (size_t i = 0; i < num_threads; ++i) {
247 expected_acc += expected_accs[i];
248 }
249 AffineElement expected(expected_acc);
250 EXPECT_EQ(AffineElement(result), expected);
251}
252
253TYPED_TEST(ScalarMultiplicationTest, RadixSortCountZeroEntries)
254{
255 const size_t total_points = 30071;
256
257 std::vector<uint64_t> input_point_schedule;
258 for (size_t i = 0; i < total_points; ++i) {
259
260 uint64_t bucket = static_cast<uint64_t>(engine.get_random_uint8()) & 0x7f;
261
262 uint64_t schedule = static_cast<uint64_t>(bucket) + (static_cast<uint64_t>(i) << 32);
263 input_point_schedule.push_back(schedule);
264 }
265
267 &input_point_schedule[0], input_point_schedule.size(), 7);
268 size_t expected = 0;
269 for (size_t i = 0; i < total_points; ++i) {
270 expected += static_cast<size_t>((input_point_schedule[i] & 0xFFFFFFFF) == 0);
271 }
272 EXPECT_EQ(result, expected);
273}
274
275TYPED_TEST(ScalarMultiplicationTest, EvaluatePippengerRound)
276{
278 using AffineElement = typename Curve::AffineElement;
279 using Element = typename Curve::Element;
280
281 const size_t num_points = 2;
282 std::vector<ScalarField> scalars(num_points);
283 constexpr size_t NUM_BITS_IN_FIELD = fr::modulus.get_msb() + 1;
284 const size_t normal_slice_size = 7; // stop hardcoding
285 const size_t num_buckets = 1 << normal_slice_size;
286
287 const size_t num_rounds = (NUM_BITS_IN_FIELD + normal_slice_size - 1) / normal_slice_size;
290 typename scalar_multiplication::MSM<Curve>::BucketAccumulators bucket_data(num_buckets);
291
292 for (size_t round_index = num_rounds - 1; round_index < num_rounds; round_index++) {
293 const size_t num_bits_in_slice =
294 (round_index == (num_rounds - 1)) ? (NUM_BITS_IN_FIELD % normal_slice_size) : normal_slice_size;
295 for (size_t i = 0; i < num_points; ++i) {
296
297 size_t hi_bit = NUM_BITS_IN_FIELD - (round_index * normal_slice_size);
298 size_t lo_bit = hi_bit - normal_slice_size;
299 if (hi_bit < normal_slice_size) {
300 lo_bit = 0;
301 }
302 uint64_t slice = engine.get_random_uint64() & ((1 << num_bits_in_slice) - 1);
303 // at this point in the algo, scalars has been converted out of montgomery form
304 uint256_t scalar = uint256_t(slice) << lo_bit;
305 scalars[i].data[0] = scalar.data[0];
306 scalars[i].data[1] = scalar.data[1];
307 scalars[i].data[2] = scalar.data[2];
308 scalars[i].data[3] = scalar.data[3];
309 scalars[i].self_to_montgomery_form();
310 }
311
312 std::vector<uint32_t> indices;
314
315 Element previous_round_output;
316 previous_round_output.self_set_infinity();
317 for (auto x : indices) {
318 ASSERT_LT(x, num_points);
319 }
320 std::vector<uint64_t> point_schedule(scalars.size());
322 scalars, TestFixture::generators, indices, point_schedule);
324 msm_data, round_index, affine_data, bucket_data, previous_round_output, 7);
325 Element expected;
326 expected.self_set_infinity();
327 for (size_t i = 0; i < num_points; ++i) {
328 ScalarField baz = scalars[i].to_montgomery_form();
329 expected += (TestFixture::generators[i] * baz);
330 }
331 size_t num_doublings = NUM_BITS_IN_FIELD - (normal_slice_size * (round_index + 1));
332 if (round_index == num_rounds - 1) {
333 num_doublings = 0;
334 }
335 for (size_t i = 0; i < num_doublings; ++i) {
336 result.self_dbl();
337 }
338 EXPECT_EQ(AffineElement(result), AffineElement(expected));
339 }
340}
341
343{
345 using AffineElement = typename Curve::AffineElement;
346
347 const size_t num_points = TestFixture::num_points;
348
349 std::span<ScalarField> scalars(&TestFixture::scalars[0], num_points);
350 AffineElement result =
351 scalar_multiplication::MSM<Curve>::msm(TestFixture::generators, PolynomialSpan<ScalarField>(0, scalars));
352
353 AffineElement expected = TestFixture::naive_msm(scalars, TestFixture::generators);
354 EXPECT_EQ(result, expected);
355}
356
358{
360 using AffineElement = typename Curve::AffineElement;
361
362 const size_t num_msms = static_cast<size_t>(engine.get_random_uint8());
363 std::vector<AffineElement> expected(num_msms);
364
366 std::vector<std::span<ScalarField>> batch_scalars_spans;
367
368 size_t vector_offset = 0;
369 for (size_t k = 0; k < num_msms; ++k) {
370 const size_t num_points = static_cast<size_t>(engine.get_random_uint16()) % 400;
371
372 ASSERT_LT(vector_offset + num_points, TestFixture::num_points);
373 std::span<ScalarField> batch_scalars(&TestFixture::scalars[vector_offset], num_points);
374 std::span<const AffineElement> batch_points(&TestFixture::generators[vector_offset], num_points);
375
376 vector_offset += num_points;
377 batch_points_span.push_back(batch_points);
378 batch_scalars_spans.push_back(batch_scalars);
379
380 expected[k] = TestFixture::naive_msm(batch_scalars_spans[k], batch_points_span[k]);
381 }
382
383 std::vector<AffineElement> result =
384 scalar_multiplication::MSM<Curve>::batch_multi_scalar_mul(batch_points_span, batch_scalars_spans);
385
386 EXPECT_EQ(result, expected);
387}
388
389TYPED_TEST(ScalarMultiplicationTest, BatchMultiScalarMulSparse)
390{
392 using AffineElement = typename Curve::AffineElement;
393
394 const size_t num_msms = 10;
395 std::vector<AffineElement> expected(num_msms);
396
397 std::vector<std::vector<ScalarField>> batch_scalars(num_msms);
398 std::vector<std::vector<AffineElement>> batch_input_points(num_msms);
400 std::vector<std::span<ScalarField>> batch_scalars_spans;
401
402 for (size_t k = 0; k < num_msms; ++k) {
403 const size_t num_points = 33;
404 auto& scalars = batch_scalars[k];
405
406 scalars.resize(num_points);
407
408 size_t fixture_offset = k * num_points;
409
410 std::span<AffineElement> batch_points(&TestFixture::generators[fixture_offset], num_points);
411 for (size_t i = 0; i < 13; ++i) {
412 scalars[i] = 0;
413 }
414 for (size_t i = 13; i < 23; ++i) {
415 scalars[i] = TestFixture::scalars[fixture_offset + i + 13];
416 }
417 for (size_t i = 23; i < num_points; ++i) {
418 scalars[i] = 0;
419 }
420 batch_points_span.push_back(batch_points);
421 batch_scalars_spans.push_back(batch_scalars[k]);
422
423 expected[k] = TestFixture::naive_msm(batch_scalars[k], batch_points);
424 }
425
426 std::vector<AffineElement> result =
427 scalar_multiplication::MSM<Curve>::batch_multi_scalar_mul(batch_points_span, batch_scalars_spans);
428
429 EXPECT_EQ(result, expected);
430}
431
433{
435 using AffineElement = typename Curve::AffineElement;
436
437 const size_t start_index = 1234;
438 const size_t num_points = TestFixture::num_points - start_index;
439
440 PolynomialSpan<ScalarField> scalar_span =
441 PolynomialSpan<ScalarField>(start_index, std::span<ScalarField>(&TestFixture::scalars[0], num_points));
442 AffineElement result = scalar_multiplication::MSM<Curve>::msm(TestFixture::generators, scalar_span);
443
444 std::span<AffineElement> points(&TestFixture::generators[start_index], num_points);
445 AffineElement expected = TestFixture::naive_msm(scalar_span.span, points);
446 EXPECT_EQ(result, expected);
447}
448
450{
452 using AffineElement = typename Curve::AffineElement;
453
454 const size_t start_index = 1234;
455 const size_t num_points = TestFixture::num_points - start_index;
456 std::vector<ScalarField> scalars(num_points);
457
458 for (size_t i = 0; i < num_points; ++i) {
459 scalars[i] = 0;
460 }
461
462 PolynomialSpan<ScalarField> scalar_span = PolynomialSpan<ScalarField>(start_index, scalars);
463 AffineElement result = scalar_multiplication::MSM<Curve>::msm(TestFixture::generators, scalar_span);
464
465 EXPECT_EQ(result, Curve::Group::affine_point_at_infinity);
466}
467
469{
471 using AffineElement = typename Curve::AffineElement;
472
473 const size_t num_points = 0;
474 std::vector<ScalarField> scalars(num_points);
475 std::vector<AffineElement> input_points(num_points);
477 AffineElement result = scalar_multiplication::MSM<Curve>::msm(input_points, scalar_span);
478
479 EXPECT_EQ(result, Curve::Group::affine_point_at_infinity);
480}
481
482TEST(ScalarMultiplication, SmallInputsExplicit)
483{
484 uint256_t x0(0x68df84429941826a, 0xeb08934ed806781c, 0xc14b6a2e4f796a73, 0x08dc1a9a11a3c8db);
485 uint256_t y0(0x8ae5c31aa997f141, 0xe85f20c504f2c11b, 0x81a94193f3b1ce2b, 0x26f2c37372adb5b7);
486 uint256_t x1(0x80f5a592d919d32f, 0x1362652b984e51ca, 0xa0b26666f770c2a1, 0x142c6e1964e5c3c5);
487 uint256_t y1(0xb6c322ebb5ae4bc5, 0xf9fef6c7909c00f8, 0xb37ca1cc9af3b421, 0x1e331c7fa73d6a59);
488 uint256_t s0(0xe48bf12a24272e08, 0xf8dd0182577f3567, 0xec8fd222b8a6becb, 0x102d76b945612c9b);
489 uint256_t s1(0x098ae8d69f1e4e9e, 0xb5c8313c0f6040ed, 0xf78041e30cc46c44, 0x1d1e6e0c21892e13);
490
491 std::vector<grumpkin::fr> scalars{ s0, s1 };
492
495
497
498 auto result = scalar_multiplication::MSM<curve::Grumpkin>::msm(points, scalar_span);
499
500 grumpkin::g1::element expected = (points[0] * scalars[0]) + (points[1] * scalars[1]);
501
502 EXPECT_EQ(result, grumpkin::g1::affine_element(expected));
503}
BB_INLINE bool get(size_t index) const noexcept
Definition bitvector.hpp:36
typename Curve::ScalarField ScalarField
static std::vector< AffineElement > generators
static std::vector< ScalarField > scalars
typename Curve::AffineElement AffineElement
static AffineElement naive_msm(std::span< ScalarField > input_scalars, std::span< const AffineElement > input_points)
typename Group::element Element
Definition grumpkin.hpp:55
typename grumpkin::g1 Group
Definition grumpkin.hpp:54
typename Group::affine_element AffineElement
Definition grumpkin.hpp:56
element class. Implements ecc group arithmetic using Jacobian coordinates See https://hyperelliptic....
Definition element.hpp:33
group_elements::affine_element< Fq, Fr, Params > affine_element
Definition group.hpp:42
virtual uint64_t get_random_uint64()=0
virtual uint8_t get_random_uint8()=0
virtual uint16_t get_random_uint16()=0
virtual uint256_t get_random_uint256()=0
constexpr uint64_t get_msb() const
static Element accumulate_buckets(BucketType &bucket_accumulators) noexcept
static uint32_t get_scalar_slice(const ScalarField &scalar, size_t round, size_t normal_slice_size) noexcept
Given a scalar that is NOT in Montgomery form, extract a slice_size-bit chunk.
static std::vector< AffineElement > batch_multi_scalar_mul(std::vector< std::span< const AffineElement > > &points, std::vector< std::span< ScalarField > > &scalars, bool handle_edge_cases=true) noexcept
Compute multiple multi-scalar multiplications.
static void consume_point_schedule(std::span< const uint64_t > point_schedule, std::span< const AffineElement > points, AffineAdditionData &affine_data, BucketAccumulators &bucket_data, size_t num_input_points_processed, size_t num_queued_affine_points) noexcept
Given a list of points and target buckets to add into, perform required group operations.
static Element evaluate_pippenger_round(MSMData &msm_data, const size_t round_index, AffineAdditionData &affine_data, BucketAccumulators &bucket_data, Element previous_round_output, const size_t bits_per_slice) noexcept
Evaluate a single Pippenger round where we use the affine trick.
static void transform_scalar_and_get_nonzero_scalar_indices(std::span< typename Curve::ScalarField > scalars, std::vector< uint32_t > &consolidated_indices) noexcept
Convert scalar out of Montgomery form. Populate consolidated_indices with nonzero scalar indices.
static AffineElement msm(std::span< const AffineElement > points, PolynomialSpan< const ScalarField > _scalars, bool handle_edge_cases=false) noexcept
Helper method to evaluate a single MSM. Internally calls batch_multi_scalar_mul
const std::vector< FF > data
#define SCALAR_MULTIPLICATION_TYPE_ALIASES
numeric::RNG & engine
RNG & get_randomness()
Definition engine.cpp:203
size_t process_buckets_count_zero_entries(uint64_t *wnaf_entries, const size_t num_entries, const uint32_t num_bits) noexcept
Entry point for Barretenberg command-line interface.
TYPED_TEST_SUITE(ShpleminiTest, TestSettings)
TEST(MegaCircuitBuilder, CopyConstructor)
size_t get_num_cpus()
Definition thread.hpp:12
C slice(C const &container, size_t start)
Definition container.hpp:9
::testing::Types< curve::BN254, curve::Grumpkin > CurveTypes
TYPED_TEST(ShpleminiTest, CorrectnessOfMultivariateClaimBatching)
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:72
void parallel_for_range(size_t num_points, const std::function< void(size_t, size_t)> &func, size_t no_multhreading_if_less_or_equal)
Split a loop into several loops running in parallel.
Definition thread.cpp:102
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::Element Element
std::span< Fr > span
static constexpr uint256_t modulus
Temp data structure, one created per thread!
Temp data structure, one created per thread!