Skip to content

Commit 949b878

Browse files
committed
add dask deepforest code
1 parent 2aa23ec commit 949b878

5 files changed

Lines changed: 259 additions & 1 deletion

File tree

README.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,27 @@
1-
# NeonTreeClassification
1+
# NeonTreeClassification
2+
3+
National Ecological Observatory Network (NEON) offers a variety of data products, including airborne data from different forest sites. Airborne data includes RGB orthophotos, LiDAR (CHM) airborne data, and 426 band hyperspectral data. All products are available on https://data.neonscience.org/data-products/, the following are the airborne data products used in this repository:
4+
5+
rgb_data_product = 'DP3.30010.001'
6+
hsi_withbrdf_2022 = 'DP3.30006.002'
7+
lidar = 'DP3.30015.001' #CHM
8+
9+
# Workflow
10+
## 1. Download NEON data
11+
- Given the Northing, Easting, Year and Site, download the NEON data using the `download_neon_data.py` script. There are functions to download the RGB, HSI, and LiDAR data. The data is downloaded to a specified directory.
12+
### To do:
13+
- Merge this script in neon_utils.py
14+
- Look into using Google Earth Engine
15+
16+
## 2. Generate crowns using deepforest
17+
- The `deepforest_parallel.py` script uses the deepforest package to generate tree crowns from the RGB data. The script is parallelized on SLURM using Dask. It can run on a given list of RGB tiles and save a pandas dataframe with the tree crowns.
18+
19+
# Citations
20+
21+
## NEON Airborne Data Products
22+
NEON (National Ecological Observatory Network). High-resolution orthorectified camera imagery mosaic (DP3.30010.001), RELEASE-2025. https://doi.org/10.48443/gdgn-3r69. Dataset accessed from https://data.neonscience.org/data-products/DP3.30010.001/RELEASE-2025 on April 3, 2025.
23+
24+
NEON (National Ecological Observatory Network). Spectrometer orthorectified surface bidirectional reflectance - mosaic (DP3.30006.002), provisional data. Dataset accessed from https://data.neonscience.org/data-products/DP3.30006.002 on April 3, 2025. Data archived at [your DOI].
25+
26+
NEON (National Ecological Observatory Network). Ecosystem structure (DP3.30015.001), RELEASE-2025. https://doi.org/10.48443/jqqd-1n30. Dataset accessed from https://data.neonscience.org/data-products/DP3.30015.001/RELEASE-2025 on April 3, 2025.
27+

SLURM/dask.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=dask_master
3+
#SBATCH --mail-type=END,FAIL
4+
#SBATCH --mail-user=riteshchowdhry@ufl.edu
5+
#SBATCH --account=azare
6+
#SBATCH --partition=gpu
7+
#SBATCH --output=/home/riteshchowdhry/logs/macrosystems/dask_deepforest/master_%j.out
8+
#SBATCH --ntasks=1
9+
#SBATCH --cpus-per-task=2
10+
#SBATCH --mem=50G
11+
#SBATCH --time=48:00:00
12+
#SBATCH --partition=gpu
13+
#SBATCH --constraint=ai
14+
#SBATCH --gpus=1
15+
16+
date; hostname
17+
module load conda
18+
conda activate dfor_311
19+
pwd
20+
21+
22+
srun -u python dask_deepforest_slurm.py
23+
24+
date

src/deepforest_parallel.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import glob
3+
import time
4+
from dask_jobqueue import SLURMCluster
5+
from dask.distributed import Client, wait
6+
7+
8+
from deepforest import main
9+
10+
def run_deepforest(rgb_tile, save_path, patch_size=400, patch_overlap=0.25):
11+
"""
12+
Predict tree crowns using a pretrained deepforest model given a rgb tile
13+
Args:
14+
rgb_tile: Path to RGB tile image
15+
save_path: Path to save the predictions
16+
patch_size: DeepForest divides the big tile into smaller patches, default is 400
17+
patch_overlap: Overlap between patches, default is 0.25
18+
returns:
19+
filename: Path to the CSV of saved predictions
20+
"""
21+
model = main.deepforest()
22+
# Load a pretrained tree detection model from Hugging Face
23+
model.load_model(model_name="weecology/deepforest-tree", revision="main")
24+
predicted_raster = model.predict_tile(rgb_tile, patch_size=patch_size, patch_overlap=patch_overlap)
25+
filename = os.path.join(save_path, 'predicted_boxes_' + os.path.basename(rgb_tile).split('.')[0]+'.csv')
26+
predicted_raster.to_csv(filename)
27+
print(f"\nPredicted boxes for {rgb_tile} saved to {filename}")
28+
return filename
29+
30+
if __name__ == "__main__":
31+
32+
rgb_tiles_dir = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/RGB/Mosaic/'
33+
all_rgb_tiles = glob.glob(os.path.join(rgb_tiles_dir, '*.tif'))
34+
output_dir = "/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/deepforest_crowns"
35+
df_patch_size = 400
36+
df_patch_overlap = 0.25
37+
num_workers = 40 # be careful with this, memory will be allocated for each worker, so if you have 40 workers and each worker has 35GB of memory, you will need 1.4TB of memory on the node.
38+
39+
slurm_args = [
40+
"--job-name=deepforest",
41+
"--account=azare",
42+
"--mail-type=END,FAIL",
43+
"--mail-user=riteshchowdhry@ufl.edu",
44+
"--output=/home/riteshchowdhry/logs/macrosystems/dask_deepforest/harv2022%j.out",
45+
"--partition=gpu",
46+
"--constraint=ai",
47+
"--gpus=1",
48+
"--time=24:00:00",
49+
]
50+
51+
cluster = SLURMCluster(
52+
cores=6,
53+
memory='35GB',
54+
processes=1,
55+
walltime='24:00:00',
56+
scheduler_options={"dashboard_address": ":8787"},
57+
job_extra_directives=slurm_args,
58+
local_directory='/home/riteshchowdhry/logs/macrosystems/dask_deepforest/',
59+
death_timeout=300,
60+
)
61+
62+
print("Job script template:")
63+
print(cluster.job_script())
64+
65+
print(f"Scaling cluster to {num_workers} workers")
66+
cluster.scale(num_workers)
67+
68+
# Allow time for workers to start
69+
print("Waiting for workers to start...")
70+
time.sleep(30)
71+
72+
# Connect client to cluster
73+
print("Connecting client to cluster")
74+
dask_client = Client(cluster)
75+
print(f"Dashboard link: {dask_client.dashboard_link}")
76+
77+
futures = []
78+
for i, rgb_tile in enumerate(all_rgb_tiles):
79+
print(f"Submitting task {i+1}/{len(all_rgb_tiles)}: {rgb_tile}")
80+
future = dask_client.submit(run_deepforest, rgb_tile, output_dir, df_patch_size, df_patch_overlap)
81+
futures.append(future)
82+
83+
84+
# Wait for tasks with progress reporting
85+
print(f"Waiting for {len(futures)} tasks to complete...")
86+
completed = 0
87+
for future in futures:
88+
try:
89+
result = future.result()
90+
completed += 1
91+
print(f"Task completed: {completed}/{len(futures)}")
92+
except Exception as e:
93+
print(f"{completed} Task failed with error: {str(e)}")
94+
95+
print("All tasks completed or failed.")
96+
97+
# Close connections
98+
print("Closing Dask client")
99+
dask_client.close()
100+
print("Closing cluster")
101+
cluster.close()
102+
print("Script completed.")

src/download_neon_data.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import argparse
2+
import os
3+
import site
4+
import rpy2
5+
import rpy2.robjects as robjects
6+
from rpy2.robjects.packages import importr
7+
from rpy2.rinterface_lib.callbacks import logger as rpy2_logger
8+
import logging
9+
rpy2_logger.setLevel(logging.ERROR)
10+
11+
base = importr('base')
12+
utils = importr('utils')
13+
stats = importr('stats')
14+
neonUtilities = importr('neonUtilities')
15+
16+
17+
def rgb_data_download(easting, northing, site, year, output_dir):
18+
"""
19+
Download RGB data from NEON for a specific site and year. Pass a list of easting and northing coordinates to download all the tiles that span the area.
20+
Args:
21+
easting: List of easting coordinates
22+
northing: List of northing coordinates
23+
site: NEON site code (e.g., 'HARV')
24+
year: Year of data to download (e.g., '2022')
25+
output_dir: Directory to save the downloaded data
26+
"""
27+
# Download the RGB data
28+
neonUtilities.byTileAOP(dpID=rgb_data_product, site=site, year=year,
29+
check_size=False,
30+
easting=easting, northing=northing,
31+
include_provisional = True,
32+
savepath=output_dir);
33+
34+
35+
def hsi_withbrdf_data_download(easting, northing, site, year, output_dir):
36+
"""
37+
Download HSI data from NEON for a specific site and year. Pass a list of easting and northing coordinates to download all the tiles that span the area.
38+
Args:
39+
easting: List of easting coordinates
40+
northing: List of northing coordinates
41+
site: NEON site code (e.g., 'HARV')
42+
year: Year of data to download (e.g., '2022')
43+
output_dir: Directory to save the downloaded data
44+
"""
45+
# Download the HSI data with BRDF correction
46+
neonUtilities.byTileAOP(dpID=hsi_withbrdf_2022, site=site, year=year,
47+
check_size=False,
48+
easting=easting, northing=northing,
49+
include_provisional = True,
50+
savepath=output_dir);
51+
52+
def lidar_chm_data_download(easting, northing, site, year, output_dir):
53+
"""
54+
Download LiDAR CHM data from NEON for a specific site and year. Pass a list of easting and northing coordinates to download all the tiles that span the area.
55+
Args:
56+
easting: List of easting coordinates
57+
northing: List of northing coordinates
58+
site: NEON site code (e.g., 'HARV')
59+
year: Year of data to download (e.g., '2022')
60+
output_dir: Directory to save the downloaded data
61+
"""
62+
# Download the LiDAR CHM data
63+
neonUtilities.byTileAOP(dpID=lidar, site=site, year=year,
64+
check_size=False,
65+
easting=easting, northing=northing,
66+
include_provisional = True,
67+
savepath=output_dir);
68+
69+
70+
if __name__ == "__main__":
71+
72+
parser = argparse.ArgumentParser(description='Download NEON data for a specific site and year.')
73+
parser.add_argument('--site', type=str, required=True, help='NEON site code (e.g., HARV)')
74+
parser.add_argument('--year', type=str, required=True, help='Year of data to download (e.g., 2022)')
75+
parser.add_argument('--easting_start', type=int, required=True, help='Starting easting coordinate')
76+
parser.add_argument('--easting_end', type=int, required=True, help='Ending easting coordinate')
77+
parser.add_argument('--northing_start', type=int, required=True, help='Starting northing coordinate')
78+
parser.add_argument('--northing_end', type=int, required=True, help='Ending northing coordinate')
79+
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the downloaded data')
80+
args = parser.parse_args()
81+
site = args.site
82+
year = args.year
83+
easting_start = args.easting_start
84+
easting_end = args.easting_end
85+
northing_start = args.northing_start
86+
northing_end = args.northing_end
87+
output_dir = args.output_dir
88+
89+
90+
rgb_data_product = 'DP3.30010.001'
91+
hsi_withbrdf_2022 = 'DP3.30006.002'
92+
lidar = 'DP3.30015.001'
93+
94+
# to do: add a function to get the easting and northing coordinates from the site name
95+
96+
# rgb_path = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/RGB'
97+
# rgb_data_download(easting, northing, site, year, rgb_path)
98+
99+
hsi_path = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/HSI'
100+
hsi_withbrdf_data_download(easting, northing, site, year, hsi_path)

src/generate_crowns.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import os
3+
import glob
34
import rasterio
45
import matplotlib.pyplot as plt
56

@@ -123,7 +124,12 @@ def run_deepforest(rgb_tile, save_path, patch_size=400, patch_overlap=0.25):
123124
predicted_raster.to_csv(filename)
124125
print(f"Predicted boxes for {rgb_tile} saved to {filename}")
125126
return filename
127+
126128
if __name__ == "__main__":
129+
130+
rgb_tiles_dir = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/RGB/Mosaic/'
131+
all_rgb_tiles = glob.glob(os.path.join(rgb_tiles_dir, '*.tif'))
132+
127133
# given a bbox (from deepforest) in RGB coordinates
128134
rgb_file = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/RGB/Mosaic/2022_HARV_7_734000_4709000_image.tif'
129135
hsi_file = '/blue/azare/riteshchowdhry/Macrosystems/Data_files/unlabeled_data/HARV/HSI/HSI_tif/2022_HARV_7_734000_4709000_image_hyperspectral.tif'

0 commit comments

Comments
 (0)