Skip to content

Commit 6d21836

Browse files
committed
Fix memory leak in error handling
1 parent c01d75f commit 6d21836

3 files changed

Lines changed: 38 additions & 17 deletions

File tree

fstwrapper/src/map.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,14 @@ pub extern "C" fn fst_map_get(ctx: *mut Context,
119119
-> libc::uint64_t {
120120
let key = cstr_to_str(key);
121121
let ctx = mutref_from_ptr!(ctx);
122-
ctx.has_error = false;
122+
ctx.clear();
123123
match ref_from_ptr!(ptr).get(key) {
124124
Some(val) => val,
125125
None => {
126126
let msg = str_to_cstr(&format!("Key '{}' not in map.", key));
127127
ctx.has_error = true;
128-
ctx.error_type = str_to_cstr("KeyError");
129-
ctx.error_debug = msg;
128+
ctx.error_type = str_to_cstr("py::KeyError");
130129
ctx.error_display = msg;
131-
ctx.error_description = msg;
132130
return 0;
133131
}
134132
}
@@ -150,16 +148,14 @@ make_free_fn!(fst_mapvalues_free, *mut map::Values);
150148
#[no_mangle]
151149
pub extern "C" fn fst_mapvalues_next(ctx: *mut Context, ptr: *mut map::Values) -> libc::uint64_t {
152150
let ctx = mutref_from_ptr!(ctx);
153-
ctx.has_error = false;
151+
ctx.clear();
154152
match mutref_from_ptr!(ptr).next() {
155153
Some(val) => val,
156154
None => {
157155
let msg = str_to_cstr("No more values.");
158156
ctx.has_error = true;
159157
ctx.error_type = str_to_cstr("StopIteration");
160-
ctx.error_debug = msg;
161158
ctx.error_display = msg;
162-
ctx.error_description = msg;
163159
return 0;
164160
}
165161
}

fstwrapper/src/util.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,28 @@ use fst::{Levenshtein,Regex};
1313
/// Exposes information about errors over the ABI
1414
pub struct Context {
1515
pub has_error: bool,
16-
pub error_type: *const libc::c_char,
17-
pub error_debug: *const libc::c_char,
18-
pub error_display: *const libc::c_char,
19-
pub error_description: *const libc::c_char,
16+
pub error_type: *mut libc::c_char,
17+
pub error_debug: *mut libc::c_char,
18+
pub error_display: *mut libc::c_char,
19+
pub error_description: *mut libc::c_char,
20+
}
21+
22+
impl Context {
23+
pub fn clear(&mut self) {
24+
self.has_error = false;
25+
if !self.error_type.is_null() {
26+
fst_string_free(self.error_type);
27+
}
28+
if !self.error_debug.is_null() {
29+
fst_string_free(self.error_debug);
30+
}
31+
if !self.error_display.is_null() {
32+
fst_string_free(self.error_display);
33+
}
34+
if !self.error_description.is_null() {
35+
fst_string_free(self.error_description);
36+
}
37+
}
2038
}
2139

2240

@@ -43,10 +61,10 @@ pub fn get_typename<T>(_: &T) -> &'static str {
4361
pub extern "C" fn fst_context_new() -> *mut Context {
4462
to_raw_ptr(Context {
4563
has_error: false,
46-
error_type: ptr::null(),
47-
error_description: ptr::null(),
48-
error_display: ptr::null(),
49-
error_debug: ptr::null(),
64+
error_type: ptr::null_mut(),
65+
error_description: ptr::null_mut(),
66+
error_display: ptr::null_mut(),
67+
error_debug: ptr::null_mut(),
5068
})
5169
}
5270
make_free_fn!(fst_context_free, *mut Context);

rust_fst/lib.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,25 @@ def find_library():
4848
'fst::Error::Regex': RegexError,
4949
'fst::Error::Levenshtein': LevenshteinError,
5050
'fst::Error::Io': IoError,
51+
'py::KeyError': KeyError
5152
}
5253

5354

5455
def checked_call(fn, ctx, *args):
5556
res = fn(ctx, *args)
5657
if not ctx.has_error:
5758
return res
58-
msg = ffi.string(ctx.error_display).decode('utf8').replace('\n', ' ')
5959
type_str = ffi.string(ctx.error_type).decode('utf8')
60+
if ctx.error_display != ffi.NULL:
61+
msg = ffi.string(ctx.error_display).decode('utf8').replace('\n', ' ')
62+
else:
63+
msg = None
6064
err_type = EXCEPTION_MAP.get(type_str)
6165
if err_type is FstError:
62-
desc_str = ffi.string(ctx.error_description).decode('utf8')
66+
if ctx.error_description != ffi.NULL:
67+
desc_str = ffi.string(ctx.error_description).decode('utf8')
68+
else:
69+
desc_str = None
6370
enum_val = re.match(r'(\w+)\(.*?\)', desc_str, re.DOTALL).group(1)
6471
err_type = EXCEPTION_MAP.get("{}::{}".format(type_str, enum_val))
6572
if err_type is None:

0 commit comments

Comments
 (0)