Skip to content

Commit 4603c70

Browse files
Eswcvladnanouh
authored andcommitted
Add script for EasyOCR model conversion to onnx
DEVSIX-9776
1 parent 630cc6e commit 4603c70

1 file changed

Lines changed: 113 additions & 0 deletions

File tree

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import os.path
4+
5+
import easyocr
6+
from easyocr import config
7+
from easyocr.craft import CRAFT
8+
from easyocr.detection import copyStateDict
9+
10+
import torch
11+
12+
13+
detection_models = (
14+
'craft',
15+
)
16+
recognition_models_gen1 = (
17+
'arabic_g1',
18+
'bengali_g1',
19+
'cyrillic_g1',
20+
'devanagari_g1',
21+
'japanese_g1',
22+
'korean_g1',
23+
'latin_g1',
24+
# FIXME: this one causes issues during export
25+
# 'tamil_g1',
26+
'thai_g1',
27+
'zh_sim_g1',
28+
'zh_tra_g1',
29+
)
30+
recognition_models_gen2 = (
31+
'cyrillic_g2',
32+
'english_g2',
33+
'japanese_g2',
34+
'kannada_g2',
35+
'korean_g2',
36+
'latin_g2',
37+
'telugu_g2',
38+
'zh_sim_g2',
39+
)
40+
recognition_models = recognition_models_gen1 + recognition_models_gen2
41+
42+
43+
# Detection model
44+
class TrimmedCRAFT(CRAFT):
45+
def forward(self, x):
46+
# Ignoring "feature"
47+
y, _ = super().forward(x)
48+
# Transposing result back to BCHW
49+
return y.permute(0, 3, 1, 2)
50+
51+
52+
def get_detector(trained_model, device='cpu'):
53+
net = TrimmedCRAFT()
54+
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device, weights_only=False)))
55+
torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
56+
net.eval()
57+
return net
58+
59+
60+
def main():
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument('model_dir', help='directory with EasyOCR models')
63+
model_dir = parser.parse_args().model_dir
64+
65+
for recognition_model in recognition_models:
66+
print(f'Exporting {recognition_model}...')
67+
gen = 'gen1' if recognition_model.endswith('_g1') else 'gen2'
68+
filename: str = config.recognition_models[gen][recognition_model]['filename']
69+
reader = easyocr.Reader(
70+
lang_list=['en'],
71+
gpu=False,
72+
model_storage_directory=model_dir,
73+
recog_network=recognition_model,
74+
quantize=False,
75+
)
76+
# AdaptiveAvgPool2d cannot be exported to ONNX
77+
# Specifying a static one instead assuming imgH=64
78+
reader.recognizer.AdaptiveAvgPool = torch.nn.AvgPool2d((1, 3))
79+
dummy_input = (
80+
torch.randn(1, 1, 64, 512),
81+
torch.randn(1, 512),
82+
)
83+
torch.onnx.export(
84+
reader.recognizer,
85+
dummy_input,
86+
os.path.join(model_dir, filename.rsplit('.', 1)[0] + '.onnx'),
87+
export_params=True,
88+
input_names=('input', 'text',),
89+
output_names=('preds',),
90+
dynamic_axes={
91+
"input": {0: 'batch_size', 3: 'width'},
92+
"text": {0: 'batch_size', 1: 'batch_max_length'},
93+
},
94+
)
95+
96+
print('Exporting CRAFT...')
97+
filename: str = config.detection_models['craft']['filename']
98+
dummy_input = (torch.randn(1, 3, 2560, 2560),)
99+
model = get_detector(os.path.join(model_dir, filename))
100+
torch.onnx.export(
101+
model,
102+
dummy_input,
103+
os.path.join(model_dir, filename.rsplit('.', 1)[0] + '.onnx'),
104+
export_params=True,
105+
input_names=('images',),
106+
output_names=('y',),
107+
dynamic_axes={
108+
"images": {0: 'batch_size', 2: 'height', 3: 'width'},
109+
},
110+
)
111+
112+
if __name__ == '__main__':
113+
main()

0 commit comments

Comments
 (0)