Skip to content

Commit bee9be1

Browse files
committed
upgrade zstd-rs, simplify code
1 parent cbce68c commit bee9be1

4 files changed

Lines changed: 47 additions & 75 deletions

File tree

Cargo.lock

Lines changed: 9 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ required-features = ["benchmark"]
3131
crate-type = ["cdylib"]
3232

3333
[dependencies]
34-
zstd = {version = "0.6.0", git = "https://github.com/phiresky/zstd-rs", branch = "master"}
35-
zstd-safe = {version = "3.0.0", git = "https://github.com/phiresky/zstd-rs", branch = "master"}
34+
zstd = {version = "0.11.2", features = ["experimental"]}
3635
#zstd = {version = "0.5.3", path="../zstd-rs"}
3736
#zstd = {version = "=0.5.4"}
3837
anyhow = "1.0.44"

src/basic.rs

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,19 @@ pub(crate) fn zstd_compress_fn<'a>(
5151
} else {
5252
ctx.get(arg_is_compact).context("is_compact argument")?
5353
};
54+
let out = Vec::new();
55+
use zstd::stream::write::Encoder;
5456

55-
let dict = if ctx.len() <= arg_dict {
56-
None
57+
let encoder = if ctx.len() <= arg_dict {
58+
Encoder::new(out, level)
5759
} else {
5860
match ctx.get_raw(arg_dict) {
59-
ValueRef::Integer(-1) => None,
60-
ValueRef::Null => None,
61-
ValueRef::Blob(d) => Some(Arc::new(wrap_encoder_dict(d.to_vec(), level))),
62-
ValueRef::Integer(_) => Some(
63-
encoder_dict_from_ctx(ctx, arg_dict, level)
61+
ValueRef::Integer(-1) | ValueRef::Null => Encoder::new(out, level),
62+
ValueRef::Blob(d) => Encoder::with_dictionary(out, level, d),
63+
//Some(Arc::new(wrap_encoder_dict(d.to_vec(), level))),
64+
ValueRef::Integer(_) => Encoder::with_prepared_dictionary(
65+
out,
66+
&*encoder_dict_from_ctx(ctx, arg_dict, level)
6467
.context("loading dictionary from int")?,
6568
),
6669
other => anyhow::bail!(
@@ -69,33 +72,26 @@ pub(crate) fn zstd_compress_fn<'a>(
6972
),
7073
}
7174
};
72-
73-
let res = {
74-
let out = Vec::new();
75-
let mut encoder = match &dict {
76-
Some(dict) => zstd::stream::write::Encoder::with_prepared_dictionary(out, dict),
77-
None => zstd::stream::write::Encoder::new(out, level),
78-
}
79-
.context("creating zstd encoder")?;
80-
/* encoder
81-
.get_operation_mut()
82-
.context
83-
.set_pledged_src_size(input_value.len() as u64)
84-
.context("pledge")?;*/
85-
if compact {
86-
encoder
87-
.include_checksum(false)
88-
.context("disable checksums")?;
89-
encoder.include_contentsize(false).context("cs")?;
90-
encoder.include_dictid(false).context("did")?;
91-
encoder.include_magicbytes(false).context("did")?;
92-
}
75+
let mut encoder = encoder.context("creating zstd encoder")?;
76+
77+
/* encoder
78+
.get_operation_mut()
79+
.context
80+
.set_pledged_src_size(input_value.len() as u64)
81+
.context("pledge")?;*/
82+
if compact {
9383
encoder
94-
.write_all(input_value)
95-
.context("writing data to zstd encoder")?;
96-
encoder.finish().context("finishing zstd stream")?
97-
};
98-
drop(dict); // to make sure the dict is still in scope because of https://github.com/gyscos/zstd-rs/issues/55
84+
.include_checksum(false)
85+
.context("disable checksums")?;
86+
encoder.include_contentsize(false).context("cs")?;
87+
encoder.include_dictid(false).context("did")?;
88+
encoder.include_magicbytes(false).context("did")?;
89+
}
90+
encoder
91+
.write_all(input_value)
92+
.context("writing data to zstd encoder")?;
93+
let res = encoder.finish().context("finishing zstd stream")?;
94+
9995
Ok(ToSqlOutput::Owned(Value::Blob(res)))
10096
}
10197

@@ -135,9 +131,8 @@ pub(crate) fn zstd_decompress_fn<'a>(
135131
None
136132
} else {
137133
match ctx.get_raw(arg_dict) {
138-
ValueRef::Integer(-1) => None,
139-
ValueRef::Null => None,
140-
ValueRef::Blob(d) => Some(Arc::new(wrap_decoder_dict(d.to_vec()))),
134+
ValueRef::Integer(-1) | ValueRef::Null => None,
135+
ValueRef::Blob(d) => Some(Arc::new(DecoderDictionary::copy(d))),
141136
ValueRef::Integer(_) => {
142137
Some(decoder_dict_from_ctx(ctx, arg_dict).context("load dict")?)
143138
}

src/dict_management.rs

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,18 @@ use std::time::Duration;
55

66
use zstd::dict::{DecoderDictionary, EncoderDictionary};
77

8-
type OwnedEncoderDict<'a> = owning_ref::OwningHandle<Vec<u8>, Box<EncoderDictionary<'a>>>;
9-
10-
// zstd-rs only exposes zstd_safe::create_cdict_by_reference, not zstd_safe::create_cdict
11-
// so we need to keep a reference to the vector ourselves
12-
// is there a better way?
13-
pub fn wrap_encoder_dict(dict_raw: Vec<u8>, level: i32) -> OwnedEncoderDict<'static> {
14-
owning_ref::OwningHandle::new_with_fn(dict_raw, |d| {
15-
Box::new(EncoderDictionary::new(
16-
unsafe { d.as_ref() }.unwrap(),
17-
level,
18-
))
19-
})
20-
}
21-
22-
type OwnedDecoderDict<'a> = owning_ref::OwningHandle<Vec<u8>, Box<DecoderDictionary<'a>>>;
23-
24-
// zstd-rs only exposes zstd_safe::create_cdict_by_reference, not zstd_safe::create_cdict
25-
// so we need to keep a reference to the vector ourselves
26-
// is there a better way?
27-
pub fn wrap_decoder_dict(dict_raw: Vec<u8>) -> OwnedDecoderDict<'static> {
28-
owning_ref::OwningHandle::new_with_fn(dict_raw, |d| {
29-
Box::new(DecoderDictionary::new(unsafe { &*d }))
30-
})
31-
}
328
// TODO: the rust interface currently requires a level when preparing a dictionary, but the zstd interface (ZSTD_CCtx_loadDictionary) does not.
339
// TODO: Using LruCache here isn't very smart
3410
pub fn encoder_dict_from_ctx<'a>(
3511
ctx: &'a Context,
3612
arg_index: usize,
3713
level: i32,
38-
) -> anyhow::Result<Arc<OwnedEncoderDict<'static>>> {
14+
) -> anyhow::Result<Arc<EncoderDictionary<'static>>> {
3915
use lru_time_cache::LruCache;
4016
// we cache the instantiated encoder dictionaries keyed by (DbConnection, dict_id, compression_level)
4117
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
4218
lazy_static::lazy_static! {
43-
static ref DICTS: RwLock<LruCache<(usize, i32, i32), Arc<OwnedEncoderDict<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
19+
static ref DICTS: RwLock<LruCache<(usize, i32, i32), Arc<EncoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
4420
}
4521
let id: i32 = ctx.get(arg_index)?;
4622
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
@@ -63,7 +39,7 @@ pub fn encoder_dict_from_ctx<'a>(
6339
|r| r.get(0),
6440
)
6541
.with_context(|| format!("getting dict with id={} from _zstd_dicts", id))?;
66-
let dict = wrap_encoder_dict(dict_raw, level);
42+
let dict = EncoderDictionary::copy(&dict_raw, level);
6743
Arc::new(dict)
6844
}),
6945
lru_time_cache::Entry::Occupied(o) => o.into_mut(),
@@ -75,12 +51,12 @@ pub fn encoder_dict_from_ctx<'a>(
7551
pub fn decoder_dict_from_ctx<'a>(
7652
ctx: &'a Context,
7753
arg_index: usize,
78-
) -> anyhow::Result<Arc<OwnedDecoderDict<'static>>> {
54+
) -> anyhow::Result<Arc<DecoderDictionary<'static>>> {
7955
use lru_time_cache::LruCache;
8056
// we cache the instantiated decoder dictionaries keyed by (DbConnection, dict_id)
8157
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
8258
lazy_static::lazy_static! {
83-
static ref DICTS: RwLock<LruCache<(usize, i32), Arc<OwnedDecoderDict<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
59+
static ref DICTS: RwLock<LruCache<(usize, i32), Arc<DecoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
8460
}
8561
let id: i32 = ctx.get(arg_index)?;
8662
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
@@ -101,7 +77,7 @@ pub fn decoder_dict_from_ctx<'a>(
10177
|r| r.get(0),
10278
)
10379
.with_context(|| format!("getting dict with id={} from _zstd_dicts", id))?;
104-
let dict = wrap_decoder_dict(dict_raw);
80+
let dict = DecoderDictionary::copy(&dict_raw);
10581
Arc::new(dict)
10682
}),
10783
lru_time_cache::Entry::Occupied(o) => o.into_mut(),

0 commit comments

Comments
 (0)