Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sha256.cpp
Go to the documentation of this file.
2
3#include <algorithm>
4#include <array>
5#include <cstdint>
6#include <memory>
7#include <stdexcept>
8
10
11namespace bb::avm2::simulation {
12
13namespace {
14
15// constants come from barretenberg/cpp/src/barretenberg/crypto/sha256/sha256.cpp
16constexpr std::array<uint32_t, 64> round_constants{
17 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
18 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
19 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
20 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
21 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
22 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
23 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
24 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
25};
26
27} // namespace
28
29// Don't worry about any weird edge cases since we have fixed non-zero shifts
30MemoryValue Sha256::ror(const MemoryValue& x, uint32_t shift)
31{
32 auto val = x.as<uint32_t>();
33 // In a rotation, we decompose into a lhs and rhs (or hi and lo) part.
34 uint32_t lo = val & ((1U << shift) - 1);
35 uint32_t hi = val >> shift;
36 uint32_t result = lo << (32U - (shift & 31U)) | hi;
37
38 // Do this outside of an assert, in case this gets built without assert
39 bool lo_in_range = gt.gt(1UL << shift, lo); // Ensure the lower bits are in range
40 (void)lo_in_range; // To please GCC.
41 assert(lo_in_range && "Low Value in ROR out of range");
42 return MemoryValue::from<uint32_t>(result);
43}
44
45// Don't need to worry about edge cases with shifts since we know we only shift by 3 and 10 for sha256
46MemoryValue Sha256::shr(const MemoryValue& x, uint32_t shift)
47{
48 uint32_t input = x.as<uint32_t>();
49 // Get the lower shift bits
50 uint32_t lo = input & ((1UL << shift) - 1);
51 uint32_t hi = input >> shift;
52
53 // Do this outside of an assert, in case this gets built without assert
54 bool lo_in_range = gt.gt(1UL << shift, lo); // Ensure the lower bits are in range
55 (void)lo_in_range; // To please GCC.
56 assert(lo_in_range && "Low Value in SHR out of range");
57
58 return MemoryValue::from<uint32_t>(hi);
59}
60
61// This function is used to sum the values in the vector and return the result modulo 2^32.
63{
64 uint64_t sum = 0;
65 for (const auto& value : values) {
66 // This is safe, since we've already checked that the values are of tag U32
67 sum += value.as<uint32_t>();
68 }
69 uint32_t lo = static_cast<uint32_t>(sum);
70 uint32_t hi = sum >> 32;
71
72 // Do these outside of an assert, in case this gets built without assert
73 bool lo_in_range = gt.gt(1UL << 32, lo); // Ensure the lower bits are in range
74 bool hi_in_range = gt.gt(1UL << 32, hi); // Ensure the upper bits are in range
75 (void)lo_in_range; // To please GCC.
76 (void)hi_in_range; // To please GCC.
77 assert(lo_in_range && hi_in_range && "Sum in MODULO_SUM out of range");
78 return MemoryValue::from<uint32_t>(lo);
79}
80
82 MemoryAddress state_addr,
83 MemoryAddress input_addr,
84 MemoryAddress output_addr)
85{
86 uint32_t execution_clk = execution_id_manager.get_execution_id();
87 uint32_t space_id = memory.get_space_id();
88
89 // Default values are FF(0) as that is what the circuit would expect
91 state.fill(MemoryValue::from<FF>(0));
92
94 input.reserve(16);
95
96 // Check that the maximum addresss for the state, input, and output addresses are within the valid range.
97 // (1) Read the 8 element hash state from { state_addr, state_addr + 1, ..., state_addr + 7 }
98 // (2) Read the 16 element input from { input_addr, input_addr + 1, ..., input_addr + 15 }
99 // (3) Write the 8 element output to { output_addr, output_addr + 1, ..., output_addr + 7 }
100 bool state_addr_out_of_range = gt.gt(static_cast<uint64_t>(state_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
101 bool input_addr_out_of_range = gt.gt(static_cast<uint64_t>(input_addr) + 15, AVM_HIGHEST_MEM_ADDRESS);
102 bool output_addr_out_of_range = gt.gt(static_cast<uint64_t>(output_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
103
104 try {
105 if (state_addr_out_of_range || input_addr_out_of_range || output_addr_out_of_range) {
106 throw std::runtime_error("Memory address out of range for sha256 compression.");
107 }
108
109 // Read the hash state from memory. The state needs to be loaded atomically from memory (i.e. all 8 elements are
110 // read regardless of errors)
111 for (uint32_t i = 0; i < 8; ++i) {
112 state[i] = memory.get(state_addr + i);
113 }
114
115 // If any of the state values are not of tag U32, we throw an error.
116 if (std::ranges::any_of(state, [](const MemoryValue& val) { return val.get_tag() != MemoryTag::U32; })) {
117 throw std::runtime_error("Invalid tag for sha256 state values.");
118 }
119
120 // Load 16 elements representing the hash input from memory.
121 // Since the circuit loads this per row, we throw on the first error we find.
122 for (uint32_t i = 0; i < 16; ++i) {
123 input.emplace_back(memory.get(input_addr + i));
124 if (input[i].get_tag() != MemoryTag::U32) {
125 throw std::runtime_error("Invalid tag for sha256 input values.");
126 }
127 }
128
129 // Perform sha256 compression. Taken from `vm2/simulation/lib/sha256_compression.cpp` but using
130 // the bitwise operations and MemoryValues
132
133 // Fill first 16 words with the inputs
134 for (size_t i = 0; i < 16; ++i) {
135 w[i] = input[i];
136 }
137
138 // Extend the input data into the remaining 48 words
139 for (size_t i = 16; i < 64; ++i) {
140 MemoryValue s0 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 15], 7), ror(w[i - 15], 18)), shr(w[i - 15], 3));
141 MemoryValue s1 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 2], 17), ror(w[i - 2], 19)), shr(w[i - 2], 10));
142 // Could be explicit with an std::initializer_list<uint32_t> here, the array overload is more readable imo.
143 // std::spans are annoying to construct from literals
144 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2447r2.html)
145 w[i] = modulo_sum({ { w[i - 16], w[i - 7], s0, s1 } });
146 }
147
148 // Initialize round variables with previous block output
149 MemoryValue a = state[0];
150 MemoryValue b = state[1];
151 MemoryValue c = state[2];
152 MemoryValue d = state[3];
153 MemoryValue e = state[4];
154 MemoryValue f = state[5];
155 MemoryValue g = state[6];
156 MemoryValue h = state[7];
157
158 // Apply SHA-256 compression function to the message schedule
159 for (size_t i = 0; i < 64; ++i) {
160 MemoryValue S1 = bitwise.xor_op(bitwise.xor_op(ror(e, 6U), ror(e, 11U)), ror(e, 25U));
161 MemoryValue ch = bitwise.xor_op(bitwise.and_op(e, f), bitwise.and_op(~e, g));
162 MemoryValue S0 = bitwise.xor_op(bitwise.xor_op(ror(a, 2U), ror(a, 13U)), ror(a, 22U));
163 MemoryValue maj =
164 bitwise.xor_op(bitwise.xor_op(bitwise.and_op(a, b), bitwise.and_op(a, c)), bitwise.and_op(b, c));
165
166 auto prev_h = h; // Need to store the previous h value before updating it so we can use it in the modulo sum
167 h = g;
168 g = f;
169 f = e;
170 // e = d + temp1;
171 e = modulo_sum({ { d, prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i] } });
172 d = c;
173 c = b;
174 b = a;
175 // a = temp1 + temp2;
176 a = modulo_sum({ { prev_h, S1, ch, MemoryValue::from<uint32_t>(round_constants[i]), w[i], S0, maj } });
177 }
178
179 // Add into previous block output and return
181 modulo_sum({ { a, state[0] } }), modulo_sum({ { b, state[1] } }), modulo_sum({ { c, state[2] } }),
182 modulo_sum({ { d, state[3] } }), modulo_sum({ { e, state[4] } }), modulo_sum({ { f, state[5] } }),
183 modulo_sum({ { g, state[6] } }), modulo_sum({ { h, state[7] } }),
184 };
185
186 // Write the output back to memory.
187 for (uint32_t i = 0; i < 8; ++i) {
188 memory.set(output_addr + i, output[i]);
189 }
190
191 events.emit({ .execution_clk = execution_clk,
192 .space_id = space_id,
193 .state_addr = state_addr,
194 .input_addr = input_addr,
195 .output_addr = output_addr,
196 .state = state,
197 .input = input,
198 .output = output });
199 } catch (const std::exception& e) {
200 // If any error occurs, we emit an event with the error message.
202 output.fill(MemoryValue::from<FF>(0)); // Default output in case of error
203 events.emit({ .execution_clk = execution_clk,
204 .space_id = space_id,
205 .state_addr = state_addr,
206 .input_addr = input_addr,
207 .output_addr = output_addr,
208 .state = state,
209 .input = input,
210 .output = output });
211 throw; // Re-throw the exception after emitting the event
212 }
213}
214
215} // namespace bb::avm2::simulation
#define AVM_HIGHEST_MEM_ADDRESS
ValueTag get_tag() const
virtual uint32_t get_execution_id() const =0
MemoryValue ror(const MemoryValue &x, uint32_t shift)
Definition sha256.cpp:30
MemoryValue modulo_sum(std::span< const MemoryValue > values)
Definition sha256.cpp:62
EventEmitterInterface< Sha256CompressionEvent > & events
Definition sha256.hpp:51
void compression(MemoryInterface &memory, MemoryAddress state_addr, MemoryAddress input_addr, MemoryAddress output_addr) override
Definition sha256.cpp:81
ExecutionIdGetterInterface & execution_id_manager
Definition sha256.hpp:48
MemoryValue shr(const MemoryValue &x, uint32_t shift)
Definition sha256.cpp:46
FF a
FF b
constexpr uint32_t round_constants[64]
uint32_t MemoryAddress
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13