Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion processors/src/docx_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,22 @@ impl DocxProcessor {

impl FileProcessor for DocxProcessor {
fn process_file(&self, path: impl AsRef<Path>) -> anyhow::Result<Document> {
let docs = MarkdownDocument::from_file(path);
// `docx-parser::MarkdownDocument::from_file` uses `panic!` instead of returning
// `Result` when the file is missing, corrupt, or not a valid DOCX/ZIP archive.
// We catch that panic here and convert it into a proper anyhow error so callers
// get a clean Err(…) rather than a process-level abort.
let path = path.as_ref().to_owned();
let docs =
std::panic::catch_unwind(move || MarkdownDocument::from_file(&path)).map_err(|e| {
let msg = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown panic".to_string()
};
anyhow::anyhow!("docx_parser panicked while opening file: {}", msg)
})?;
let markdown = docs.to_markdown(false);
self.markdown_processor.process_document(&markdown)
}
Expand Down
41 changes: 28 additions & 13 deletions rust/src/chunkers/statistical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,33 @@ use text_splitter::{ChunkConfig, TextSplitter};
// use text_splitter::{ChunkConfig, TextSplitter};
use tokenizers::Tokenizer;

fn median<T>(data: &[T]) -> T
fn median<T>(data: &[T]) -> Option<T>
where
T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Div<Output = T> + From<u8>,
{
assert!(!data.is_empty(), "median requires at least one data point");
if data.is_empty() {
return None;
}
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Use `unwrap_or` to handle NaN values (treat them as equal) instead of panicking.
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 {
let result = if sorted.len() % 2 == 0 {
(sorted[mid - 1] + sorted[mid]) / T::from(2u8)
} else {
sorted[mid]
}
};
Some(result)
}

fn std_dev(data: &[f32]) -> f32 {
assert!(data.len() > 1, "standard deviation requires at least two data points");
fn std_dev(data: &[f32]) -> Option<f32> {
if data.len() < 2 {
return None;
}
let n = data.len() as f32;
let mean = data.iter().sum::<f32>() / n;
let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
variance.sqrt()
Some(variance.sqrt())
}

pub struct StatisticalChunker {
Expand Down Expand Up @@ -255,7 +261,14 @@ impl StatisticalChunker {
raw_similarities
}

fn _find_optimal_threshold(&self, batch_splits: &[&str], similarities: &Vec<f32>) -> f32 {
fn _find_optimal_threshold(&self, batch_splits: &[&str], similarities: &[f32]) -> f32 {
// Guard: we need at least 2 similarity scores to compute median + std_dev.
// With 0 scores there are no chunk boundaries to find; return a neutral threshold.
// With 1 score there is no variance to measure; use that single score directly.
if similarities.len() < 2 {
return similarities.first().copied().unwrap_or(0.5);
}

let tokens = self
.tokenizer
.encode_batch(batch_splits.to_vec(), true)
Expand All @@ -274,8 +287,10 @@ impl StatisticalChunker {
.collect::<Vec<_>>();

// analyze the distribution of similarity scores to set initial bounds
let median_score = median(similarities);
let std_dev = std_dev(similarities);
// Both median() and std_dev() return Option; the len() >= 2 guard above
// ensures they always return Some(_) here.
let median_score = median(similarities).unwrap_or(0.5);
let std_dev = std_dev(similarities).unwrap_or(0.0);

// set initial bounds based on median and standard deviation
let mut low = f32::max(0.0, median_score - std_dev);
Expand All @@ -300,7 +315,7 @@ impl StatisticalChunker {
.map(|(start, end)| cumulative_token_counts[*end] - cumulative_token_counts[*start])
.collect();

median_tokens = median(&split_token_counts);
median_tokens = median(&split_token_counts).unwrap_or(0);

if self.min_split_tokens - self.split_token_tolerance <= median_tokens
&& median_tokens <= self.max_split_tokens + self.split_token_tolerance
Expand All @@ -315,7 +330,7 @@ impl StatisticalChunker {
}
calculated_threshold
}
fn _find_split_indices(&self, similarities: &Vec<f32>, threshold: f32) -> Vec<usize> {
fn _find_split_indices(&self, similarities: &[f32], threshold: f32) -> Vec<usize> {
let mut split_indices = Vec::new();
for (idx, score) in enumerate(similarities) {
if *score < threshold {
Expand Down
Loading