Skip to content

Commit a1d9398

Browse files
Added Surrogate Modeling enhancement module (#30)
* update the benchmark fim_lookup query within fimserv * update the date field and formatting of codes * added the evaluation module within FIMserv * added the surrogate modeling enhacement module * added Surrogate Model citation
1 parent b3c7672 commit a1d9398

15 files changed

Lines changed: 1932 additions & 1 deletion

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "fimserve"
3-
version = "0.1.87"
3+
version = "0.1.88"
44
description = "Framework which is developed with the purpose of quickly generating Flood Inundation Maps (FIM) for emergency response and risk assessment. It is developed under Surface Dynamics Modeling Lab (SDML)."
55
authors = [{ name = "Supath Dhital", email = "sdhital@crimson.ua.edu" }]
66
maintainers = [{ name = "Supath Dhital", email = "sdhital@crimson.ua.edu" }]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Operational Enhanced FIM
2+
3+
## Overview
4+
5+
The **Operational Enhanced Flood Inundation Mapping (FIM)** framework provides a streamlined pipeline for real-time flood mapping and impact analysis. Its primary objective is to support emergency response operations by delivering timely, high-resolution flood information.
6+
7+
8+
## Workflow
9+
10+
The core components of the framework are illustrated in the figure below:
11+
12+
![Workflow](graphics/operationalFIM.jpg)
13+
14+
## Key Capabilities
15+
16+
- **Real-time flood map generation** using surrogate models.
17+
- **Automated processing pipeline** for rapid data acquisition, model execution, and visualization.
18+
- **Scalable architecture** adaptable to different regions and data sources.
19+
20+
## Exposure Analysis
21+
22+
In addition to mapping flooded areas, the framework includes modules for:
23+
24+
- **Population exposure estimation**: Identifies the number of people affected by the flood event. For this it uses the population grids to analyze the exposed population. The framework provides the estimated exposed population and the spatial distribution of those population count using a unique way of vizualization.
25+
![Workflow](graphics/population_exposure.png)
26+
27+
- **Building exposure analysis**: Detects flooded structures using geospatial building footprints and flood maps.
28+
![Workflow](graphics/building_exposure.png)
29+
30+
These capabilities help quantify potential impacts and support decision-making in emergency situations.
31+
32+
## Application in Emergency Response
33+
34+
This framework is designed to:
35+
- Provide first responders with up-to-date flood information
36+
- Inform public warnings and evacuation strategies
37+
- Aid in post-event damage assessments and recovery planning
38+
39+
## Reference
40+
The trained Surrogate Model that is being applied for the enhancement of the FIMserv derieved Low-Fidelity results comes out of the following research, which is currently in review.
41+
42+
**Preprint****Supath Dhital**, Sagy Cohen, Parvaneh Nikrou, et al.
43+
Enhancement of low-fidelity flood inundation mapping through surrogate modeling. *ESS Open Archive*, November 03, 2025.
44+
[https://doi.org/10.22541/essoar.176218121.12875584/v1](https://doi.org/10.22541/essoar.176218121.12875584/v1)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from SM_preprocess import *
2+
from surrogate_model import *
3+
from utlis import *
4+
from preprocessFIM import *
5+
6+
def load_model(model):
7+
# Set up S3 access
8+
fs = s3fs.S3FileSystem(anon=True)
9+
bucket_path = "sdmlab/SM_dataset/trained_model/SM_trainedmodel.ckpt"
10+
11+
# Download to a temporary file
12+
with fs.open(bucket_path, 'rb') as s3file:
13+
with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_ckpt:
14+
tmp_ckpt.write(s3file.read())
15+
tmp_ckpt_path = tmp_ckpt.name
16+
17+
# Load checkpoint
18+
checkpoint = torch.load(tmp_ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
19+
model.load_state_dict(checkpoint['state_dict'])
20+
21+
# Move model to device
22+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23+
model.to(device)
24+
model.eval()
25+
26+
return model, device
27+
28+
#WEIGHTED AVERAGE for patch
29+
def create_weight_map(M: int, N: int):
30+
weight_map = np.zeros((M, N), dtype=np.float32)
31+
center_x, center_y = M // 2, N // 2
32+
for i in range(M):
33+
for j in range(N):
34+
dist_sq = (i - center_x)**2 + (j - center_y)**2
35+
weight = np.exp(-dist_sq / (2 * (min(M, N) / 2)**2))
36+
weight_map[i, j] = weight
37+
return torch.from_numpy(weight_map).float().unsqueeze(0).unsqueeze(0) # (1, 1, M, N)
38+
39+
#If there is Stride
40+
def predict_on_area(dataset, model, shape: torch.Tensor, M: int = 256, N: int = 256, stride: int = 128, device=None):
41+
# Get row and col size
42+
shape_row = shape.size(1)
43+
shape_col = shape.size(2)
44+
45+
# Pad if needed
46+
pad_h = (stride * ((shape_row - M) // stride + 1) + M - shape_row)
47+
pad_w = (stride * ((shape_col - N) // stride + 1) + N - shape_col)
48+
49+
if pad_h > 0 or pad_w > 0:
50+
padding = (0, pad_w, 0, pad_h)
51+
shape = nn.functional.pad(shape, padding, mode='constant', value=0)
52+
53+
# Update new shape after padding
54+
new_row = shape.size(1)
55+
new_col = shape.size(2)
56+
57+
# Separate X and Y
58+
X = shape[dataset.x_feature_index]
59+
y = shape[dataset.y_feature_index]
60+
61+
# Initialize weighted prediction sum and weight sum arrays
62+
weighted_prediction_sum = torch.zeros((1, new_row, new_col), device=device)
63+
weight_sum = torch.zeros((1, new_row, new_col), device=device)
64+
65+
# Create the weight map
66+
weight_map = create_weight_map(M, N).to(device)
67+
68+
# Loop over patches
69+
for start_i in range(0, new_row - M + 1, stride):
70+
for start_j in range(0, new_col - N + 1, stride):
71+
end_i = start_i + M
72+
end_j = start_j + N
73+
patch = X[:, start_i:end_i, start_j:end_j].unsqueeze(0).to(device)
74+
75+
with torch.no_grad():
76+
patch_prediction_raw = model(patch)
77+
78+
weighted_prediction = patch_prediction_raw * weight_map
79+
weighted_prediction_sum[:, start_i:end_i, start_j:end_j] += weighted_prediction.squeeze(0)
80+
weight_sum[:, start_i:end_i, start_j:end_j] += weight_map.squeeze(0)
81+
82+
epsilon = 1e-8
83+
final_prediction = weighted_prediction_sum / (weight_sum + epsilon)
84+
final_prediction = (final_prediction > 0.01).float()
85+
86+
# Crop back to original shape (before padding)
87+
final_prediction = final_prediction[:, :shape_row, :shape_col]
88+
y = y[:, :shape_row, :shape_col]
89+
lf = shape[[dataset.lf_index]][:, :shape_row, :shape_col]
90+
91+
return final_prediction.cpu(), y.cpu(), lf.cpu()
92+
93+
#Save the tif file
94+
def save_image(image: torch.Tensor, path: Path, reference_tif: str):
95+
"""Save the image as a .tif file.
96+
97+
Args:
98+
image (torch.Tensor): The image to save
99+
path (Path): The path to save the image
100+
"""
101+
image_np = image.squeeze().cpu().numpy().astype('float32')
102+
with rasterio.open(reference_tif) as ref:
103+
meta = ref.meta.copy()
104+
meta.update({
105+
"driver": "GTiff",
106+
"height": image_np.shape[0],
107+
"width": image_np.shape[1],
108+
"count": 1,
109+
"dtype": 'float32'
110+
})
111+
112+
with rasterio.open(path, 'w', **meta) as dst:
113+
dst.write(image_np, 1)
114+
mask_with_PWB(path, path)
115+
116+
with rasterio.open(path, 'r+') as dst:
117+
data = dst.read(1)
118+
binary_data = np.where(data > 0, 1, 0).astype(np.uint8)
119+
dst.write(binary_data, 1)
120+
121+
compress_tif_lzw(path)
122+
123+
124+
#ENHANCE THE LOW-FIDELITY FLOOD MAP
125+
def Predict_FM(huc_id, patch_size=(256, 256)):
126+
127+
data_dir = Path(f'./HUC{huc_id}_forcings/')
128+
model = AttentionUNet(channel=8)
129+
130+
preprocessor = InferenceDataPreprocessor(data_dir=Path(data_dir), patch_size=patch_size, verbose=True)
131+
132+
print("Loading model...")
133+
model, device = load_model(model)
134+
print("Model loaded.")
135+
136+
137+
lf_files = preprocessor.get_all_lf_maps(huc_id)
138+
for lf_path in lf_files:
139+
lf_filename = lf_path.name
140+
print(f"Predicting for: {lf_filename}\n")
141+
142+
print(f"Loading static features for HUC {huc_id}...")
143+
static_stack = preprocessor.get_static_stack(huc_id)
144+
lf_tensor = preprocessor.tif_to_tensor(lf_path, feature_name='low_fidelity')
145+
146+
# Combine and validate
147+
area_tensor = torch.cat([static_stack, lf_tensor], dim=0)
148+
if area_tensor.shape[0] != 8:
149+
raise ValueError(f"Expected 8 channels, got {area_tensor.shape[0]} — check missing static feature for HUC {huc_id}.")
150+
151+
# Define dummy interface
152+
class Dummy:
153+
x_feature_index = list(range(area_tensor.shape[0]))
154+
y_feature_index = [area_tensor.shape[0] - 1]
155+
lf_index = area_tensor.shape[0] - 1
156+
157+
print(f"Static features loaded for {huc_id}.\n")
158+
159+
# Predict
160+
print(f"Enhancing {lf_path}...")
161+
x, y, lf = predict_on_area(Dummy, model, area_tensor, M=patch_size[0], N=patch_size[1], stride=patch_size[0] // 2, device=device)
162+
163+
# Save result
164+
pred_dir = Path(f"./Results/HUC{huc_id}/")
165+
pred_dir.mkdir(parents=True, exist_ok=True)
166+
pred_path = pred_dir / f"SMprediction_{lf_filename}"
167+
save_image(x, pred_path, lf_path)
168+
print(f"Enhancement completed for {lf_filename}.\n")
169+
170+
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
import rasterio
3+
import numpy as np
4+
from pathlib import Path
5+
import torch.nn as nn
6+
from scipy.stats import boxcox
7+
8+
GLOBAL_STATS = {
9+
'elevation': {'mean': 651.62, 'std': 935.30},
10+
'slope': {'mean': 1.65, 'std': 2.50},
11+
'flow_acc': {'mean': -101.88, 'std': 65.45},
12+
'twi': {'mean': -4.23, 'std': 2.27, 'min': -11.01, 'max': 25.39},
13+
'curve_number': {'mean': 73.96, 'std': 13.29, 'min': 30.12, 'max': 100.00},
14+
'soil_moisture': {'mean': 28.06, 'std': 9.94, 'min': 0.00, 'max': 76.13},
15+
}
16+
17+
# INFERENCE DATA PREPROCESSOR
18+
class InferenceDataPreprocessor:
19+
20+
STATIC_FEATURES = ['curve_number', 'elevation', 'flow_acc', 'lulc', 'slope', 'soil_moisture', 'twi']
21+
LF_KEYWORDS = ['hand']
22+
23+
FEATURE_FILENAME_MAP = {
24+
'curve_number': 'CN',
25+
'flow_acc': 'flowacc',
26+
'soil_moisture': 'SM',
27+
'elevation': 'elevation',
28+
'slope': 'slope',
29+
'twi': 'twi',
30+
'lulc': 'LULC'
31+
}
32+
33+
def __init__(self, data_dir: Path, patch_size=(128, 128), global_stats=None, verbose=False):
34+
self.data_dir = Path(data_dir)
35+
self.M, self.N = patch_size
36+
self.verbose = verbose
37+
self.global_stats = global_stats if global_stats else GLOBAL_STATS
38+
39+
def tif_to_tensor(self, path: Path, feature_name: str = None) -> torch.Tensor:
40+
with rasterio.open(path) as src:
41+
array = src.read(1).astype(np.float32)
42+
nodata_value = src.nodata
43+
if nodata_value is not None:
44+
array[array == nodata_value] = np.nan
45+
array = np.nan_to_num(array, nan=0.0)
46+
tensor = torch.tensor(array, dtype=torch.float32).unsqueeze(0)
47+
48+
if feature_name == 'elevation':
49+
mean, std = self.global_stats['elevation'].values()
50+
tensor = (tensor - mean) / (std + 1e-7)
51+
elif feature_name == 'slope':
52+
tensor = self.apply_boxcox(tensor)
53+
mean, std = self.global_stats['slope'].values()
54+
tensor = (tensor - mean) / (std + 1e-7)
55+
elif feature_name == 'flow_acc':
56+
tensor = self.apply_boxcox(tensor)
57+
mean, std = self.global_stats['flow_acc'].values()
58+
tensor = (tensor - mean) / (std + 1e-7)
59+
elif feature_name == 'lulc':
60+
reclass_map = {1: 1, 2: 2, 4: 3, 3: 4, 8: 5, 6: 6, 7: 7, 5: 8, 9: 9}
61+
array = array.astype(np.int32)
62+
reclass_array = np.vectorize(lambda x: reclass_map.get(x, 0))(array).astype(np.float32)
63+
tensor = torch.tensor(reclass_array, dtype=torch.float32).unsqueeze(0)
64+
elif feature_name == 'low_fidelity':
65+
tensor = (tensor > 0).float()
66+
elif feature_name in self.global_stats:
67+
min_val = self.global_stats[feature_name]['min']
68+
max_val = self.global_stats[feature_name]['max']
69+
tensor = (tensor - min_val) / (max_val - min_val + 1e-7)
70+
71+
return tensor
72+
73+
def apply_boxcox(self, tensor: torch.Tensor, lmbda=0.5) -> torch.Tensor:
74+
tensor = tensor + 1e-6
75+
flat_np = tensor.flatten().numpy()
76+
transformed = boxcox(flat_np, lmbda=lmbda)
77+
return torch.tensor(transformed).reshape(tensor.shape)
78+
79+
def patchify(self, data: torch.Tensor):
80+
C, H, W = data.shape
81+
stride_h = stride_w = self.M // 2
82+
pad_h = (stride_h * ((H - self.M) // stride_h + 1) + self.M - H) % stride_h
83+
pad_w = (stride_w * ((W - self.N) // stride_w + 1) + self.N - W) % stride_w
84+
padded = nn.functional.pad(data, (0, int(pad_w), 0, int(pad_h)), mode='constant', value=0)
85+
patches = padded.unfold(1, self.M, stride_h).unfold(2, self.N, stride_w)
86+
patches = patches.permute(1, 2, 0, 3, 4).reshape(-1, C, self.M, self.N)
87+
return patches
88+
89+
def get_static_stack(self, huc_id: str):
90+
tensors = []
91+
for feature in self.STATIC_FEATURES:
92+
search_key = self.FEATURE_FILENAME_MAP[feature]
93+
match = list(self.data_dir.glob(f"*{search_key}*{huc_id}*.tif"))
94+
if not match:
95+
print(f"Missing static feature: {feature} for {huc_id}")
96+
return None
97+
tensors.append(self.tif_to_tensor(match[0], feature_name=feature))
98+
return torch.cat(tensors, dim=0)
99+
100+
def get_all_lf_maps(self, huc_id: str):
101+
return sorted([
102+
f for f in self.data_dir.glob(f"*{huc_id}*.tif")
103+
if any(k in f.name.lower() for k in self.LF_KEYWORDS)
104+
])
105+
106+
def preprocess_all_lf_maps(self, huc_id: str):
107+
static_stack = self.get_static_stack(huc_id)
108+
if static_stack is None:
109+
return []
110+
111+
lf_files = self.get_all_lf_maps(huc_id)
112+
results = []
113+
114+
for lf_path in lf_files:
115+
lf_tensor = self.tif_to_tensor(lf_path, feature_name='low_fidelity')
116+
combined = torch.cat([static_stack, lf_tensor], dim=0)
117+
patches = self.patchify(combined)
118+
results.append((lf_path.name, patches, lf_path))
119+
if self.verbose:
120+
print(f"Processed {lf_path.name} with {patches.shape[0]} patches.")
121+
122+
return results
123+

src/fimserve/enhancement_withSM/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)