diff --git a/aircv/__init__.py b/aircv/__init__.py index e8e52ea..a88e065 100644 --- a/aircv/__init__.py +++ b/aircv/__init__.py @@ -125,8 +125,28 @@ def find_all_template(im_source, im_search, threshold=0.5, maxcnt=0, rgb=False, resbgr[i] = cv2.matchTemplate(i_bgr[i], s_bgr[i], method) res = resbgr[0]*weight[0] + resbgr[1]*weight[1] + resbgr[2]*weight[2] else: - s_gray = cv2.cvtColor(im_search, cv2.COLOR_BGR2GRAY) - i_gray = cv2.cvtColor(im_source, cv2.COLOR_BGR2GRAY) + channel = 1 if len(im_search.shape) == 2 else im_search.shape[2] + if channel == 1: + # if the image is gray, then keep it + s_gray = im_search + elif channel == 3: + # if it's colorful, then convert it to gray + s_gray = cv2.cvtColor(im_search, cv2.COLOR_BGR2GRAY) + elif channel == 4: + # if it's colorful with transparent channel, then convert it to gray + s_gray = cv2.cvtColor(im_search, cv2.COLOR_BGRA2GRAY) + + channel = 1 if len(im_source.shape) == 2 else im_source.shape[2] + if channel == 1: + # if the image is gray, then keep it + i_gray = im_source + elif channel == 3: + # if it's colorful, then convert it to gray + i_gray = cv2.cvtColor(im_source, cv2.COLOR_BGR2GRAY) + elif channel == 4: + # if it's colorful with transparent channel, then convert it to gray + i_gray = cv2.cvtColor(im_source, cv2.COLOR_BGRA2GRAY) + # 边界提取(来实现背景去除的功能) if bgremove: s_gray = cv2.Canny(s_gray, 100, 200) diff --git a/tests/example.py b/tests/example.py index efaf32a..78f1c05 100644 --- a/tests/example.py +++ b/tests/example.py @@ -7,29 +7,70 @@ """ sift """ +import cv2 +import numpy as np import aircv as ac def sift_test(): t1 = ac.imread("testdata/1s.png") t2 = ac.imread("testdata/2s.png") - print ac.sift_count(t1), ac.sift_count(t2) + print(ac.sift_count(t1), ac.sift_count(t2)) result = ac.find_sift(t1, t2, min_match_count=ac.sift_count(t1)*0.4) # after many tests, 0.4 may be best if result: - print 'Same' + print('Same') else: - print 'Not same' + print('Not same') -def tmpl_test(): - t1 = ac.imread("testdata/2s.png") - t2 = ac.imread("testdata/2t.png") +def tmpl_test(imgsrc, imgtgt): + t1 = ac.imread(imgsrc) + t2 = ac.imread(imgtgt) import time start = time.time() - print ac.find_all_template(t1, t2) - print 'Time used:', time.time() - start + print(ac.find_all_template(t1, t2)) + print('Time used:', time.time() - start) + + +def add_alpha(imgsrc, imgtgt): + # a tool to generate 4 channel image + img = cv2.imread(imgsrc) + + b_channel, g_channel, r_channel = cv2.split(img) + + alpha_channel = np.ones(b_channel.shape, dtype=b_channel.dtype) * 255 + # 最小值为0 + alpha_channel[:, :int(b_channel.shape[1] / 2)] = 100 + + img_BGRA = cv2.merge((b_channel, g_channel, r_channel, alpha_channel)) + cv2.imwrite(imgtgt, img_BGRA) + + +def to_grayscale(imgsrc, imgtgt): + # a tool to generate grayscale image + img = cv2.imread(imgsrc) + + channel = img.shape[2] + if channel == 1: + img_gray = imgsrc + elif channel == 3: + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + elif channel == 4: + img_gray = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY) + + cv2.imwrite(imgtgt, img_gray) + if __name__ == '__main__': - sift_test() - tmpl_test() + # sift_test() + + # add_alpha('testdata/2s.png','testdata/2s_trans.png') + # to_grayscale('testdata/2s.png','testdata/2s_gray.png') + + tmpl_test("testdata/2s.png", "testdata/2t.png") + # test find template in 3 channel images + tmpl_test("testdata/2s_trans.png", "testdata/2t.png") + # test find template in 4 channel images + tmpl_test("testdata/2s_gray.png", "testdata/2t.png") + # test find template in 1 channel images diff --git a/tests/testdata/2s_gray.png b/tests/testdata/2s_gray.png new file mode 100644 index 0000000..1844003 Binary files /dev/null and b/tests/testdata/2s_gray.png differ diff --git a/tests/testdata/2s_trans.png b/tests/testdata/2s_trans.png new file mode 100644 index 0000000..7eb64a0 Binary files /dev/null and b/tests/testdata/2s_trans.png differ