Skip to content

Commit e46ad1a

Browse files
committed
added pytorch to requirements and deep_refinement with tests
1 parent 07543a3 commit e46ad1a

5 files changed

Lines changed: 350 additions & 0 deletions

File tree

openpiv/deep_refinement.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import numpy as np
2+
import scipy.ndimage
3+
from scipy.interpolate import griddata
4+
5+
# Optional imports for Deep Learning backends (PyTorch/TensorFlow)
6+
try:
7+
import torch
8+
import torch.nn.functional as F
9+
HAS_TORCH = True
10+
except ImportError:
11+
HAS_TORCH = False
12+
13+
def warp_image(image, u, v):
14+
"""
15+
Warps 'image' based on the velocity field (u, v).
16+
Used to warp Image 2 back towards Image 1 based on coarse estimation.
17+
18+
Args:
19+
image (np.ndarray): Image 2 (grayscale or RGB).
20+
u (np.ndarray): Dense displacement field in x (same shape as image).
21+
v (np.ndarray): Dense displacement field in y (same shape as image).
22+
23+
Returns:
24+
np.ndarray: The warped image.
25+
"""
26+
# Create a grid of coordinates
27+
h, w = image.shape[:2]
28+
y_grid, x_grid = np.mgrid[0:h, 0:w]
29+
30+
# Apply the reverse displacement (Warp Image 2 "back" to Image 1)
31+
# According to Choi et al: Warped(x,y) = Img2(x + u, y + v)
32+
map_x = x_grid + u
33+
map_y = y_grid + v
34+
35+
# Handle interpolation (Scipy is used to avoid adding OpenCV dependency,
36+
# though cv2.remap is faster)
37+
if image.ndim == 2:
38+
warped = scipy.ndimage.map_coordinates(
39+
image, [map_y, map_x], order=1, mode='nearest'
40+
)
41+
else:
42+
# Handle RGB if necessary
43+
warped = np.zeros_like(image)
44+
for i in range(image.shape[2]):
45+
warped[..., i] = scipy.ndimage.map_coordinates(
46+
image[..., i], [map_y, map_x], order=1, mode='nearest'
47+
)
48+
49+
return warped
50+
51+
def upscale_flow(u_coarse, v_coarse, x_coarse, y_coarse, target_shape):
52+
"""
53+
Upscales the sparse PIV grid (from correlation) to dense pixel resolution.
54+
55+
Args:
56+
u_coarse, v_coarse: Velocity components from standard OpenPIV.
57+
x_coarse, y_coarse: Meshgrid coordinates of the coarse vectors.
58+
target_shape: (height, width) of the original image.
59+
60+
Returns:
61+
u_dense, v_dense: Fields matching target_shape.
62+
"""
63+
h, w = target_shape
64+
grid_y, grid_x = np.mgrid[0:h, 0:w]
65+
66+
# Flatten source points
67+
points = np.column_stack((x_coarse.flatten(), y_coarse.flatten()))
68+
69+
# Interpolate (Linear is usually sufficient for the "Coarse" step)
70+
u_dense = griddata(points, u_coarse.flatten(), (grid_x, grid_y), method='linear')
71+
v_dense = griddata(points, v_coarse.flatten(), (grid_x, grid_y), method='linear')
72+
73+
# Fill NaNs at edges (common in PIV) with nearest valid value or zero
74+
mask = np.isnan(u_dense)
75+
if np.any(mask):
76+
u_dense[mask] = 0 # Simplified; ideal is nearest neighbor inpaint
77+
v_dense[mask] = 0
78+
79+
return u_dense, v_dense
80+
81+
class DeepRefiner:
82+
"""
83+
Implements the Choi et al. refinement algorithm.
84+
Wrapper for an Optical Flow CNN (e.g., RAFT, FlowNet2, LiteFlowNet).
85+
"""
86+
def __init__(self, model_path=None, device='cpu'):
87+
"""
88+
Args:
89+
model_path (str): Path to trained weights (.pth, .onnx).
90+
device (str): 'cpu' or 'cuda'.
91+
"""
92+
self.device = device
93+
self.model = self._load_model(model_path)
94+
95+
def _load_model(self, path):
96+
"""
97+
Placeholder for model loading logic.
98+
In a real implementation, this would load RAFT or FlowNet2.
99+
"""
100+
if not HAS_TORCH:
101+
print("Warning: PyTorch not found. Using dummy identity model.")
102+
return None
103+
104+
# Boilerplate: Load your specific architecture here
105+
# model = RAFT(args)
106+
# model.load_state_dict(torch.load(path))
107+
# return model.eval().to(self.device)
108+
return "Loaded_Model_Placeholder"
109+
110+
def predict_residual(self, img1, img2_warped):
111+
"""
112+
Uses the CNN to find the small 'residual' motion between Image 1
113+
and the Warped Image 2.
114+
"""
115+
# 1. Preprocess images (Normalize to [0,1] or [-1,1], convert to Tensor)
116+
# 2. Feed to self.model
117+
# 3. Return residual flow numpy array
118+
119+
# --- DUMMY IMPLEMENTATION FOR BOILERPLATE ---
120+
# Returns zero residual (no refinement)
121+
h, w = img1.shape
122+
return np.zeros((h, w)), np.zeros((h, w))
123+
124+
def refine(self, image1, image2, u_piv, v_piv, x_piv, y_piv):
125+
"""
126+
Main execution method for the Hybrid PIV+CNN approach.
127+
128+
Args:
129+
image1, image2: Raw particle images.
130+
u_piv, v_piv: Result from openpiv.pyprocess.extended_search_area_piv
131+
x_piv, y_piv: Coordinates of the PIV grid.
132+
133+
Returns:
134+
u_final, v_final: Dense, high-resolution velocity fields.
135+
"""
136+
137+
# Step 1: Upscale Coarse PIV to Pixel Resolution
138+
print("Upscaling coarse PIV field...")
139+
u_dense, v_dense = upscale_flow(
140+
u_piv, v_piv, x_piv, y_piv, image1.shape
141+
)
142+
143+
# Step 2: Warp Image 2 using the Coarse Flow
144+
# The warped image should now align closely with Image 1
145+
print("Warping Image 2...")
146+
image2_warped = warp_image(image2, u_dense, v_dense)
147+
148+
# Step 3: CNN Inference on (Image 1, Warped Image 2)
149+
# The CNN only needs to find the small differences (residuals)
150+
print("Calculating residual flow with CNN...")
151+
u_res, v_res = self.predict_residual(image1, image2_warped)
152+
153+
# Step 4: Combine Results
154+
# Total Flow = Coarse Flow + Residual Flow
155+
u_final = u_dense + u_res
156+
v_final = v_dense + v_res
157+
158+
return u_final, v_final
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "69106551",
7+
"metadata": {
8+
"vscode": {
9+
"languageId": "plaintext"
10+
}
11+
},
12+
"outputs": [],
13+
"source": [
14+
"import numpy as np\n",
15+
"from openpiv import tools, pyprocess, scaling, validation\n",
16+
"from openpiv.deep_refinement import DeepRefiner # The new module\n",
17+
"\n",
18+
"# 1. Load Images\n",
19+
"frame_a = tools.imread( 'exp1_001_a.bmp' )\n",
20+
"frame_b = tools.imread( 'exp1_001_b.bmp' )\n",
21+
"\n",
22+
"# 2. Standard OpenPIV (The \"Coarse\" Step)\n",
23+
"# We use standard WIDIM or simple Cross-Correlation\n",
24+
"u, v, sig2noise = pyprocess.extended_search_area_piv(\n",
25+
" frame_a.astype(np.int32), \n",
26+
" frame_b.astype(np.int32), \n",
27+
" window_size=24, \n",
28+
" overlap=12, \n",
29+
" dt=0.02, \n",
30+
" search_area_size=64, \n",
31+
" sig2noise_method='peak2peak'\n",
32+
")\n",
33+
"\n",
34+
"# Get the grid coordinates for the coarse vectors\n",
35+
"x, y = pyprocess.get_coordinates(\n",
36+
" image_size=frame_a.shape, \n",
37+
" search_area_size=64, \n",
38+
" overlap=12\n",
39+
")\n",
40+
"\n",
41+
"# 3. Filter Outliers (Important before feeding to CNN)\n",
42+
"u, v, mask = validation.sig2noise_val( u, v, sig2noise, threshold = 1.3 )\n",
43+
"u, v = validation.global_val( u, v, (-1000, 1000), (-1000, 1000) )\n",
44+
"u, v = tools.replace_outliers( u, v, method='localmean', max_iter=10, kernel_size=2)\n",
45+
"\n",
46+
"# 4. Apply CNN Refinement (The Choi et al. Step)\n",
47+
"# Initialize refiner (optionally load weights for RAFT/FlowNet2)\n",
48+
"refiner = DeepRefiner(model_path=\"weights/raft-sintel.pth\")\n",
49+
"\n",
50+
"# Get pixel-dense, super-resolved flow\n",
51+
"u_dense, v_dense = refiner.refine(frame_a, frame_b, u, v, x, y)\n",
52+
"\n",
53+
"# 5. Save/Plot\n",
54+
"tools.save(x, y, u, v, mask, 'openpiv_result.txt')\n",
55+
"# Save the dense field (custom function needed as it's pixel-wise, not grid-wise)\n",
56+
"np.savez('dense_result.npz', u=u_dense, v=v_dense)"
57+
]
58+
}
59+
],
60+
"metadata": {
61+
"language_info": {
62+
"name": "python"
63+
}
64+
},
65+
"nbformat": 4,
66+
"nbformat_minor": 5
67+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import unittest
2+
import numpy as np
3+
from scipy import ndimage
4+
from openpiv import deep_refinement
5+
6+
class TestDeepRefinement(unittest.TestCase):
7+
def setUp(self):
8+
"""
9+
Generate synthetic particle images for testing.
10+
"""
11+
# 1. Create a random speckle pattern (Image A)
12+
# We use random noise smoothed by a Gaussian to simulate particles
13+
np.random.seed(42)
14+
noise = np.random.rand(64, 64)
15+
self.frame_a = ndimage.gaussian_filter(noise, sigma=1.5)
16+
17+
# 2. Define a known Ground Truth displacement
18+
# Let's say particles moved u=3.5px, v=-2.0px
19+
self.u_gt = 3.5
20+
self.v_gt = -2.0
21+
22+
# 3. Create Image B by shifting Image A by the ground truth
23+
# (Simulates the particles moving)
24+
self.frame_b = ndimage.shift(self.frame_a, shift=(self.v_gt, self.u_gt), order=3)
25+
26+
# 4. Create a "Coarse" PIV grid (Simulating a standard Cross-Correlation result)
27+
# We simulate a coarse grid of 3x3 vectors
28+
# Let's assume standard PIV found an integer approximation (u=3, v=-2)
29+
grid_h, grid_w = 3, 3
30+
self.u_piv = np.full((grid_h, grid_w), 3.0)
31+
self.v_piv = np.full((grid_h, grid_w), -2.0)
32+
33+
# Create coordinates for these vectors (center of windows)
34+
# Just simple linspace for testing upscaling
35+
y = np.linspace(16, 48, grid_h)
36+
x = np.linspace(16, 48, grid_w)
37+
self.x_piv, self.y_piv = np.meshgrid(x, y)
38+
39+
def test_upscale_flow(self):
40+
"""
41+
Test if the coarse grid is correctly interpolated to pixel resolution.
42+
"""
43+
target_shape = (64, 64)
44+
u_dense, v_dense = deep_refinement.upscale_flow(
45+
self.u_piv, self.v_piv, self.x_piv, self.y_piv, target_shape
46+
)
47+
48+
# The result should be dense (64x64)
49+
self.assertEqual(u_dense.shape, target_shape)
50+
51+
# Since our coarse input was uniform (all 3.0), the output should be uniform
52+
np.testing.assert_allclose(u_dense, 3.0, rtol=1e-5)
53+
np.testing.assert_allclose(v_dense, -2.0, rtol=1e-5)
54+
55+
def test_image_warping_integrity(self):
56+
"""
57+
Crucial Test: Validate the warping logic defined by Choi et al.
58+
If we warp Frame B back by the Ground Truth flow, it should match Frame A.
59+
"""
60+
# Create a dense flow field representing the Ground Truth
61+
h, w = self.frame_a.shape
62+
u_field = np.full((h, w), self.u_gt) # 3.5
63+
v_field = np.full((h, w), self.v_gt) # -2.0
64+
65+
# Warp Frame B "backwards" using the flow
66+
# Expected: warped_b should look like frame_a
67+
warped_b = deep_refinement.warp_image(self.frame_b, u_field, v_field)
68+
69+
# Crop borders to avoid shifting artifacts when comparing
70+
border = 5
71+
diff = np.abs(self.frame_a[border:-border, border:-border] -
72+
warped_b[border:-border, border:-border])
73+
74+
# The difference should be very small (close to 0)
75+
# We allow small tolerance due to interpolation errors
76+
mae = np.mean(diff)
77+
self.assertLess(mae, 0.05, "Warping failed to align Frame B with Frame A")
78+
79+
def test_pipeline_with_mock_model(self):
80+
"""
81+
Test the full 'refine' pipeline.
82+
Since we don't have a trained CNN model file in the repo,
83+
we Mock the 'predict_residual' method.
84+
"""
85+
86+
# 1. Initialize the Refiner
87+
refiner = deep_refinement.DeepRefiner(model_path=None, device='cpu')
88+
89+
# 2. Mock the CNN output
90+
# The Standard PIV found u=3.0. The Ground Truth is u=3.5.
91+
# The CNN *should* find the residual: 0.5.
92+
# We force our mock to return exactly that.
93+
def mock_predict(img1, img2_warped):
94+
h, w = img1.shape
95+
# Return uniform residual of 0.5 for u, 0.0 for v
96+
return np.full((h, w), 0.5), np.full((h, w), 0.0)
97+
98+
# Inject the mock
99+
refiner.predict_residual = mock_predict
100+
101+
# 3. Run the pipeline
102+
u_final, v_final = refiner.refine(
103+
self.frame_a, self.frame_b,
104+
self.u_piv, self.v_piv,
105+
self.x_piv, self.y_piv
106+
)
107+
108+
# 4. Assertions
109+
# Coarse (3.0) + Residual (0.5) = 3.5 (Ground Truth)
110+
expected_u = 3.5
111+
112+
# Check center pixels (avoid boundary interpolation issues)
113+
center_u = u_final[32, 32]
114+
115+
self.assertAlmostEqual(center_u, expected_u, places=3)
116+
print(f"Pipeline Test: Input=3.0, Residual=0.5, Result={center_u} (Expected 3.5)")
117+
118+
if __name__ == '__main__':
119+
unittest.main()

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ scipy = ">=1.11.0"
3838
natsort = ">=8.4.0"
3939
tqdm = ">=4.66.0"
4040
importlib_resources = ">=5.12.0"
41+
# Optional dependencies for Deep Refinement
42+
torch = ">=1.10.0"
43+
torchvision = ">=0.15.0"
4144

4245
[tool.poetry.dev-dependencies]
4346
pytest = "^7.4.3"

uv.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)