85 const std::vector<MSM>& msms,
const uint32_t total_number_of_muls,
const size_t num_msm_rows)
111 const size_t num_rows_in_read_counts_table =
112 static_cast<size_t>(total_number_of_muls) *
113 (eccvm::POINT_TABLE_SIZE >> 1);
117 point_table_read_counts[0].reserve(num_rows_in_read_counts_table);
118 point_table_read_counts[1].reserve(num_rows_in_read_counts_table);
119 for (
size_t i = 0; i < num_rows_in_read_counts_table; ++i) {
120 point_table_read_counts[0].emplace_back(0);
121 point_table_read_counts[1].emplace_back(0);
124 const auto update_read_count = [&point_table_read_counts](
const size_t point_idx,
const int slice) {
133 const size_t row_index_offset = point_idx * 8;
134 const bool digit_is_negative =
slice < 0;
135 const auto relative_row_idx =
static_cast<size_t>((
slice + 15) / 2);
136 const size_t column_index = digit_is_negative ? 1 : 0;
138 if (digit_is_negative) {
139 point_table_read_counts[column_index][row_index_offset + relative_row_idx]++;
141 point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++;
146 std::vector<size_t> msm_row_counts;
147 msm_row_counts.reserve(msms.size() + 1);
148 msm_row_counts.push_back(1);
151 std::vector<size_t> pc_values;
152 pc_values.reserve(msms.size() + 1);
153 pc_values.push_back(total_number_of_muls);
154 for (
const auto& msm : msms) {
156 msm_row_counts.push_back(msm_row_counts.back() + num_rows_required);
157 pc_values.push_back(pc_values.back() - msm.size());
170 for (
size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) {
172 auto pc =
static_cast<uint32_t
>(pc_values[msm_idx]);
173 const auto& msm = msms[msm_idx];
174 const size_t msm_size = msm.size();
175 const size_t num_rows_per_digit =
178 for (
size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) {
179 const size_t num_points_in_row = (relative_row_idx + 1) *
ADDITIONS_PER_ROW > msm_size
183 for (
size_t relative_point_idx = 0; relative_point_idx <
ADDITIONS_PER_ROW; ++relative_point_idx) {
184 const size_t point_idx =
offset + relative_point_idx;
185 const bool add = num_points_in_row > relative_point_idx;
187 int slice = msm[point_idx].wnaf_digits[digit_idx];
189 update_read_count((total_number_of_muls - pc) + point_idx,
slice);
196 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
202 ++relative_point_idx) {
203 bool add = num_points_in_row > relative_point_idx;
204 const size_t point_idx =
offset + relative_point_idx;
211 int slice = msm[point_idx].wnaf_skew ? -1 : -15;
212 update_read_count((total_number_of_muls - pc) + point_idx,
slice);
230 const size_t num_point_adds_and_doubles =
231 (num_msm_rows - 2) * 4;
238 const size_t num_accumulators = num_msm_rows - 1;
242 static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3;
243 const size_t num_points_to_normalize =
244 (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators;
245 std::vector<Element> points_to_normalize(num_points_to_normalize);
247 std::span<Element> p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles);
248 std::span<Element> p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles);
252 std::vector<bool> is_double_or_add(num_point_adds_and_doubles);
254 std::span<Element> accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators);
258 accumulator_trace[0] = offset_generator;
263 for (
size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
264 Element accumulator = offset_generator;
265 const auto& msm = msms[msm_idx];
266 size_t msm_row_index = msm_row_counts[msm_idx];
267 const size_t msm_size = msm.size();
268 const size_t num_rows_per_digit =
275 (msm_row_counts[msm_idx] - 1) * 4;
282 const auto pc =
static_cast<uint32_t
>(pc_values[msm_idx]);
284 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
288 auto& row = msm_rows[msm_row_index];
290 row.msm_transition = (digit_idx == 0) && (row_idx == 0);
293 auto& add_state = row.add_state[point_idx];
294 add_state.add = num_points_in_row > point_idx;
295 int slice = add_state.add ? msm[
offset + point_idx].wnaf_digits[digit_idx] : 0;
305 add_state.slice = add_state.add ? (
slice + 15) / 2 : 0;
308 ? msm[
offset + point_idx].precomputed_table[
static_cast<size_t>(add_state.slice)]
313 accumulator = add_state.add ? (accumulator + add_state.point) :
Element(p1);
314 p1_trace[trace_index] = p1;
315 p2_trace[trace_index] = p2;
316 p3_trace[trace_index] = accumulator;
317 is_double_or_add[trace_index] =
false;
321 accumulator_trace[msm_row_index] = accumulator;
323 row.q_double =
false;
325 row.msm_round =
static_cast<uint32_t
>(digit_idx);
326 row.msm_size =
static_cast<uint32_t
>(msm_size);
327 row.msm_count =
static_cast<uint32_t
>(
offset);
337 auto& row = msm_rows[msm_row_index];
338 row.msm_transition =
false;
339 row.msm_round =
static_cast<uint32_t
>(digit_idx + 1);
340 row.msm_size =
static_cast<uint32_t
>(msm_size);
341 row.msm_count =
static_cast<uint32_t
>(0);
346 auto& add_state = row.add_state[point_idx];
347 add_state.add =
false;
349 add_state.point = { 0, 0 };
350 add_state.collision_inverse = 0;
352 p1_trace[trace_index] = accumulator;
353 p2_trace[trace_index] = accumulator;
354 accumulator = accumulator.dbl();
355 p3_trace[trace_index] = accumulator;
356 is_double_or_add[trace_index] =
true;
359 accumulator_trace[msm_row_index] = accumulator;
363 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
364 auto& row = msm_rows[msm_row_index];
370 row.msm_transition =
false;
371 Element acc_expected = accumulator;
373 auto& add_state = row.add_state[point_idx];
374 add_state.add = num_points_in_row > point_idx;
375 add_state.slice = add_state.add ? msm[
offset + point_idx].wnaf_skew ? 7 : 0 : 0;
379 ? msm[
offset + point_idx].precomputed_table[
static_cast<size_t>(add_state.slice)]
384 bool add_predicate = add_state.add ? msm[
offset + point_idx].wnaf_skew :
false;
385 auto p1 = accumulator;
386 accumulator = add_predicate ? accumulator + add_state.point : accumulator;
387 p1_trace[trace_index] = p1;
388 p2_trace[trace_index] = add_state.point;
389 p3_trace[trace_index] = accumulator;
390 is_double_or_add[trace_index] =
false;
394 row.q_double =
false;
396 row.msm_round =
static_cast<uint32_t
>(digit_idx + 1);
397 row.msm_size =
static_cast<uint32_t
>(msm_size);
398 row.msm_count =
static_cast<uint32_t
>(
offset);
400 accumulator_trace[msm_row_index] = accumulator;
409 Element::batch_normalize(&points_to_normalize[start], end - start);
413 std::vector<FF> inverse_trace(num_point_adds_and_doubles);
415 for (
size_t operation_idx = start; operation_idx < end; ++operation_idx) {
416 if (is_double_or_add[operation_idx]) {
417 inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y);
419 inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x);
428 for (
size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
429 const auto& msm = msms[msm_idx];
431 size_t msm_row_index = msm_row_counts[msm_idx];
433 size_t accumulator_index = msm_row_counts[msm_idx] - 1;
434 const size_t msm_size = msm.size();
435 const size_t num_rows_per_digit =
439 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
440 auto& row = msm_rows[msm_row_index];
444 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
445 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
446 row.accumulator_x = normalized_accumulator.x;
447 row.accumulator_y = normalized_accumulator.y;
449 auto& add_state = row.add_state[point_idx];
451 const auto& inverse = inverse_trace[trace_index];
452 const auto& p1 = p1_trace[trace_index];
453 const auto& p2 = p2_trace[trace_index];
454 add_state.collision_inverse = add_state.add ? inverse : 0;
455 add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0;
465 MSMRow& row = msm_rows[msm_row_index];
466 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
467 const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x;
468 const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y;
472 auto& add_state = row.
add_state[point_idx];
473 add_state.collision_inverse = 0;
474 const FF& dx = p1_trace[trace_index].x;
475 const FF& inverse = inverse_trace[trace_index];
476 add_state.lambda = ((dx + dx + dx) * dx) * inverse;
486 for (
size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
487 MSMRow& row = msm_rows[msm_row_index];
488 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
489 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
494 auto& add_state = row.
add_state[point_idx];
495 bool add_predicate = add_state.add ? msm[
offset + point_idx].wnaf_skew :
false;
497 const auto& inverse = inverse_trace[trace_index];
498 const auto& p1 = p1_trace[trace_index];
499 const auto& p2 = p2_trace[trace_index];
500 add_state.collision_inverse = add_predicate ? inverse : 0;
501 add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0;
514 Element final_accumulator(accumulator_trace.back());
515 MSMRow& final_row = msm_rows.back();
516 final_row.
pc =
static_cast<uint32_t
>(pc_values.back());
518 final_row.
accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x;
519 final_row.
accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y;
522 final_row.
q_add =
false;
530 return { msm_rows, point_table_read_counts };