Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
batched_affine_addition.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: not started, auditors: [], date: YYYY-MM-DD }
3// external_1: { status: not started, auditors: [], date: YYYY-MM-DD }
4// external_2: { status: not started, auditors: [], date: YYYY-MM-DD }
5// =====================
6
11#include <algorithm>
12#include <execution>
13#include <set>
14
15namespace bb {
16
17template <typename Curve>
19 const std::span<G1>& points, const std::vector<size_t>& sequence_counts)
20{
21 PROFILE_THIS_NAME("BatchedAffineAddition::add_in_place");
22 // Instantiate scratch space for point addition denominators and their calculation
23 std::vector<Fq> scratch_space_vector(points.size());
24 std::span<Fq> scratch_space(scratch_space_vector);
25
26 // Divide the work into groups of addition sequences to be reduced by each thread
27 auto [addition_sequences_, sequence_tags] = construct_thread_data(points, sequence_counts, scratch_space);
28 auto& addition_sequences = addition_sequences_;
29
30 const size_t num_threads = addition_sequences.size();
31 parallel_for(num_threads, [&](size_t thread_idx) { batched_affine_add_in_place(addition_sequences[thread_idx]); });
32
33 // Construct a vector of the reduced points, accounting for sequences that may have been split across threads
34 std::vector<G1> reduced_points;
35 size_t prev_tag = std::numeric_limits<size_t>::max();
36 for (auto [sequences, tags] : zip_view(addition_sequences, sequence_tags)) {
37 // Extract the first num-sequence-counts many points from each add sequence
38 for (size_t i = 0; i < sequences.sequence_counts.size(); ++i) {
39 if (tags[i] == prev_tag) {
40 reduced_points.back() = reduced_points.back() + sequences.points[i];
41 } else {
42 reduced_points.emplace_back(sequences.points[i]);
43 }
44 prev_tag = tags[i];
45 }
46 }
47
48 return reduced_points;
49}
50
51template <typename Curve>
53 const std::span<G1>& points, const std::vector<size_t>& sequence_counts, const std::span<Fq>& scratch_space)
54{
55 // Compute the endpoints of the sequences within the points array from the sequence counts
56 std::vector<size_t> sequence_endpoints;
57 size_t total_count = 0;
58 for (const auto& count : sequence_counts) {
59 total_count += count;
60 sequence_endpoints.emplace_back(total_count);
61 }
62
63 if (points.size() != total_count) {
64 throw_or_abort("Number of input points does not match sequence counts!");
65 }
66
67 // Determine the optimal number of threads for parallelization
68 const size_t MIN_POINTS_PER_THREAD = 1 << 14; // heuristic; anecdotally optimal for practical cases
69 const size_t total_num_points = points.size();
70 const size_t optimal_threads = total_num_points / MIN_POINTS_PER_THREAD;
71 const size_t num_threads = std::max(1UL, std::min(get_num_cpus(), optimal_threads));
72 // Distribute the work as evenly as possible across threads
73 const size_t base_thread_size = total_num_points / num_threads;
74 const size_t leftover_size = total_num_points % num_threads;
75 std::vector<size_t> thread_sizes(num_threads, base_thread_size);
76 for (size_t i = 0; i < leftover_size; ++i) {
77 thread_sizes[i]++;
78 }
79
80 // Construct the point spans for each thread according to the distribution determined above
81 std::vector<std::span<G1>> thread_points;
82 std::vector<std::span<Fq>> thread_scratch_space;
83 std::vector<size_t> thread_endpoints;
84 size_t point_index = 0;
85 for (auto size : thread_sizes) {
86 thread_points.push_back(points.subspan(point_index, size));
87 thread_scratch_space.push_back(scratch_space.subspan(point_index, size));
88 point_index += size;
89 thread_endpoints.emplace_back(point_index);
90 }
91
92 // Construct the union of the thread and sequence endpoints by combining, sorting, then removing duplicates. This is
93 // used to break the points into sequences for each thread while tracking tags so that sequences split across one of
94 // more threads can be properly reconstructed.
95 std::vector<size_t> all_endpoints;
96 all_endpoints.reserve(thread_endpoints.size() + sequence_endpoints.size());
97 all_endpoints.insert(all_endpoints.end(), thread_endpoints.begin(), thread_endpoints.end());
98 all_endpoints.insert(all_endpoints.end(), sequence_endpoints.begin(), sequence_endpoints.end());
99 std::sort(all_endpoints.begin(), all_endpoints.end());
100 auto last = std::unique(all_endpoints.begin(), all_endpoints.end());
101 all_endpoints.erase(last, all_endpoints.end());
102
103 // Construct sequence counts and tags for each thread using the set of all thread and sequence endpoints
104 size_t prev_endpoint = 0;
105 size_t thread_idx = 0;
106 size_t sequence_idx = 0;
107 std::vector<std::vector<size_t>> thread_sequence_counts(num_threads);
108 std::vector<std::vector<size_t>> thread_sequence_tags(num_threads);
109 for (auto& endpoint : all_endpoints) {
110 size_t chunk_size = endpoint - prev_endpoint;
111 thread_sequence_counts[thread_idx].emplace_back(chunk_size);
112 thread_sequence_tags[thread_idx].emplace_back(sequence_idx);
113 if (endpoint == thread_endpoints[thread_idx]) {
114 thread_idx++;
115 }
116 if (endpoint == sequence_endpoints[sequence_idx]) {
117 sequence_idx++;
118 }
119 prev_endpoint = endpoint;
120 }
121
122 if (thread_sequence_counts.size() != thread_points.size()) {
123 throw_or_abort("Mismatch in sequence count construction!");
124 }
125
126 // Construct the addition sequences for each thread
127 std::vector<AdditionSequences> addition_sequences;
128 for (size_t i = 0; i < num_threads; ++i) {
129 addition_sequences.push_back(
130 AdditionSequences{ thread_sequence_counts[i], thread_points[i], thread_scratch_space[i] });
131 }
132
133 return { addition_sequences, thread_sequence_tags };
134}
135
136template <typename Curve>
138 Curve>::batch_compute_point_addition_slope_inverses(const AdditionSequences& add_sequences)
139{
140 auto points = add_sequences.points;
141 auto sequence_counts = add_sequences.sequence_counts;
142
143 // Count the total number of point pairs to be added across all addition sequences
144 size_t total_num_pairs{ 0 };
145 for (auto& count : sequence_counts) {
146 total_num_pairs += count >> 1;
147 }
148
149 // Define scratch space for batched inverse computations and eventual storage of denominators
150 BB_ASSERT_GTE(add_sequences.scratch_space.size(), 2 * total_num_pairs);
151 std::span<Fq> denominators = add_sequences.scratch_space.subspan(0, total_num_pairs);
152 std::span<Fq> differences = add_sequences.scratch_space.subspan(total_num_pairs, 2 * total_num_pairs);
153
154 // Compute and store successive products of differences (x_2 - x_1)
155 Fq accumulator = 1;
156 size_t point_idx = 0;
157 size_t pair_idx = 0;
158 for (auto& count : sequence_counts) {
159 const auto num_pairs = count >> 1;
160 for (size_t j = 0; j < num_pairs; ++j) {
161 BB_ASSERT_LT(pair_idx, total_num_pairs);
162 const auto& x1 = points[point_idx++].x;
163 const auto& x2 = points[point_idx++].x;
164
165 // It is assumed that the input points are random and thus w/h/p do not share an x-coordinate
166 ASSERT(x1 != x2);
167
168 auto diff = x2 - x1;
169 differences[pair_idx] = diff;
170
171 // Store and update the running product of differences at each stage
172 denominators[pair_idx++] = accumulator;
173 accumulator *= diff;
174 }
175 // If number of points in the sequence is odd, we skip the last one since it has no pair
176 point_idx += (count & 0x01ULL);
177 }
178
179 // Invert the full product of differences
180 Fq inverse = accumulator.invert();
181
182 // Compute the individual point-pair addition denominators 1/(x2 - x1)
183 for (size_t i = 0; i < total_num_pairs; ++i) {
184 size_t idx = total_num_pairs - 1 - i;
185 denominators[idx] *= inverse;
186 inverse *= differences[idx];
187 }
188
189 return denominators;
190}
191
192template <typename Curve>
194{
195 const size_t num_points = add_sequences.points.size();
196 if (num_points == 0 || num_points == 1) { // nothing to do
197 return;
198 }
199
200 // Batch compute terms of the form 1/(x2 -x1) for each pair to be added in this round
201 std::span<Fq> denominators = batch_compute_point_addition_slope_inverses(add_sequences);
202
203 auto points = add_sequences.points;
204 auto sequence_counts = add_sequences.sequence_counts;
205
206 // Compute pairwise in-place additions for all sequences with more than 1 point
207 size_t point_idx = 0; // index for points to be summed
208 size_t result_point_idx = 0; // index for result points
209 size_t pair_idx = 0; // index into array of denominators for each pair
210 bool more_additions = false;
211 for (auto& count : sequence_counts) {
212 const auto num_pairs = count >> 1;
213 const bool overflow = static_cast<bool>(count & 0x01ULL);
214 // Compute the sum of all pairs in the sequence and store the result in the same points array
215 for (size_t j = 0; j < num_pairs; ++j) {
216 const auto& point_1 = points[point_idx++]; // first summand
217 const auto& point_2 = points[point_idx++]; // second summand
218 const auto& denominator = denominators[pair_idx++]; // denominator needed in add formula
219 auto& result = points[result_point_idx++]; // target for addition result
220
221 result = affine_add_with_denominator(point_1, point_2, denominator);
222 }
223 // If the sequence had an odd number of points, simply carry the unpaired point over to the next round
224 if (overflow) {
225 points[result_point_idx++] = points[point_idx++];
226 }
227
228 // Update the sequence counts in place for the next round
229 const uint32_t updated_sequence_count = static_cast<uint32_t>(num_pairs) + static_cast<uint32_t>(overflow);
230 count = updated_sequence_count;
231
232 // More additions are required if any sequence has not yet been reduced to a single point
233 more_additions = more_additions || updated_sequence_count > 1;
234 }
235
236 // Recursively perform pairwise additions until all sequences have been reduced to a single point
237 if (more_additions) {
238 const size_t updated_point_count = result_point_idx;
239 std::span<G1> updated_points(&points[0], updated_point_count);
240 return batched_affine_add_in_place(
241 AdditionSequences{ sequence_counts, updated_points, add_sequences.scratch_space });
242 }
243}
244
247} // namespace bb
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:101
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:115
#define ASSERT(expression,...)
Definition assert.hpp:49
Class for handling fast batched affine addition of large sets of EC points.
static std::vector< G1 > add_in_place(const std::span< G1 > &points, const std::vector< size_t > &sequence_counts)
Given a set of points and sequence counts, peform addition to reduce each sequence to a single point.
static void batched_affine_add_in_place(AdditionSequences add_sequences)
Internal method for in-place summation of a single set of addition sequences.
static ThreadData construct_thread_data(const std::span< G1 > &points, const std::vector< size_t > &sequence_counts, const std::span< Fq > &scratch_space)
Construct the set of AdditionSequences to be handled by each thread.
Entry point for Barretenberg command-line interface.
size_t get_num_cpus()
Definition thread.hpp:12
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:72
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
#define PROFILE_THIS_NAME(name)
Definition op_count.hpp:16
void throw_or_abort(std::string const &err)