Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
msm_builder.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
8
9#include <cstddef>
10
15
16namespace bb {
17
19 public:
22 using Element = typename CycleGroup::element;
23 using AffineElement = typename CycleGroup::affine_element;
25
26 static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW;
27 static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR;
28
29 struct alignas(64) MSMRow {
30 uint32_t pc = 0; // counter over all half-length (128 bit) scalar muls used to compute the required MSMs
31 uint32_t msm_size = 0; // the number of points that will be scaled and summed
32 uint32_t msm_count = 0; // number of multiplications processed so far in current MSM round
33 uint32_t msm_round = 0; // current "round" of MSM, in {0, ..., 32 = `NUM_WNAF_DIGITS_PER_SCALAR`}. With the
34 // Straus algorithm, we proceed wNAF digit by wNAF digit, from left to right. (final
35 // round deals with the `skew` bit.)
36 bool msm_transition = false; // is 1 if the current row *starts* the processing of a different MSM, else 0.
37 bool q_add = false;
38 bool q_double = false;
39 bool q_skew = false;
40
41 // Each row in the MSM portion of the ECCVM can handle (up to) 4 point-additions.
42 // For each row in the VM we represent the point addition data via a size-4 array of
43 // AddState objects.
44 struct AddState {
45 bool add = false; // are we adding a point at this location in the VM?
46 // e.g if the MSM is of size-2 then the 3rd and 4th AddState objects will have this set
47 // to `false`.
48 int slice = 0; // wNAF slice value. This has values in {0, ..., 15} and corresponds to an odd number in the
49 // range {-15, -13, ..., 15} via the monotonic bijection.
50 AffineElement point{ 0, 0 }; // point being added into the accumulator. (This is of the form nP,
51 // where n is in {-15, -13, ..., 15}.)
52 FF lambda = 0; // when adding `point` into the accumulator via Affine point addition, the value of `lambda`
53 // (i.e., the slope of the line). (we need this as a witness in the circuit.)
54 FF collision_inverse = 0; // `collision_inverse` is used to validate we are not hitting point addition edge
55 // case exceptions, i.e., we want the VM proof to fail if we're doing a point
56 // addition where (x1 == x2). to do this, we simply provide an inverse to x1 - x2.
57 };
58 std::array<AddState, 4> add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 },
59 AddState{ false, 0, { 0, 0 }, 0, 0 },
60 AddState{ false, 0, { 0, 0 }, 0, 0 },
61 AddState{ false, 0, { 0, 0 }, 0, 0 } };
62 // The accumulator here is, in general, the result of four EC additions: A + Q_1 + Q_2 + Q_3 + Q_4.
63 // We do not explicitly store the intermediate values A + Q_1, A + Q_1 + Q_2, and A + Q_1 + Q_2 + Q_3, although
64 // these values are implicitly used in the values of `AddState.lambda` and `AddState.collision_inverse`.
65
66 FF accumulator_x = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
67 // the points in `add_state`.
68 FF accumulator_y = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
69 // the points in `add_state`.
70 };
71
85 const std::vector<MSM>& msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
86 {
87 // To perform a scalar multiplication of a point P by a scalar x, we precompute a table of points
88 // -15P, -13P, ..., -3P, -P, P, 3P, ..., 15P
89 // When we perform a scalar multiplication, we decompose x into base-16 wNAF digits then look these precomputed
90 // values up with digit-by-digit. As we are performing lookups with the log-derivative argument, we have to
91 // record read counts. We record read counts in a table with the following structure:
92 // 1st write column = positive wNAF digits
93 // 2nd write column = negative wNAF digits
94 // the row number is a function of pc and wnaf digit:
95 // point_idx = total_number_of_muls - pc
96 // row = point_idx * rows_per_point_table + (some function of the slice value)
97 //
98 // Illustration:
99 // Block Structure Table structure:
100 // | 0 | 1 | | Block_{0} | <-- pc = total_number_of_muls
101 // | - | - | | Block_{1} | <-- pc = total_number_of_muls-(num muls in msm 0)
102 // 1 | # | # | -1 | ... | ...
103 // 3 | # | # | -3 | Block_{total_number_of_muls-1} | <-- pc = num muls in last msm
104 // 5 | # | # | -5
105 // 7 | # | # | -7
106 // 9 | # | # | -9
107 // 11 | # | # | -11
108 // 13 | # | # | -13
109 // 15 | # | # | -15
110
111 const size_t num_rows_in_read_counts_table =
112 static_cast<size_t>(total_number_of_muls) *
113 (eccvm::POINT_TABLE_SIZE >> 1); // `POINT_TABLE_SIZE` is 2ʷ, where in our case w = 4. As noted above, with
114 // respect to *read counts*, we are record looking up the positive and
115 // negative odd multiples of [P] in two separate columns, each of size 2ʷ⁻¹.
116 std::array<std::vector<size_t>, 2> point_table_read_counts;
117 point_table_read_counts[0].reserve(num_rows_in_read_counts_table);
118 point_table_read_counts[1].reserve(num_rows_in_read_counts_table);
119 for (size_t i = 0; i < num_rows_in_read_counts_table; ++i) {
120 point_table_read_counts[0].emplace_back(0);
121 point_table_read_counts[1].emplace_back(0);
122 }
123
124 const auto update_read_count = [&point_table_read_counts](const size_t point_idx, const int slice) {
133 const size_t row_index_offset = point_idx * 8;
134 const bool digit_is_negative = slice < 0;
135 const auto relative_row_idx = static_cast<size_t>((slice + 15) / 2);
136 const size_t column_index = digit_is_negative ? 1 : 0;
137
138 if (digit_is_negative) {
139 point_table_read_counts[column_index][row_index_offset + relative_row_idx]++;
140 } else {
141 point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++;
142 }
143 };
144
145 // compute which row index each multiscalar multiplication will start at.
146 std::vector<size_t> msm_row_counts;
147 msm_row_counts.reserve(msms.size() + 1);
148 msm_row_counts.push_back(1);
149 // compute the program counter (i.e. the index among all single scalar muls) that each multiscalar
150 // multiplication will start at.
151 std::vector<size_t> pc_values;
152 pc_values.reserve(msms.size() + 1);
153 pc_values.push_back(total_number_of_muls);
154 for (const auto& msm : msms) {
155 const size_t num_rows_required = EccvmRowTracker::num_eccvm_msm_rows(msm.size());
156 msm_row_counts.push_back(msm_row_counts.back() + num_rows_required);
157 pc_values.push_back(pc_values.back() - msm.size());
158 }
159 BB_ASSERT_EQ(pc_values.back(), 0U);
160
161 // compute the MSM rows
162
163 std::vector<MSMRow> msm_rows(num_msm_rows);
164 // start with empty row (shiftable polynomials must have 0 as first coefficient)
165 msm_rows[0] = (MSMRow{});
166 // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup
167 // tables are called.
168 // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big
169 // concern.
170 for (size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) {
171 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
172 auto pc = static_cast<uint32_t>(pc_values[msm_idx]);
173 const auto& msm = msms[msm_idx];
174 const size_t msm_size = msm.size();
175 const size_t num_rows_per_digit =
176 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
177
178 for (size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) {
179 const size_t num_points_in_row = (relative_row_idx + 1) * ADDITIONS_PER_ROW > msm_size
180 ? (msm_size % ADDITIONS_PER_ROW)
182 const size_t offset = relative_row_idx * ADDITIONS_PER_ROW;
183 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; ++relative_point_idx) {
184 const size_t point_idx = offset + relative_point_idx;
185 const bool add = num_points_in_row > relative_point_idx;
186 if (add) {
187 int slice = msm[point_idx].wnaf_digits[digit_idx];
188 // pc starts at total_number_of_muls and decreses non-uniformly to 0
189 update_read_count((total_number_of_muls - pc) + point_idx, slice);
190 }
191 }
192 }
193
194 // update the log-derivative read count for the lookup associated with WNAF skew
195 if (digit_idx == NUM_WNAF_DIGITS_PER_SCALAR - 1) {
196 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
197 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
198 ? (msm_size % ADDITIONS_PER_ROW)
200 const size_t offset = row_idx * ADDITIONS_PER_ROW;
201 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW;
202 ++relative_point_idx) {
203 bool add = num_points_in_row > relative_point_idx;
204 const size_t point_idx = offset + relative_point_idx;
205 if (add) {
206 // pc starts at total_number_of_muls and decreases non-uniformly to 0
207 // -15 maps to the 1st point in the lookup table (array element 0)
208 // -1 maps to the point in the lookup table that corresponds to the negation of the
209 // original input point (i.e. the point we need to add into the accumulator if wnaf_skew
210 // is positive)
211 int slice = msm[point_idx].wnaf_skew ? -1 : -15;
212 update_read_count((total_number_of_muls - pc) + point_idx, slice);
213 }
214 }
215 }
216 }
217 }
218 }
219
220 // The execution trace data for the MSM columns requires knowledge of intermediate values from *affine* point
221 // addition. The naive solution to compute this data requires 2 field inversions per in-circuit group addition
222 // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps.
223 // Step 1: compute the execution trace group operations in *projective* coordinates. (these will be stored in
224 // `p1_trace`, `p2_trace`, and `p3_trace`)
225 // Step 2: use batch inversion trick to convert all points into affine coordinates
226 // Step 3: populate the full execution trace, including the intermediate values from affine group
227 // operations
228 // This section sets up the data structures we need to store all intermediate ECC operations in projective form
229
230 const size_t num_point_adds_and_doubles =
231 (num_msm_rows - 2) * 4; // `num_msm_rows - 2` is the actual number of rows in the table required to compute
232 // the MSM; the msm table itself has a dummy row at the beginning and an extra row
233 // with the `x` and `y` coordinates of the accumulator at the end. (In general, the
234 // output of the accumulator from the computation at row `i` is present on row
235 // `i+1`. We multiply by 4 because each "row" of the VM processes 4 point-additions
236 // (and the fact that w = 4 means we must interleave with 4 doublings). This
237 // "corresponds" to the fact that `MSMROW.add_state` has 4 entries.
238 const size_t num_accumulators = num_msm_rows - 1; // for every row after the first row, we have an accumulator.
239 // In what follows, either p1 + p2 = p3, or p1.dbl() = p3
240 // We create 1 vector to store the entire point trace. We split into multiple containers using std::span
241 // (we want 1 vector object to more efficiently batch-normalize points)
242 static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3;
243 const size_t num_points_to_normalize =
244 (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators;
245 std::vector<Element> points_to_normalize(num_points_to_normalize);
246 std::span<Element> p1_trace(&points_to_normalize[0], num_point_adds_and_doubles);
247 std::span<Element> p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles);
248 std::span<Element> p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles);
249 // `is_double_or_add` records whether an entry in the p1/p2/p3 trace represents a point addition or
250 // doubling. if it is `true`, then we are doubling (i.e., the condition is that `p3 = p1.dbl()`), else we are
251 // adding (i.e., the condition is that `p3 = p1 + p2`).
252 std::vector<bool> is_double_or_add(num_point_adds_and_doubles);
253 // accumulator_trace tracks the value of the ECCVM accumulator for each row
254 std::span<Element> accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators);
255
256 // we start the accumulator at the offset generator point
257 constexpr auto offset_generator = get_precomputed_generators<g1, "ECCVM_OFFSET_GENERATOR", 1>()[0];
258 accumulator_trace[0] = offset_generator;
259
260 // TODO(https://github.com/AztecProtocol/barretenberg/issues/973): Reinstate multitreading?
261 // populate point trace, and the components of the MSM execution trace that do not relate to affine point
262 // operations
263 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
264 Element accumulator = offset_generator; // for every MSM, we start with the same `offset_generator`
265 const auto& msm = msms[msm_idx]; // which MSM we are processing. This is of type `std::vector<ScalarMul>`.
266 size_t msm_row_index = msm_row_counts[msm_idx]; // the row where the given MSM starts
267 const size_t msm_size = msm.size();
268 const size_t num_rows_per_digit =
269 (msm_size / ADDITIONS_PER_ROW) +
270 (msm_size % ADDITIONS_PER_ROW !=
271 0); // the Straus algorithm proceeds by incrementing through the digit-slots and doing
272 // computations *across* the `ScalarMul`s that make up our MSM. Each digit-slot therefore
273 // contributes the *ceiling* of `msm_size`/`ADDITIONS_PER_ROW`.
274 size_t trace_index =
275 (msm_row_counts[msm_idx] - 1) * 4; // tracks the index in the traces of `p1`, `p2`, `p3`, and
276 // `accumulator_trace` that we are filling out
277
278 // for each digit-slot (`digit_idx`), and then for each row of the VM (which does `ADDITIONS_PER_ROW` point
279 // additions), we either enter in/process (`ADDITIONS_PER_ROW`) `AddState` objects, and then if necessary
280 // (i.e., if not at the last wNAF digit), process the four doublings.
281 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
282 const auto pc = static_cast<uint32_t>(pc_values[msm_idx]); // pc that our msm starts at
283
284 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
285 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
286 ? (msm_size % ADDITIONS_PER_ROW)
288 auto& row = msm_rows[msm_row_index]; // actual `MSMRow` we will fill out in the body of this loop
289 const size_t offset = row_idx * ADDITIONS_PER_ROW;
290 row.msm_transition = (digit_idx == 0) && (row_idx == 0);
291 // each iteration of this loop process/enters in one of the `AddState` objects in `row.add_state`.
292 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
293 auto& add_state = row.add_state[point_idx];
294 add_state.add = num_points_in_row > point_idx;
295 int slice = add_state.add ? msm[offset + point_idx].wnaf_digits[digit_idx] : 0;
296 // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row.
297 // if `row.add_state[point_idx].add = 1`, this indicates that we want to add the
298 // `point_idx`'th point in the MSM columns into the MSM accumulator `add_state.slice` = A
299 // 4-bit WNAF slice of the scalar multiplier associated with the point we are adding (the
300 // specific slice chosen depends on the value of msm_round) (WNAF = our version of
301 // windowed-non-adjacent-form. Value range is `-15, -13,..., 15`)
302 // If `add_state.add = 1`, we want `add_state.slice` to be the *compressed*
303 // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15,
304 // -13, ..., 15 maps to 0, ... , 15)
305 add_state.slice = add_state.add ? (slice + 15) / 2 : 0;
306 add_state.point =
307 add_state.add
308 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
309 : AffineElement{ 0, 0 };
310
311 Element p1(accumulator);
312 Element p2(add_state.point);
313 accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1);
314 p1_trace[trace_index] = p1;
315 p2_trace[trace_index] = p2;
316 p3_trace[trace_index] = accumulator;
317 is_double_or_add[trace_index] = false;
318 trace_index++;
319 }
320 // Now, `row.add_state` has been fully processed and we fill in the rest of the members of `row`.
321 accumulator_trace[msm_row_index] = accumulator;
322 row.q_add = true;
323 row.q_double = false;
324 row.q_skew = false;
325 row.msm_round = static_cast<uint32_t>(digit_idx);
326 row.msm_size = static_cast<uint32_t>(msm_size);
327 row.msm_count = static_cast<uint32_t>(offset);
328 row.pc = pc;
329 msm_row_index++;
330 }
331 // after processing each digit-slot, we now take care of doubling (as long as we are not at the last
332 // digit). We add an `MSMRow`, `row`, whose four `AddState` objects in `row.add_state`
333 // are null, but we also populate `p1_trace`, `p2_trace`, `p3_trace`, and `is_double_or_add` for four
334 // indices, corresponding to the w=4 doubling operations we need to perform. This embodies the numerical
335 // "coincidence" that `ADDITIONS_PER_ROW == NUM_WNAF_DIGIT_BITS`
336 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
337 auto& row = msm_rows[msm_row_index];
338 row.msm_transition = false;
339 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
340 row.msm_size = static_cast<uint32_t>(msm_size);
341 row.msm_count = static_cast<uint32_t>(0);
342 row.q_add = false;
343 row.q_double = true;
344 row.q_skew = false;
345 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
346 auto& add_state = row.add_state[point_idx];
347 add_state.add = false;
348 add_state.slice = 0;
349 add_state.point = { 0, 0 };
350 add_state.collision_inverse = 0;
351
352 p1_trace[trace_index] = accumulator;
353 p2_trace[trace_index] = accumulator; // dummy
354 accumulator = accumulator.dbl();
355 p3_trace[trace_index] = accumulator;
356 is_double_or_add[trace_index] = true;
357 trace_index++;
358 }
359 accumulator_trace[msm_row_index] = accumulator;
360 msm_row_index++;
361 } else // process `wnaf_skew`, i.e., the skew digit.
362 {
363 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
364 auto& row = msm_rows[msm_row_index];
365
366 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
367 ? msm_size % ADDITIONS_PER_ROW
369 const size_t offset = row_idx * ADDITIONS_PER_ROW;
370 row.msm_transition = false;
371 Element acc_expected = accumulator;
372 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
373 auto& add_state = row.add_state[point_idx];
374 add_state.add = num_points_in_row > point_idx;
375 add_state.slice = add_state.add ? msm[offset + point_idx].wnaf_skew ? 7 : 0 : 0;
376
377 add_state.point =
378 add_state.add
379 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
381 0, 0
382 }; // if the skew_bit is on, `slice == 7`. Then `precomputed_table[7] == -[P]`, as
383 // required for the skew logic.
384 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
385 auto p1 = accumulator;
386 accumulator = add_predicate ? accumulator + add_state.point : accumulator;
387 p1_trace[trace_index] = p1;
388 p2_trace[trace_index] = add_state.point;
389 p3_trace[trace_index] = accumulator;
390 is_double_or_add[trace_index] = false;
391 trace_index++;
392 }
393 row.q_add = false;
394 row.q_double = false;
395 row.q_skew = true;
396 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
397 row.msm_size = static_cast<uint32_t>(msm_size);
398 row.msm_count = static_cast<uint32_t>(offset);
399 row.pc = pc;
400 accumulator_trace[msm_row_index] = accumulator;
401 msm_row_index++;
402 }
403 }
404 }
405 }
406
407 // Normalize the points in the point trace
408 parallel_for_range(points_to_normalize.size(), [&](size_t start, size_t end) {
409 Element::batch_normalize(&points_to_normalize[start], end - start);
410 });
411
412 // inverse_trace is used to compute the value of the `collision_inverse` column in the ECCVM.
413 std::vector<FF> inverse_trace(num_point_adds_and_doubles);
414 parallel_for_range(num_point_adds_and_doubles, [&](size_t start, size_t end) {
415 for (size_t operation_idx = start; operation_idx < end; ++operation_idx) {
416 if (is_double_or_add[operation_idx]) {
417 inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y);
418 } else {
419 inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x);
420 }
421 }
422 FF::batch_invert(&inverse_trace[start], end - start);
423 });
424
425 // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data
426 // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse,
427 // row.add_state[0...3].lambda
428 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
429 const auto& msm = msms[msm_idx];
430 size_t trace_index = ((msm_row_counts[msm_idx] - 1) * ADDITIONS_PER_ROW);
431 size_t msm_row_index = msm_row_counts[msm_idx];
432 // 1st MSM row will have accumulator equal to the previous MSM output (or point at infinity for first MSM)
433 size_t accumulator_index = msm_row_counts[msm_idx] - 1;
434 const size_t msm_size = msm.size();
435 const size_t num_rows_per_digit =
436 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
437
438 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
439 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
440 auto& row = msm_rows[msm_row_index];
441 // note that we do not store the "intermediate accumulators" that are implicit *within* a row (i.e.,
442 // within a given `add_state` object). This is the reason why accumulator_index only increments once
443 // per `row_idx`.
444 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
445 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
446 row.accumulator_x = normalized_accumulator.x;
447 row.accumulator_y = normalized_accumulator.y;
448 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
449 auto& add_state = row.add_state[point_idx];
450
451 const auto& inverse = inverse_trace[trace_index];
452 const auto& p1 = p1_trace[trace_index];
453 const auto& p2 = p2_trace[trace_index];
454 add_state.collision_inverse = add_state.add ? inverse : 0;
455 add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0;
456 trace_index++;
457 }
458 accumulator_index++;
459 msm_row_index++;
460 }
461
462 // if digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1 we have to fill out our doubling row (which in fact
463 // amounts to 4 doublings)
464 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
465 MSMRow& row = msm_rows[msm_row_index];
466 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
467 const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x;
468 const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y;
469 row.accumulator_x = acc_x;
470 row.accumulator_y = acc_y;
471 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
472 auto& add_state = row.add_state[point_idx];
473 add_state.collision_inverse = 0; // no notion of "different x values" for a point doubling
474 const FF& dx = p1_trace[trace_index].x;
475 const FF& inverse = inverse_trace[trace_index]; // here, 2y
476 add_state.lambda = ((dx + dx + dx) * dx) * inverse;
477 trace_index++;
478 }
479 accumulator_index++;
480 msm_row_index++;
481 } else // this row corresponds to performing point additions to handle WNAF skew
482 // i.e. iterate over all the points in the MSM - if for a given point, `wnaf_skew == 1`,
483 // subtract the original point from the accumulator. if `digit_idx == NUM_WNAF_DIGITS_PER_SCALAR
484 // - 1` we have finished executing our double-and-add algorithm.
485 {
486 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
487 MSMRow& row = msm_rows[msm_row_index];
488 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
489 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
490 const size_t offset = row_idx * ADDITIONS_PER_ROW;
491 row.accumulator_x = normalized_accumulator.x;
492 row.accumulator_y = normalized_accumulator.y;
493 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
494 auto& add_state = row.add_state[point_idx];
495 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
496
497 const auto& inverse = inverse_trace[trace_index];
498 const auto& p1 = p1_trace[trace_index];
499 const auto& p2 = p2_trace[trace_index];
500 add_state.collision_inverse = add_predicate ? inverse : 0;
501 add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0;
502 trace_index++;
503 }
504 accumulator_index++;
505 msm_row_index++;
506 }
507 }
508 }
509 }
510
511 // populate the final row in the MSM execution trace.
512 // we always require 1 extra row at the end of the trace, because the x and y coordinates of the accumulator for
513 // row `i` are present at row `i+1`
514 Element final_accumulator(accumulator_trace.back());
515 MSMRow& final_row = msm_rows.back();
516 final_row.pc = static_cast<uint32_t>(pc_values.back());
517 final_row.msm_transition = true;
518 final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x;
519 final_row.accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y;
520 final_row.msm_size = 0;
521 final_row.msm_count = 0;
522 final_row.q_add = false;
523 final_row.q_double = false;
524 final_row.q_skew = false;
525 final_row.add_state = { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
526 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
527 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
528 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } };
529
530 return { msm_rows, point_table_read_counts };
531 }
532};
533} // namespace bb
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:59
static constexpr size_t ADDITIONS_PER_ROW
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR
static std::tuple< std::vector< MSMRow >, std::array< std::vector< size_t >, 2 > > compute_rows(const std::vector< MSM > &msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
Computes the row values for the Straus MSM columns of the ECCVM.
curve::BN254::Group CycleGroup
typename CycleGroup::affine_element AffineElement
bb::eccvm::MSM< CycleGroup > MSM
typename CycleGroup::element Element
static uint32_t num_eccvm_msm_rows(const size_t msm_size)
Get the number of rows in the 'msm' column section of the ECCVM associated with a single multiscalar ...
typename bb::g1 Group
Definition bn254.hpp:20
ssize_t offset
Definition engine.cpp:36
std::vector< ScalarMul< CycleGroup > > MSM
Entry point for Barretenberg command-line interface.
group< fq, fr, Bn254G1Params > g1
Definition g1.hpp:33
C slice(C const &container, size_t start)
Definition container.hpp:9
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
void parallel_for_range(size_t num_points, const std::function< void(size_t, size_t)> &func, size_t no_multhreading_if_less_or_equal)
Split a loop into several loops running in parallel.
Definition thread.cpp:102
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< AddState, 4 > add_state
static void batch_invert(std::span< field > coeffs) noexcept