Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
named_union.hpp
Go to the documentation of this file.
1#pragma once
5#include <sstream>
6#define MSGPACK_NO_BOOST
7#include "msgpack/object_fwd.hpp"
8#include <concepts>
9#include <optional>
10#include <stdexcept>
11#include <string>
12#include <string_view>
13#include <type_traits>
14#include <variant>
15
16namespace bb {
17
21template <typename T>
22concept HasMsgpackSchemaName = requires {
23 { T::MSGPACK_SCHEMA_NAME } -> std::convertible_to<std::string_view>;
24};
25
33template <HasMsgpackSchemaName... Types> class NamedUnion {
34 public:
36
37 private:
39
40 // Helper to get index from type name
41 template <size_t I = 0> static std::optional<size_t> get_index_from_name(std::string_view name)
42 {
43 if constexpr (I < sizeof...(Types)) {
45 if (name == CurrentType::MSGPACK_SCHEMA_NAME) {
46 return I;
47 }
48 return get_index_from_name<I + 1>(name);
49 }
50 return std::nullopt;
51 }
52
53 // Helper to construct variant by index
54 template <size_t I = 0> static VariantType construct_by_index(size_t index, auto& o)
55 {
56 if constexpr (I < sizeof...(Types)) {
57 if (I == index) {
60 o.convert(obj);
61 return obj;
62 }
63 return construct_by_index<I + 1>(index, o);
64 }
65 throw_or_abort("Invalid variant index");
66 }
67
68 public:
69 NamedUnion() = default;
70
71 template <typename T>
73 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
76 {}
77
78 // Conversion operator to get the underlying variant
79 operator VariantType&() { return value_; }
80 operator const VariantType&() const { return value_; }
81
82 // Access the underlying variant
83 VariantType& get() { return value_; }
84 const VariantType& get() const { return value_; }
85
86 // Visit the variant
87 template <typename Visitor> decltype(auto) visit(Visitor&& vis) &&
88 {
89 return std::visit(std::forward<Visitor>(vis), std::move(value_));
90 }
91
92 template <typename Visitor> decltype(auto) visit(Visitor&& vis) const&
93 {
94 return std::visit(std::forward<Visitor>(vis), value_);
95 }
96
97 // Get the current type name
98 std::string_view get_type_name() const
99 {
100 return std::visit(
101 [](const auto& obj) -> std::string_view { return std::decay_t<decltype(obj)>::MSGPACK_SCHEMA_NAME; },
102 value_);
103 }
104
105 // Msgpack serialization
106 void msgpack_pack(auto& packer) const
107 {
108 packer.pack_array(2);
109 // First pack the type name
110 std::string_view type_name = get_type_name();
111 packer.pack(type_name);
112
113 // Then pack the actual object
114 std::visit([&packer](const auto& obj) { packer.pack(obj); }, value_);
115 }
116
117 // Msgpack deserialization
118 void msgpack_unpack(msgpack::object const& o)
119 {
120 constexpr size_t MAX_OUTPUT_CHARS = 100;
121 // access object assuming it is an array of size 2
122 if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
123 throw_or_abort("Expected an array of size 2 for NamedUnion deserialization, got " +
125 }
126 const auto& arr = o.via.array;
127 if (arr.ptr[0].type != msgpack::type::STR) {
128 throw_or_abort("Expected first element to be a string (type name) in NamedUnion deserialization");
129 }
130 std::string_view type_name = std::string_view(arr.ptr[0].via.str.ptr, arr.ptr[0].via.str.size);
132 if (!index_opt.has_value()) {
133 throw_or_abort("Unknown type name in NamedUnion deserialization: " + std::string(type_name));
134 }
135 size_t index = index_opt.value();
136 // Now construct the variant using the index
137 value_ = construct_by_index(index, arr.ptr[1]);
138 }
139
140 // Msgpack schema
141 void msgpack_schema(auto& packer) const
142 {
143 packer.pack_array(2);
144 packer.pack("named_union");
145 packer.pack_array(sizeof...(Types));
146 (
147 [&packer]() {
148 packer.pack_array(2);
149 packer.pack(Types::MSGPACK_SCHEMA_NAME);
150 // Abitrary mutable object.
151 packer.pack_schema(*std::make_unique<Types>());
152 }(),
153 ...); /* pack schemas of all template Args */
154 }
155};
156
157// Deduction guide
158template <typename... Types> NamedUnion(std::variant<Types...>) -> NamedUnion<Types...>;
159
160} // namespace bb
A wrapper around std::variant that provides msgpack serialization based on type names.
decltype(auto) visit(Visitor &&vis) &&
VariantType value_
void msgpack_schema(auto &packer) const
void msgpack_unpack(msgpack::object const &o)
void msgpack_pack(auto &packer) const
const VariantType & get() const
VariantType & get()
static VariantType construct_by_index(size_t index, auto &o)
static std::optional< size_t > get_index_from_name(std::string_view name)
std::variant< Types... > VariantType
std::string_view get_type_name() const
decltype(auto) visit(Visitor &&vis) const &
NamedUnion()=default
Concept to check if a type has a static NAME member.
Entry point for Barretenberg command-line interface.
std::string msgpack_to_json(msgpack::object const &o, size_t max_chars)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
void throw_or_abort(std::string const &err)