Skip to content

Commit 2049052

Browse files
authored
Merge pull request #84 from k-wasniowski/sframe-api
sframe api refactor
2 parents e109627 + 17b67cf commit 2049052

11 files changed

Lines changed: 225 additions & 164 deletions

File tree

include/sframe/result.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ class SFrameError
4848
const char* message_ = nullptr;
4949
};
5050

51-
// Helper to convert SFrameError to appropriate exception type
51+
#ifdef __cpp_exceptions
5252
void
53-
throw_on_error(const SFrameError& error);
53+
throw_sframe_error(const SFrameError& error);
54+
#endif
5455

5556
template<typename T>
5657
class Result
@@ -96,6 +97,17 @@ class Result
9697

9798
bool is_err() const { return std::holds_alternative<SFrameError>(data_); }
9899

100+
#ifdef __cpp_exceptions
101+
T unwrap()
102+
{
103+
if (std::holds_alternative<SFrameError>(data_)) {
104+
throw_sframe_error(std::get<SFrameError>(data_));
105+
}
106+
107+
return std::move(std::get<T>(data_));
108+
}
109+
#endif
110+
99111
private:
100112
std::variant<T, SFrameError> data_;
101113
};
@@ -135,24 +147,21 @@ class Result<void>
135147

136148
bool is_err() const { return error_.has_value(); }
137149

150+
#ifdef __cpp_exceptions
151+
void unwrap()
152+
{
153+
if (error_.has_value()) {
154+
throw_sframe_error(error_.value());
155+
}
156+
}
157+
#endif
158+
138159
private:
139160
std::optional<SFrameError> error_;
140161
};
141162

142163
} // namespace SFRAME_NAMESPACE
143164

144-
// Unwrap a Result<T>, throwing the corresponding exception on error.
145-
// Use in functions that have NOT yet been migrated away from exceptions.
146-
// Usage: const auto val = SFRAME_VALUE_OR_THROW(some_result_expr);
147-
#define SFRAME_VALUE_OR_THROW(expr) \
148-
([&]() { \
149-
auto _result = (expr); \
150-
if (_result.is_err()) { \
151-
SFRAME_NAMESPACE::throw_on_error(_result.error()); \
152-
} \
153-
return _result.value(); \
154-
}())
155-
156165
// Unwrap a Result<T> into `var`, propagating the error by early return.
157166
// Use in functions that already return Result<U>.
158167
// Usage: SFRAME_VALUE_OR_RETURN(val, some_result_expr);

include/sframe/sframe.h

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include <sframe/result.h>
99
#include <sframe/vector.h>
1010

11+
#ifdef __cpp_exceptions
12+
#include <stdexcept>
13+
#endif
14+
1115
#include <namespace.h>
1216

1317
// These constants define the size of certain internal data structures if
@@ -28,6 +32,7 @@
2832

2933
namespace SFRAME_NAMESPACE {
3034

35+
#ifdef __cpp_exceptions
3136
struct crypto_error : std::runtime_error
3237
{
3338
crypto_error();
@@ -60,6 +65,7 @@ struct invalid_key_usage_error : std::runtime_error
6065
using parent = std::runtime_error;
6166
using parent::parent;
6267
};
68+
#endif
6369

6470
enum class CipherSuite : uint16_t
6571
{
@@ -111,15 +117,15 @@ class Context
111117
Context(CipherSuite suite);
112118
virtual ~Context();
113119

114-
void add_key(KeyID kid, KeyUsage usage, input_bytes key);
120+
Result<void> add_key(KeyID kid, KeyUsage usage, input_bytes key);
115121

116-
output_bytes protect(KeyID key_id,
117-
output_bytes ciphertext,
118-
input_bytes plaintext,
119-
input_bytes metadata);
120-
output_bytes unprotect(output_bytes plaintext,
121-
input_bytes ciphertext,
122-
input_bytes metadata);
122+
Result<output_bytes> protect(KeyID key_id,
123+
output_bytes ciphertext,
124+
input_bytes plaintext,
125+
input_bytes metadata);
126+
Result<output_bytes> unprotect(output_bytes plaintext,
127+
input_bytes ciphertext,
128+
input_bytes metadata);
123129

124130
static constexpr size_t max_overhead = 17 + 16;
125131
static constexpr size_t max_metadata_size = 512;
@@ -150,29 +156,30 @@ class MLSContext : protected Context
150156

151157
MLSContext(CipherSuite suite_in, size_t epoch_bits_in);
152158

153-
void add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret);
154-
void add_epoch(EpochID epoch_id,
155-
input_bytes sframe_epoch_secret,
156-
size_t sender_bits);
159+
Result<void> add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret);
160+
Result<void> add_epoch(EpochID epoch_id,
161+
input_bytes sframe_epoch_secret,
162+
size_t sender_bits);
157163
void purge_before(EpochID keeper);
158164

159-
output_bytes protect(EpochID epoch_id,
160-
SenderID sender_id,
161-
output_bytes ciphertext,
162-
input_bytes plaintext,
163-
input_bytes metadata);
164-
output_bytes protect(EpochID epoch_id,
165-
SenderID sender_id,
166-
ContextID context_id,
167-
output_bytes ciphertext,
168-
input_bytes plaintext,
169-
input_bytes metadata);
170-
171-
output_bytes unprotect(output_bytes plaintext,
172-
input_bytes ciphertext,
173-
input_bytes metadata);
165+
Result<output_bytes> protect(EpochID epoch_id,
166+
SenderID sender_id,
167+
output_bytes ciphertext,
168+
input_bytes plaintext,
169+
input_bytes metadata);
170+
Result<output_bytes> protect(EpochID epoch_id,
171+
SenderID sender_id,
172+
ContextID context_id,
173+
output_bytes ciphertext,
174+
input_bytes plaintext,
175+
input_bytes metadata);
176+
177+
Result<output_bytes> unprotect(output_bytes plaintext,
178+
input_bytes ciphertext,
179+
input_bytes metadata);
174180

175181
private:
182+
// NOLINTBEGIN(clang-analyzer-core.uninitialized.Assign)
176183
struct EpochKeys
177184
{
178185
static constexpr size_t max_secret_size = 64;
@@ -184,20 +191,22 @@ class MLSContext : protected Context
184191
uint64_t max_sender_id;
185192
uint64_t max_context_id;
186193

187-
EpochKeys(EpochID full_epoch_in,
188-
input_bytes sframe_epoch_secret_in,
189-
size_t epoch_bits,
190-
size_t sender_bits_in);
194+
EpochKeys() = default;
195+
static Result<EpochKeys> create(EpochID full_epoch_in,
196+
input_bytes sframe_epoch_secret_in,
197+
size_t epoch_bits,
198+
size_t sender_bits_in);
191199
Result<owned_bytes<max_secret_size>> base_key(CipherSuite suite,
192200
SenderID sender_id) const;
193201
};
202+
// NOLINTEND(clang-analyzer-core.uninitialized.Assign)
194203

195204
void purge_epoch(EpochID epoch_id);
196205

197-
KeyID form_key_id(EpochID epoch_id,
198-
SenderID sender_id,
199-
ContextID context_id) const;
200-
void ensure_key(KeyID key_id, KeyUsage usage);
206+
Result<KeyID> form_key_id(EpochID epoch_id,
207+
SenderID sender_id,
208+
ContextID context_id) const;
209+
Result<void> ensure_key(KeyID key_id, KeyUsage usage);
201210

202211
const size_t epoch_bits;
203212
const size_t epoch_mask;

src/crypto_boringssl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE {
1717
/// Convert between native identifiers / errors and OpenSSL ones
1818
///
1919

20+
#ifdef __cpp_exceptions
2021
crypto_error::crypto_error()
2122
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
2223
{
2324
}
25+
#endif
2426

2527
static Result<const EVP_MD*>
2628
openssl_digest_type(CipherSuite suite)

src/crypto_openssl11.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ using scoped_hmac_ctx = std::unique_ptr<HMAC_CTX, decltype(&HMAC_CTX_free)>;
2323
/// Convert between native identifiers / errors and OpenSSL ones
2424
///
2525

26+
#ifdef __cpp_exceptions
2627
crypto_error::crypto_error()
2728
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
2829
{
2930
}
31+
#endif
3032

3133
static Result<const EVP_MD*>
3234
openssl_digest_type(CipherSuite suite)

src/crypto_openssl3.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE {
1717
/// Convert between native identifiers / errors and OpenSSL ones
1818
///
1919

20+
#ifdef __cpp_exceptions
2021
crypto_error::crypto_error()
2122
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
2223
{
2324
}
25+
#endif
2426

2527
static Result<const EVP_CIPHER*>
2628
openssl_cipher(CipherSuite suite)

src/result.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1-
#include <sframe/result.h>
21
#include <sframe/sframe.h>
32

43
namespace SFRAME_NAMESPACE {
54

5+
#ifdef __cpp_exceptions
6+
unsupported_ciphersuite_error::unsupported_ciphersuite_error()
7+
: std::runtime_error("Unsupported ciphersuite")
8+
{
9+
}
10+
11+
authentication_error::authentication_error()
12+
: std::runtime_error("AEAD authentication failure")
13+
{
14+
}
15+
616
void
7-
throw_on_error(const SFrameError& error)
17+
throw_sframe_error(const SFrameError& error)
818
{
919
switch (error.type()) {
20+
case SFrameErrorType::internal_error:
21+
throw std::runtime_error(error.message() ? error.message()
22+
: "SFrame internal error");
1023
case SFrameErrorType::buffer_too_small_error:
1124
throw buffer_too_small_error(error.message());
1225
case SFrameErrorType::invalid_parameter_error:
@@ -19,9 +32,8 @@ throw_on_error(const SFrameError& error)
1932
throw authentication_error();
2033
case SFrameErrorType::invalid_key_usage_error:
2134
throw invalid_key_usage_error(error.message());
22-
default:
23-
throw std::runtime_error(error.message());
2435
}
2536
}
37+
#endif
2638

2739
} // namespace SFRAME_NAMESPACE

0 commit comments

Comments
 (0)