Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
ecc_msm_relation_impl.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
7#pragma once
10#include "ecc_msm_relation.hpp"
11
12namespace bb {
13
45template <typename FF>
46template <typename ContainerOverSubrelations, typename AllEntities, typename Parameters>
47void ECCVMMSMRelationImpl<FF>::accumulate(ContainerOverSubrelations& accumulator,
48 const AllEntities& in,
49 const Parameters& /*unused*/,
50 const FF& scaling_factor)
51{
52 using Accumulator = typename std::tuple_element_t<0, ContainerOverSubrelations>;
53 using View = typename Accumulator::View;
54
55 const auto& x1 = View(in.msm_x1);
56 const auto& y1 = View(in.msm_y1);
57 const auto& x2 = View(in.msm_x2);
58 const auto& y2 = View(in.msm_y2);
59 const auto& x3 = View(in.msm_x3);
60 const auto& y3 = View(in.msm_y3);
61 const auto& x4 = View(in.msm_x4);
62 const auto& y4 = View(in.msm_y4);
63 const auto& collision_inverse1 = View(in.msm_collision_x1);
64 const auto& collision_inverse2 = View(in.msm_collision_x2);
65 const auto& collision_inverse3 = View(in.msm_collision_x3);
66 const auto& collision_inverse4 = View(in.msm_collision_x4);
67 const auto& lambda1 = View(in.msm_lambda1);
68 const auto& lambda2 = View(in.msm_lambda2);
69 const auto& lambda3 = View(in.msm_lambda3);
70 const auto& lambda4 = View(in.msm_lambda4);
71 const auto& lagrange_first = View(in.lagrange_first);
72 const auto& add1 = View(in.msm_add1);
73 const auto& add1_shift = View(in.msm_add1_shift);
74 const auto& add2 = View(in.msm_add2);
75 const auto& add3 = View(in.msm_add3);
76 const auto& add4 = View(in.msm_add4);
77 const auto& acc_x = View(in.msm_accumulator_x);
78 const auto& acc_y = View(in.msm_accumulator_y);
79 const auto& acc_x_shift = View(in.msm_accumulator_x_shift);
80 const auto& acc_y_shift = View(in.msm_accumulator_y_shift);
81 const auto& slice1 = View(in.msm_slice1);
82 const auto& slice2 = View(in.msm_slice2);
83 const auto& slice3 = View(in.msm_slice3);
84 const auto& slice4 = View(in.msm_slice4);
85 const auto& msm_transition = View(in.msm_transition);
86 const auto& msm_transition_shift = View(in.msm_transition_shift);
87 const auto& round = View(in.msm_round);
88 const auto& round_shift = View(in.msm_round_shift);
89 const auto& q_add = View(in.msm_add);
90 const auto& q_add_shift = View(in.msm_add_shift);
91 const auto& q_skew = View(in.msm_skew);
92 const auto& q_skew_shift = View(in.msm_skew_shift);
93 const auto& q_double = View(in.msm_double);
94 const auto& q_double_shift = View(in.msm_double_shift);
95 const auto& msm_size = View(in.msm_size_of_msm);
96 // const auto& msm_size_shift = View(in.msm_size_of_msm_shift);
97 const auto& pc = View(in.msm_pc);
98 const auto& pc_shift = View(in.msm_pc_shift);
99 const auto& count = View(in.msm_count);
100 const auto& count_shift = View(in.msm_count_shift);
101 auto is_not_first_row = (-lagrange_first + 1);
102
178 auto add = [&](auto& xb,
179 auto& yb,
180 auto& xa,
181 auto& ya,
182 auto& lambda,
183 auto& selector,
184 auto& relation,
185 auto& collision_relation) {
186 // L * (1 - s) = 0
187 // (combine) (L * (xb - xa - 1) - yb - ya) * s + L = 0
188 relation += selector * (lambda * (xb - xa - 1) - (yb - ya)) + lambda;
189 collision_relation += selector * (xb - xa);
190 // x3 = L.L + (-xb - xa) * q + (1 - q) xa
191 auto x_out = lambda.sqr() + (-xb - xa - xa) * selector + xa;
192
193 // y3 = L . (xa - x3) - ya * q + (1 - q) ya
194 auto y_out = lambda * (xa - x_out) + (-ya - ya) * selector + ya;
195 return std::array<Accumulator, 2>{ x_out, y_out };
196 };
197
212 auto first_add = [&](auto& xb,
213 auto& yb,
214 auto& xa,
215 auto& ya,
216 auto& lambda,
217 auto& selector,
218 auto& relation,
219 auto& collision_relation) {
220 // N.B. this is brittle - should be curve agnostic but we don't propagate the curve parameter into relations!
221 constexpr auto offset_generator = get_precomputed_generators<g1, "ECCVM_OFFSET_GENERATOR", 1>()[0];
222 constexpr uint256_t oxu = offset_generator.x;
223 constexpr uint256_t oyu = offset_generator.y;
224 const Accumulator xo(oxu);
225 const Accumulator yo(oyu);
226
227 auto x = xo * selector + xb * (-selector + 1);
228 auto y = yo * selector + yb * (-selector + 1);
229 relation += lambda * (x - xa) - (y - ya); // degree 3
230 collision_relation += (xa - x);
231 auto x_out = lambda * lambda + (-x - xa);
232 auto y_out = lambda * (xa - x_out) - ya;
233 return std::array<Accumulator, 2>{ x_out, y_out };
234 };
235
236 // ADD operations (if row represents ADD round, not SKEW or DOUBLE)
237 Accumulator add_relation(0);
238 Accumulator x1_collision_relation(0);
239 Accumulator x2_collision_relation(0);
240 Accumulator x3_collision_relation(0);
241 Accumulator x4_collision_relation(0);
242 // If msm_transition = 1, we have started a new MSM. We need to treat the current value of [Acc] as the point at
243 // infinity!
244 auto [x_t1, y_t1] = first_add(acc_x, acc_y, x1, y1, lambda1, msm_transition, add_relation, x1_collision_relation);
245 auto [x_t2, y_t2] = add(x2, y2, x_t1, y_t1, lambda2, add2, add_relation, x2_collision_relation);
246 auto [x_t3, y_t3] = add(x3, y3, x_t2, y_t2, lambda3, add3, add_relation, x3_collision_relation);
247 auto [x_t4, y_t4] = add(x4, y4, x_t3, y_t3, lambda4, add4, add_relation, x4_collision_relation);
248
249 // Validate accumulator output matches ADD output if q_add = 1
250 // (this is a degree-6 relation)
251 std::get<0>(accumulator) += q_add * (acc_x_shift - x_t4) * scaling_factor;
252 std::get<1>(accumulator) += q_add * (acc_y_shift - y_t4) * scaling_factor;
253 std::get<2>(accumulator) += q_add * add_relation * scaling_factor;
254
262 auto dbl = [&](auto& x, auto& y, auto& lambda, auto& relation) {
263 auto two_x = x + x;
264 relation += lambda * (y + y) - (two_x + x) * x;
265 auto x_out = lambda.sqr() - two_x;
266 auto y_out = lambda * (x - x_out) - y;
267 return std::array<Accumulator, 2>{ x_out, y_out };
268 };
269
284 Accumulator double_relation(0);
285 auto [x_d1, y_d1] = dbl(acc_x, acc_y, lambda1, double_relation);
286 auto [x_d2, y_d2] = dbl(x_d1, y_d1, lambda2, double_relation);
287 auto [x_d3, y_d3] = dbl(x_d2, y_d2, lambda3, double_relation);
288 auto [x_d4, y_d4] = dbl(x_d3, y_d3, lambda4, double_relation);
289 std::get<10>(accumulator) += q_double * (acc_x_shift - x_d4) * scaling_factor;
290 std::get<11>(accumulator) += q_double * (acc_y_shift - y_d4) * scaling_factor;
291 std::get<12>(accumulator) += q_double * double_relation * scaling_factor;
292
302 Accumulator skew_relation(0);
303 static FF inverse_seven = FF(7).invert();
304 auto skew1_select = slice1 * inverse_seven;
305 auto skew2_select = slice2 * inverse_seven;
306 auto skew3_select = slice3 * inverse_seven;
307 auto skew4_select = slice4 * inverse_seven;
308 Accumulator x1_skew_collision_relation(0);
309 Accumulator x2_skew_collision_relation(0);
310 Accumulator x3_skew_collision_relation(0);
311 Accumulator x4_skew_collision_relation(0);
312 // add skew points iff row is a SKEW row AND slice = 7 (point_table[7] maps to -[P])
313 // N.B. while it would be nice to have one `add` relation for both ADD and SKEW rounds,
314 // this would increase degree of sumcheck identity vs evaluating them separately.
315 // This is because, for add rounds, the result of adding [P1], [Acc] is [P1 + Acc] or [P1]
316 // but for skew rounds, the result of adding [P1], [Acc] is [P1 + Acc] or [Acc]
317 auto [x_s1, y_s1] = add(x1, y1, acc_x, acc_y, lambda1, skew1_select, skew_relation, x1_skew_collision_relation);
318 auto [x_s2, y_s2] = add(x2, y2, x_s1, y_s1, lambda2, skew2_select, skew_relation, x2_skew_collision_relation);
319 auto [x_s3, y_s3] = add(x3, y3, x_s2, y_s2, lambda3, skew3_select, skew_relation, x3_skew_collision_relation);
320 auto [x_s4, y_s4] = add(x4, y4, x_s3, y_s3, lambda4, skew4_select, skew_relation, x4_skew_collision_relation);
321
322 // Validate accumulator output matches SKEW output if q_skew = 1
323 // (this is a degree-6 relation)
324 std::get<3>(accumulator) += q_skew * (acc_x_shift - x_s4) * scaling_factor;
325 std::get<4>(accumulator) += q_skew * (acc_y_shift - y_s4) * scaling_factor;
326 std::get<5>(accumulator) += q_skew * skew_relation * scaling_factor;
327
328 // Check x-coordinates do not collide if row is an ADD row or a SKEW row
329 // if either q_add or q_skew = 1, an inverse should exist for each computed relation
330 // Step 1: construct boolean selectors that describe whether we added a point at the current row
331 const auto add_first_point = add1 * q_add + q_skew * skew1_select;
332 const auto add_second_point = add2 * q_add + q_skew * skew2_select;
333 const auto add_third_point = add3 * q_add + q_skew * skew3_select;
334 const auto add_fourth_point = add4 * q_add + q_skew * skew4_select;
335 // Step 2: construct the delta between x-coordinates for each point add (depending on if row is ADD or SKEW)
336 const auto x1_delta = x1_skew_collision_relation * q_skew + x1_collision_relation * q_add;
337 const auto x2_delta = x2_skew_collision_relation * q_skew + x2_collision_relation * q_add;
338 const auto x3_delta = x3_skew_collision_relation * q_skew + x3_collision_relation * q_add;
339 const auto x4_delta = x4_skew_collision_relation * q_skew + x4_collision_relation * q_add;
340 // Step 3: x_delta * inverse - 1 = 0 if we performed a point addition (else x_delta * inverse = 0)
341 std::get<6>(accumulator) += (x1_delta * collision_inverse1 - add_first_point) * scaling_factor;
342 std::get<7>(accumulator) += (x2_delta * collision_inverse2 - add_second_point) * scaling_factor;
343 std::get<8>(accumulator) += (x3_delta * collision_inverse3 - add_third_point) * scaling_factor;
344 std::get<9>(accumulator) += (x4_delta * collision_inverse4 - add_fourth_point) * scaling_factor;
345
346 // Validate that if q_add = 1 or q_skew = 1, add1 also is 1
347 // TODO(@zac-williamson) Once we have a stable base to work off of, remove q_add1 and replace with q_msm_add +
348 // q_msm_skew (issue #2222)
349 std::get<32>(accumulator) += (add1 - q_add - q_skew) * scaling_factor;
350
351 // If add_i = 0, slice_i = 0
352 // When add_i = 0, force slice_i to ALSO be 0
353 std::get<13>(accumulator) += (-add1 + 1) * slice1 * scaling_factor;
354 std::get<14>(accumulator) += (-add2 + 1) * slice2 * scaling_factor;
355 std::get<15>(accumulator) += (-add3 + 1) * slice3 * scaling_factor;
356 std::get<16>(accumulator) += (-add4 + 1) * slice4 * scaling_factor;
357
358 // only one of q_skew, q_double, q_add can be nonzero
359 std::get<17>(accumulator) += (q_add * q_double + q_add * q_skew + q_double * q_skew) * scaling_factor;
360
361 // We look up wnaf slices by mapping round + pc -> slice
362 // We use an exact set membership check to validate that
363 // wnafs written in wnaf_relation == wnafs read in msm relation
364 // We use `add1/add2/add3/add4` to flag whether we are performing a wnaf read op
365 // We can set these to be Prover-defined as the set membership check implicitly ensures that the correct reads
366 // have occurred.
367 // if msm_transition = 0, round_shift - round = 0 or 1
368 const auto round_delta = round_shift - round;
369
370 // ROUND TRANSITION LOGIC (when round does not change)
371 // If msm_transition = 0 (next row) then round_delta = 0 or 1
372 const auto round_transition = round_delta * (-msm_transition_shift + 1);
373 std::get<18>(accumulator) += round_transition * (round_delta - 1) * scaling_factor;
374
375 // ROUND TRANSITION LOGIC (when round DOES change)
376 // round_transition describes whether we are transitioning between rounds of an MSM
377 // If round_transition = 1, the next row is either a double (if round != 31) or we are adding skew (if round ==
378 // 31) round_transition * skew * (round - 31) = 0 (if round tx and skew, round == 31) round_transition * (skew +
379 // double - 1) = 0 (if round tx, skew XOR double = 1) i.e. if round tx and round != 31, double = 1
380 std::get<19>(accumulator) += round_transition * q_skew_shift * (round - 31) * scaling_factor;
381 std::get<20>(accumulator) += round_transition * (q_skew_shift + q_double_shift - 1) * scaling_factor;
382
383 // if no double or no skew, round_delta = 0
384 std::get<21>(accumulator) += round_transition * (-q_double_shift + 1) * (-q_skew_shift + 1) * scaling_factor;
385
386 // if double, next double != 1
387 std::get<22>(accumulator) += q_double * q_double_shift * scaling_factor;
388
389 // if double, next add = 1
390 std::get<23>(accumulator) += q_double * (-q_add_shift + 1) * scaling_factor;
391
392 // updating count
393 // if msm_transition = 0 and round_transition = 0, count_shift = count + add1 + add2 + add3 + add4
394 // todo: we need this?
395 std::get<24>(accumulator) += (-msm_transition_shift + 1) * (-round_delta + 1) *
396 (count_shift - count - add1 - add2 - add3 - add4) * scaling_factor;
397
398 std::get<25>(accumulator) +=
399 is_not_first_row * (-msm_transition_shift + 1) * round_delta * count_shift * scaling_factor;
400
401 // if msm_transition = 1, count_shift = 0
402 std::get<26>(accumulator) += is_not_first_row * msm_transition_shift * count_shift * scaling_factor;
403
404 // if msm_transition = 1, pc = pc_shift + msm_size
405 // `ecc_set_relation` ensures `msm_size` maps to `transcript.msm_count` for the current value of `pc`
406 std::get<27>(accumulator) += is_not_first_row * msm_transition_shift * (msm_size + pc_shift - pc) * scaling_factor;
407
408 // Addition continuity checks
409 // We want to RULE OUT the following scenarios:
410 // Case 1: add2 = 1, add1 = 0
411 // Case 2: add3 = 1, add2 = 0
412 // Case 3: add4 = 1, add3 = 0
413 // These checks ensure that the current row does not skip points (for both ADD and SKEW ops)
414 // This is part of a wider set of checks we use to ensure that all point data is used in the assigned
415 // multiscalar multiplication operation.
416 // (and not in a different MSM operation)
417 std::get<28>(accumulator) += add2 * (-add1 + 1) * scaling_factor;
418 std::get<29>(accumulator) += add3 * (-add2 + 1) * scaling_factor;
419 std::get<30>(accumulator) += add4 * (-add3 + 1) * scaling_factor;
420
421 // Final continuity check.
422 // If an addition spans two rows, we need to make sure that the following scenario is RULED OUT:
423 // add4 = 0 on the CURRENT row, add1 = 1 on the NEXT row
424 // We must apply the above for the two cases:
425 // Case 1: q_add = 1 on the CURRENT row, q_add = 1 on the NEXT row
426 // Case 2: q_skew = 1 on the CURRENT row, q_skew = 1 on the NEXT row
427 // (i.e. if q_skew = 1, q_add_shift = 1 this implies an MSM transition so we skip this continuity check)
428 std::get<31>(accumulator) +=
429 (q_add * q_add_shift + q_skew * q_skew_shift) * (-add4 + 1) * add1_shift * scaling_factor;
430
431 // remaining checks (done in ecc_set_relation.hpp, ecc_lookup_relation.hpp)
432 // when transition occurs, perform set membership lookup on (accumulator / pc / msm_size)
433 // perform set membership lookups on add_i * (pc / round / slice_i)
434 // perform lookups on (pc / slice_i / x / y)
435}
436
437} // namespace bb
static void accumulate(ContainerOverSubrelations &accumulator, const AllEntities &in, const Parameters &, const FF &scaling_factor)
MSM relations that evaluate the Strauss multiscalar multiplication algorithm.
Entry point for Barretenberg command-line interface.
group< fq, fr, Bn254G1Params > g1
Definition g1.hpp:33
typename Flavor::FF FF
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
constexpr field invert() const noexcept