Skip to content

Commit b218cbf

Browse files
author
MariusWiggert
committed
Added spatial and temporal interpolation functions to the simulation_utils.py
1 parent 6a33399 commit b218cbf

3 files changed

Lines changed: 143 additions & 2 deletions

File tree

ocean_navigation_simulator/utils/plotting_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def visualize_currents(time, grids_dict, u_data, v_data, vmin=0, vmax=None, alph
7070
if vmax is None:
7171
vmax = np.max(magnitude)
7272
if autoscale:
73-
plt.imshow(magnitude, extent=[grids_dict['x_grid'][0], grids_dict['x_grid'][-1],
73+
plt.imshow(np.flip(magnitude, axis=0), extent=[grids_dict['x_grid'][0], grids_dict['x_grid'][-1],
7474
grids_dict['y_grid'][0], grids_dict['y_grid'][-1]],
7575
aspect='auto',
7676
cmap='jet', vmin=vmin, vmax=vmax, alpha=alpha)
7777
else:
78-
plt.imshow(magnitude, extent=[grids_dict['x_grid'][0], grids_dict['x_grid'][-1],
78+
plt.imshow(np.flip(magnitude, axis=0), extent=[grids_dict['x_grid'][0], grids_dict['x_grid'][-1],
7979
grids_dict['y_grid'][0], grids_dict['y_grid'][-1]],
8080
cmap='jet', vmin=vmin, vmax=vmax, alpha=alpha)
8181

ocean_navigation_simulator/utils/simulation_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import math
66
import warnings
7+
from scipy import interpolate
78

89

910
def convert_to_lat_lon_time_bounds(x_0, x_T, deg_around_x0_xT_box, temp_horizon_in_h):
@@ -215,6 +216,95 @@ def get_current_data_subset_from_daily_files(t_interval, lat_interval, lon_inter
215216
return grids_dict, full_u_data.filled(fill_value=0.), full_v_data.filled(fill_value=0.)
216217

217218

219+
# Functions to do interpolation of the current data
220+
# general interpolation function
221+
def spatio_temporal_interpolation(grids_dict, u_data, v_data,
222+
temp_res_in_h=None, spatial_shape=None, spatial_kind='cubic'):
223+
"""Spatio-temporal interpolation of the current data to a new temp_res_in_h and new spatial_shape.
224+
Inputs:
225+
- grids_dict containing at least x_grid', 'y_grid'
226+
- u_data, v_data [T, Y, X] matrix of the current data
227+
- temp_res_in_h desired temporal resolution of the data
228+
- spatial_shape Desired spatial shape as tuple or list e.g. (<# of y_points>, <# of x_points>)
229+
- spatial_kind which interpolation to use, options are 'cubic', 'linear'
230+
231+
Outputs:
232+
- grids_dict updated grids_dict
233+
- u_data_new, v_data_new defined just like input
234+
"""
235+
# copy dict object to not inadvertantly change the original
236+
new_grids_dict = grids_dict.copy()
237+
if temp_res_in_h is not None:
238+
new_grids_dict['t_grid'], u_data, v_data = temporal_interpolation(grids_dict, u_data, v_data, temp_res_in_h)
239+
if spatial_shape is not None:
240+
new_grids_dict['x_grid'], new_grids_dict['y_grid'], u_data, v_data = spatial_interpolation(grids_dict, u_data, v_data,
241+
target_shape=spatial_shape, kind=spatial_kind)
242+
return new_grids_dict, u_data, v_data
243+
244+
# spatial interpolation function
245+
def spatial_interpolation(grid_dict, u_data, v_data, target_shape, kind='cubic'):
246+
"""Doing spatial interpolation to a specific spatial shape e.g. (100, 100).
247+
Inputs:
248+
- grid_dict containing at least x_grid', 'y_grid'
249+
- u_data, v_data [T, Y, X] matrix of the current data
250+
- target_shape Shape as tuple or list e.g. (<# of y_points>, <# of x_points>)
251+
- kind which interpolation to use, options are 'cubic', 'linear'
252+
253+
Outputs:
254+
- x_grid_new, x_grid_new arrays of the new x and y grid
255+
- u_data_new, v_data_new defined just like input
256+
"""
257+
# Step 1: create the new x and y axis vectors
258+
x_grid_new = np.arange(target_shape[1]) * (grid_dict['x_grid'][-1] - grid_dict['x_grid'][0]) / (
259+
target_shape[1] - 1) + grid_dict['x_grid'][0]
260+
y_grid_new = np.arange(target_shape[0]) * (grid_dict['y_grid'][-1] - grid_dict['y_grid'][0]) / (
261+
target_shape[0] - 1) + grid_dict['y_grid'][0]
262+
263+
# create the arrays to fill in with the new-resolution data
264+
u_data_new = np.zeros(shape=(u_data.shape[0], target_shape[0], target_shape[1]))
265+
v_data_new = np.zeros(shape=(u_data.shape[0], target_shape[0], target_shape[1]))
266+
267+
# Step 2: iterate over the time axis to create the new u and v data
268+
for t_idx in range(u_data.shape[0]):
269+
# run spatial interpolation in 2D along the new axis
270+
u_data_new[t_idx, :, :] = interpolate.interp2d(grid_dict['x_grid'], grid_dict['y_grid'],
271+
u_data[t_idx, :, :], kind=kind)(x_grid_new, y_grid_new)
272+
v_data_new[t_idx, :, :] = interpolate.interp2d(grid_dict['x_grid'], grid_dict['y_grid'],
273+
v_data[t_idx, :, :], kind=kind)(x_grid_new, y_grid_new)
274+
275+
return x_grid_new, y_grid_new, u_data_new, v_data_new
276+
277+
# temporal interpolation function
278+
def temporal_interpolation(grids_dict, u_data, v_data, temp_res_in_h):
279+
"""Doing linear temporal interpolation of the u and v data for a specific resolution.
280+
Inputs:
281+
- grid_dict containing at least x_grid', 'y_grid'
282+
- u_data, v_data [T, Y, X] matrix of the current data
283+
- temp_res_in_h desired temporal resolution in hours
284+
285+
Outputs:
286+
- t_grid_new arrays of the new t grid
287+
- u_data_new, v_data_new defined just like input
288+
"""
289+
# check
290+
if temp_res_in_h <= 0:
291+
raise ValueError("Temporal resolution must be positive")
292+
t_span_in_h = (grids_dict['t_grid'][-1] - grids_dict['t_grid'][0]) / 3600
293+
# TODO: Think if we want to implement temporal aggregation
294+
# # Case 1: aggregation, we need to average over the values
295+
# if temp_res_in_h >= (time_span_in_s/(3600*grids_dict['t_grid'].shape[0])):
296+
# print("Not yet implemented")
297+
# Case 2: interpolation
298+
# get the integer of how many elements the new t_grid will have
299+
n_new_t_grid = int(t_span_in_h/temp_res_in_h)
300+
# get the new t_grid
301+
new_t_grid = grids_dict['t_grid'][0] + np.arange(n_new_t_grid + 1) * (t_span_in_h/n_new_t_grid) * 3600
302+
# perform the 1D interpolations
303+
new_u_data = interpolate.interp1d(grids_dict['t_grid'], u_data, axis=0, kind='linear')(new_t_grid)
304+
new_v_data = interpolate.interp1d(grids_dict['t_grid'], v_data, axis=0, kind='linear')(new_t_grid)
305+
# return new t_grid and values
306+
return new_t_grid, new_u_data, new_v_data
307+
218308
# Helper helper functions
219309
def get_abs_time_grid_from_hycom_file(f):
220310
"""Helper function to extract the t_grid in UTC POSIX time from a HYCOM File f."""
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import matplotlib.pyplot as plt
2+
from scipy import interpolate
3+
import numpy as np
4+
import sys
5+
from ocean_navigation_simulator.problem import Problem
6+
from ocean_navigation_simulator.utils import simulation_utils, plotting_utils
7+
from ocean_navigation_simulator import OceanNavSimulator
8+
from ocean_navigation_simulator.planners import HJReach2DPlanner
9+
import datetime
10+
import os
11+
import hj_reachability as hj
12+
import time
13+
#%%
14+
platform_config_dict = {'battery_cap': 400.0, 'u_max': 0.1, 'motor_efficiency': 1.0,
15+
'solar_panel_size': 0.5, 'solar_efficiency': 0.2, 'drag_factor': 675}
16+
17+
# Create the navigation problem
18+
t_0 = datetime.datetime(2021, 6, 1, 12, 10, 10, tzinfo=datetime.timezone.utc)
19+
x_0 = [-88.0, 25.0, 1] # lon, lat, battery
20+
# is on land so we can check the land-mask.
21+
# x_0 = [-88.0, 20.0, 1] # lon, lat, battery
22+
x_T = [-88.2, 26.3]
23+
hindcast_folder = "data/hindcast_test/"
24+
forecast_folder = "data/forecast_test/"
25+
forecast_delay_in_h = 0.
26+
plan_on_gt=True
27+
prob = Problem(x_0, x_T, t_0,
28+
platform_config_dict=platform_config_dict,
29+
hindcast_folder= hindcast_folder,
30+
forecast_folder=forecast_folder,
31+
plan_on_gt = plan_on_gt,
32+
forecast_delay_in_h=forecast_delay_in_h)
33+
34+
#%% Let's plot the field normally
35+
import datetime
36+
grids_dict, u_data, v_data = simulation_utils.get_current_data_subset(
37+
t_interval=[datetime.datetime(2021, 6, 1, 12, 0, tzinfo=datetime.timezone.utc),
38+
datetime.datetime(2021, 6, 2, 11, 0, tzinfo=datetime.timezone.utc)],
39+
lat_interval=[21, 22],
40+
lon_interval=[-87, -86],
41+
file_dicts=prob.hindcasts_dicts)
42+
43+
# viz
44+
plotting_utils.visualize_currents(time=datetime.datetime(2021, 6, 1, 12, 0, tzinfo=datetime.timezone.utc).timestamp()
45+
, grids_dict=grids_dict, u_data=u_data, v_data=v_data)
46+
#%% Check spatial resolution
47+
grid_dict_new, u_data_new, v_data_new = simulation_utils.spatio_temporal_interpolation(grids_dict, u_data, v_data,
48+
temp_res_in_h=0.5, spatial_shape=(50,50), spatial_kind='cubic')
49+
50+
plotting_utils.visualize_currents(time=datetime.datetime(2021, 6, 1, 12, 0, tzinfo=datetime.timezone.utc).timestamp()
51+
, grids_dict=grid_dict_new, u_data=u_data_new, v_data=v_data_new)

0 commit comments

Comments
 (0)