@@ -22,10 +22,6 @@ This file is part of the iText (R) project.
2222 */
2323package com .itextpdf .pdfocr .onnxtr ;
2424
25- import com .itextpdf .pdfocr .exceptions .PdfOcrException ;
26- import com .itextpdf .pdfocr .onnxtr .util .BatchProcessingGenerator ;
27- import com .itextpdf .pdfocr .onnxtr .util .Batching ;
28-
2925import ai .onnxruntime .NodeInfo ;
3026import ai .onnxruntime .OnnxJavaType ;
3127import ai .onnxruntime .OnnxTensor ;
@@ -39,6 +35,12 @@ This file is part of the iText (R) project.
3935import ai .onnxruntime .OrtSession .SessionOptions .OptLevel ;
4036import ai .onnxruntime .TensorInfo ;
4137import ai .onnxruntime .ValueInfo ;
38+ import com .itextpdf .commons .utils .MessageFormatUtil ;
39+ import com .itextpdf .pdfocr .exceptions .PdfOcrException ;
40+ import com .itextpdf .pdfocr .onnxtr .exceptions .PdfOcrOnnxTrExceptionMessageConstant ;
41+ import com .itextpdf .pdfocr .onnxtr .util .BatchProcessingGenerator ;
42+ import com .itextpdf .pdfocr .onnxtr .util .Batching ;
43+
4244import java .nio .FloatBuffer ;
4345import java .util .Arrays ;
4446import java .util .Collection ;
@@ -76,6 +78,21 @@ public abstract class AbstractOnnxPredictor<T, R> implements IPredictor<T, R> {
7678 */
7779 private final String inputName ;
7880
81+ /**
82+ * Close status of the predictor.
83+ */
84+ private boolean closed = false ;
85+
86+ static {
87+ try {
88+ // OnnxRuntime.init() is used under the hood.
89+ new OrtSession .SessionOptions ().close ();
90+ } catch (RuntimeException | UnsatisfiedLinkError e ) {
91+ DependencyLoadChecker .processException (e );
92+ throw e ;
93+ }
94+ }
95+
7996 /**
8097 * Creates a new abstract predictor.
8198 *
@@ -93,20 +110,21 @@ protected AbstractOnnxPredictor(String modelPath, OnnxInputProperties inputPrope
93110 try {
94111 this .sessionOptions = createDefaultSessionOptions ();
95112 } catch (OrtException e ) {
96- throw new PdfOcrException ("Failed to init ONNX Runtime session options" , e );
113+ throw new PdfOcrException (PdfOcrOnnxTrExceptionMessageConstant . FAILED_TO_INIT_SESSION_OPTIONS , e );
97114 }
98115
99116 try {
100117 this .session = OrtEnvironment .getEnvironment ().createSession (modelPath , sessionOptions );
101118 } catch (Exception e ) {
102119 this .sessionOptions .close ();
103- throw new PdfOcrException ("Failed to init ONNX Runtime session" , e );
120+ throw new PdfOcrException (PdfOcrOnnxTrExceptionMessageConstant . FAILED_TO_INIT_ONNX_RUNTIME_SESSION , e );
104121 }
105122
106123 try {
107124 this .inputName = validateModel (this .session , inputProperties , outputShape );
108125 } catch (Exception e ) {
109- final PdfOcrException userException = new PdfOcrException ("ONNX Runtime model did not pass validation" , e );
126+ final PdfOcrException userException = new PdfOcrException (
127+ PdfOcrOnnxTrExceptionMessageConstant .MODEL_DID_NOT_PASS_VALIDATION , e );
110128 try {
111129 this .session .close ();
112130 } catch (OrtException closeException ) {
@@ -123,23 +141,28 @@ public Iterator<R> predict(Iterator<T> inputs) {
123141 Batching .wrap (inputs , inputProperties .getBatchSize ()),
124142 (List <T > batch ) -> {
125143 try (final OnnxTensor inputTensor = createTensor (toInputBuffer (batch ));
126- final Result outputTensor = session .run (Collections .singletonMap (inputName , inputTensor ))) {
144+ final Result outputTensor = session .run (Collections .singletonMap (inputName , inputTensor ))) {
127145 return fromOutputBuffer (batch , parseModelOutput (outputTensor ));
128146 } catch (OrtException e ) {
129- throw new PdfOcrException ("ONNX Runtime operation failed" , e );
147+ throw new PdfOcrException (
148+ PdfOcrOnnxTrExceptionMessageConstant .ONNX_RUNTIME_OPERATION_FAILED , e );
130149 }
131150 }
132151 );
133152 }
134153
135154 @ Override
136155 public void close () {
156+ if (closed ) {
157+ return ;
158+ }
137159 try {
138160 session .close ();
139161 sessionOptions .close ();
140162 } catch (OrtException e ) {
141- throw new PdfOcrException ("Failed to close an ONNX Runtime session" , e );
163+ throw new PdfOcrException (PdfOcrOnnxTrExceptionMessageConstant . FAILED_TO_CLOSE_ONNX_RUNTIME_SESSION , e );
142164 }
165+ closed = true ;
143166 }
144167
145168 /**
@@ -205,51 +228,47 @@ private static String validateModel(OrtSession session, OnnxInputProperties prop
205228 private static String validateModelInput (OrtSession session , OnnxInputProperties properties ) throws OrtException {
206229 final Collection <NodeInfo > inputInfo = session .getInputInfo ().values ();
207230 if (inputInfo .size () != 1 ) {
208- throw new IllegalArgumentException (
209- "Expected 1 input, but got " + inputInfo .size () + " instead"
210- );
231+ throw new IllegalArgumentException (MessageFormatUtil .format (
232+ PdfOcrOnnxTrExceptionMessageConstant .UNEXPECTED_INPUT_SIZE , inputInfo .size ()));
211233 }
212234 final NodeInfo inputNodeInfo = inputInfo .iterator ().next ();
213235 final ValueInfo inputNodeValueInfo = inputNodeInfo .getInfo ();
214236 if (!(inputNodeValueInfo instanceof TensorInfo )) {
215- throw new IllegalArgumentException ("Unexpected input type, expected float32 tensor" );
237+ throw new IllegalArgumentException (PdfOcrOnnxTrExceptionMessageConstant . UNEXPECTED_INPUT_TYPE );
216238 }
217239 final TensorInfo inputTensorInfo = (TensorInfo ) inputNodeValueInfo ;
218240 if (inputTensorInfo .type != OnnxJavaType .FLOAT ) {
219- throw new IllegalArgumentException ("Unexpected input type, expected float32 tensor" );
241+ throw new IllegalArgumentException (PdfOcrOnnxTrExceptionMessageConstant . UNEXPECTED_INPUT_TYPE );
220242 }
221243 final long [] inputShape = inputTensorInfo .getShape ();
222244 if (isShapeIncompatible (properties .getShape (), inputShape )) {
223- throw new IllegalArgumentException (
224- "Expected " + Arrays .toString (properties .getShape ()) + " input shape, "
225- + "but got " + Arrays .toString (inputShape ) + " instead"
226- );
245+ throw new IllegalArgumentException (MessageFormatUtil .format (
246+ PdfOcrOnnxTrExceptionMessageConstant .UNEXPECTED_INPUT_SHAPE , Arrays .toString (properties .getShape ()),
247+ Arrays .toString (inputShape )));
227248 }
228249 return inputNodeInfo .getName ();
229250 }
230251
231252 private static void validateModelOutput (OrtSession session , long [] expectedOutputShape ) throws OrtException {
232253 final Collection <NodeInfo > outputInfo = session .getOutputInfo ().values ();
233254 if (outputInfo .size () != 1 ) {
234- throw new IllegalArgumentException (
235- "Expected 1 output, but got " + outputInfo .size () + " instead"
236- );
255+ throw new IllegalArgumentException (MessageFormatUtil .format (
256+ PdfOcrOnnxTrExceptionMessageConstant .UNEXPECTED_OUTPUT_SIZE , outputInfo .size ()));
237257 }
238258 final NodeInfo outputNodeInfo = outputInfo .iterator ().next ();
239259 final ValueInfo outputNodeValueInfo = outputNodeInfo .getInfo ();
240260 if (!(outputNodeValueInfo instanceof TensorInfo )) {
241- throw new IllegalArgumentException ("Unexpected output type, expected float32 tensor" );
261+ throw new IllegalArgumentException (PdfOcrOnnxTrExceptionMessageConstant . UNEXPECTED_OUTPUT_TYPE );
242262 }
243263 final TensorInfo outputTensorInfo = (TensorInfo ) outputNodeValueInfo ;
244264 if (outputTensorInfo .type != OnnxJavaType .FLOAT ) {
245- throw new IllegalArgumentException ("Unexpected output type, expected float32 tensor" );
265+ throw new IllegalArgumentException (PdfOcrOnnxTrExceptionMessageConstant . UNEXPECTED_OUTPUT_TYPE );
246266 }
247267 final long [] actualOutputShape = outputTensorInfo .getShape ();
248268 if (isShapeIncompatible (expectedOutputShape , actualOutputShape )) {
249- throw new IllegalArgumentException (
250- "Expected " + Arrays .toString (expectedOutputShape ) + " output shape, "
251- + "but got " + Arrays .toString (actualOutputShape ) + " instead"
252- );
269+ throw new IllegalArgumentException (MessageFormatUtil .format (
270+ PdfOcrOnnxTrExceptionMessageConstant .UNEXPECTED_OUTPUT_SHAPE , Arrays .toString (expectedOutputShape ),
271+ Arrays .toString (actualOutputShape )));
253272 }
254273 }
255274
0 commit comments