Skip to content

Commit 4a7abce

Browse files
author
swinersha
committed
fix: sppeds up find pairs by sampling M to the size of K
1 parent 088048e commit 4a7abce

2 files changed

Lines changed: 80 additions & 277 deletions

File tree

methods/matching/cluster_find_pairs_interactive.py

Lines changed: 48 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.cluster import KMeans
88
import matplotlib.pyplot as plt
99
import geopandas as gpd
10+
from pyproj import Proj, transform
1011
import os
1112
import time
1213
import sys
@@ -108,14 +109,53 @@ def calculate_smd(group1, group2):
108109

109110
# Define the start year
110111
t0 = 2012 # READ THIS IN
111-
match_years = [t0-10, t0-5, t0]
112112

113113
# Read in the data
114114
boundary = gpd.read_file('/maps/aew85/projects/1201.geojson')
115115

116-
k_pixels = pd.read_parquet('/maps/tws36/tmf_pipe_out/1201/k_all.parquet')
116+
k_pixels = pd.read_parquet('/maps/tws36/tmf_pipe_out/1201/k.parquet')
117+
# k_pixels = pd.read_parquet('/maps/tws36/tmf_pipe_out/1201/k_all.parquet')
117118
m_pixels = pd.read_parquet('/maps/aew85/tmf_pipe_out/1201/matches.parquet')
118119

120+
121+
t0 = 2018
122+
boundary = gpd.read_file('/maps/aew85/projects/ona.geojson')
123+
k_pixels = pd.read_parquet('/maps/aew85/tmf_pipe_out/fastfp_test_ona/k.parquet')
124+
m_pixels = pd.read_parquet('/maps/aew85/tmf_pipe_out/fastfp_test_ona/matches.parquet')
125+
126+
if(m_pixels.shape[0] > (k_pixels.shape[0])):
127+
m_sub_size = int(k_pixels.shape[0]) # First down sample M as it is ~230 million points
128+
m_random_indices = np.random.choice(m_pixels.shape[0], size=m_sub_size, replace=False)
129+
m_pixels = m_pixels.iloc[m_random_indices]
130+
131+
# # Calculate the central coordinates (centroid)
132+
# central_lat = m_pixels['lat'].mean()
133+
# central_lon = m_pixels['lng'].mean()
134+
# aeqd_proj = f"+proj=aeqd +lat_0={central_lat} +lon_0={central_lon} +datum=WGS84"
135+
136+
# # Convert the DataFrame to a GeoDataFrame
137+
# m_gdf = gpd.GeoDataFrame(m_pixels, geometry=gpd.points_from_xy(m_pixels.lng, m_pixels.lat))
138+
# # Set the original CRS to WGS84 (EPSG:4326)
139+
# m_gdf.set_crs(epsg=4326, inplace=True)
140+
141+
# # Transform the GeoDataFrame to the AEQD projection
142+
# m_gdf_aeqd = m_gdf.to_crs(aeqd_proj)
143+
144+
# # Extract the transformed coordinates
145+
# gdf_aeqd['aeqd_x'] = gdf_aeqd.geometry.x
146+
# gdf_aeqd['aeqd_y'] = gdf_aeqd.geometry.y
147+
148+
# # Define the grid resolution in meters
149+
# grid_resolution_m = 5000 # 5 km
150+
151+
# # Calculate grid cell indices
152+
# gdf_aeqd['grid_x'] = (gdf_aeqd['aeqd_x'] // grid_resolution_m).astype(int)
153+
# gdf_aeqd['grid_y'] = (gdf_aeqd['aeqd_y'] // grid_resolution_m).astype(int)
154+
155+
# # Print the first few rows to verify
156+
# print(gdf_aeqd.head())
157+
158+
119159
# concat m and k
120160
km_pixels = pd.concat([k_pixels.assign(trt='trt', ID=range(0, len(k_pixels))),
121161
m_pixels.assign(trt='ctrl', ID=range(0, len(m_pixels)))], ignore_index=True)
@@ -124,8 +164,6 @@ def calculate_smd(group1, group2):
124164
exclude_columns = ['ID', 'x', 'y', 'lat', 'lng', 'country', 'ecoregion', 'trt']
125165
exclude_columns += [col for col in km_pixels.columns if col.startswith('luc')]
126166

127-
match_cats = ["ecoregion", "country", "cluster"] + ["luc_" + str(year) for year in match_years]
128-
129167
# Extract only the continuous columns
130168
continuous_columns = km_pixels.columns.difference(exclude_columns)
131169
km_pixels_selected = km_pixels[continuous_columns]
@@ -178,7 +216,7 @@ def calculate_smd(group1, group2):
178216
plt.ylabel('PCA Component 2')
179217
plt.legend()
180218
plt.show()
181-
plt.savefig('Figures/cluster_centres_faiss_1.png')
219+
plt.savefig('Figures/ona_cluster_centres_faiss_1.png')
182220
plt.close() # Close the plot to free up memory
183221

184222
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -202,11 +240,14 @@ def calculate_smd(group1, group2):
202240
fig.delaxes(axes[j])
203241

204242
plt.tight_layout()
205-
plt.savefig('Figures/cluster_faiss_1_facet.png')
243+
plt.savefig('Figures/Ona_cluster_faiss_1_facet.png')
206244
plt.close()
207245

208246
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
209247

248+
match_years = [t0-10, t0-5, t0]
249+
match_cats = ["ecoregion", "country", "cluster"] + ["luc_" + str(year) for year in match_years]
250+
210251
# Extract K and M pixels
211252
k_pixels = km_pixels.loc[km_pixels['trt'] == 'trt']
212253
m_pixels = km_pixels.loc[km_pixels['trt'] == 'ctrl']
@@ -216,7 +257,7 @@ def calculate_smd(group1, group2):
216257
m_pca = km_pca.loc[km_pixels['trt'] == 'ctrl'].to_numpy()
217258

218259
k_sub_size = int(k_pixels.shape[0]* K_SUB_PROPORTION)
219-
m_sub_size = int(m_pixels.shape[0] * M_SUB_PROPORTION)
260+
m_sub_size = int(m_pixels.shape[0] * 1)
220261

221262
# Define indexs for the samples from K and M
222263
k_random_indices = np.random.choice(k_pixels.shape[0], size=k_sub_size, replace=False)
@@ -330,252 +371,4 @@ def calculate_smd(group1, group2):
330371
smd_df = pd.DataFrame(smd_results, columns=['Variable', 'SMD', 'Mean_k_cat', 'Mean_m_cat', 'Pooled_std'])
331372
print(smd_df)
332373

333-
smd_results = []
334-
for column in columns_to_compare:
335-
smd, mean1, mean2, pooled_std = calculate_smd(k_pixels[column], m_pixels[column])
336-
smd_results.append((column, smd, mean1, mean2, pooled_std))
337374

338-
# Convert the results to a DataFrame for better readability
339-
smd_df = pd.DataFrame(smd_results, columns=['Variable', 'SMD', 'Mean_k_cat', 'Mean_m_cat', 'Pooled_std'])
340-
print(smd_df)
341-
342-
343-
344-
345-
346-
347-
348-
def find_match_iteration(
349-
k_parquet_filename: str,
350-
m_parquet_filename: str,
351-
start_year: int,
352-
output_folder: str,
353-
idx_and_seed: tuple[int, int]
354-
) -> None:
355-
logging.info("Find match iteration %d of %d", idx_and_seed[0] + 1, REPEAT_MATCH_FINDING)
356-
rng = np.random.default_rng(idx_and_seed[1])
357-
358-
logging.info("Loading K from %s", k_parquet_filename)
359-
k_pixels = pd.read_parquet(k_parquet_filename)
360-
361-
logging.info("Loading M from %s", m_parquet_filename)
362-
m_pixels = pd.read_parquet(m_parquet_filename)
363-
364-
# concat m and k
365-
km_pixels = pd.concat([k_pixels.assign(trt='trt', ID=range(0, len(k_pixels))),
366-
m_pixels.assign(trt='ctrl', ID=range(0, len(m_pixels)))], ignore_index=True)
367-
368-
# Find the continuous columns
369-
exclude_columns = ['ID', 'x', 'y', 'lat', 'lng', 'country', 'ecoregion', 'trt']
370-
exclude_columns += [col for col in km_pixels.columns if col.startswith('luc')]
371-
continuous_columns = km_pixels.columns.difference(exclude_columns)
372-
# Categorical columns
373-
match_cats = ["ecoregion", "country", "cluster"] + ["luc_" + str(year) for year in match_years]
374-
375-
logging.info("Starting PCA transformation of k and m union. km_pixels.shape: %a", {km_pixels.shape})
376-
# PCA transform and conversion to 32 bit ints for continuous only
377-
km_pca = to_pca_int32(km_pixels[continuous_columns])
378-
logging.info("Done PCA transformation")
379-
380-
# Extract K and M pixels - this might be unnecessary I just wanted to make sure
381-
# K and M were in the same order here and in the PCA transform
382-
k_pixels = km_pixels.loc[km_pixels['trt'] == 'trt']
383-
m_pixels = km_pixels.loc[km_pixels['trt'] == 'ctrl']
384-
# Extract K and M PCA transforms
385-
k_pca = km_pca.loc[km_pixels['trt'] == 'trt'].to_numpy()
386-
m_pca = km_pca.loc[km_pixels['trt'] == 'ctrl'].to_numpy()
387-
388-
# Sample from K and M
389-
k_sub_size = int(k_pixels.shape[0]* K_SUB_PROPORTION)
390-
m_sub_size = int(m_pixels.shape[0] * M_SUB_PROPORTION)
391-
# Define indexs for the samples from K and M
392-
k_random_indices = np.random.choice(k_pixels.shape[0], size=k_sub_size, replace=False)
393-
m_random_indices = np.random.choice(m_pixels.shape[0], size=m_sub_size, replace=False)
394-
# Take random samples from K and M pixels
395-
k_sub = k_pixels.iloc[k_random_indices]
396-
m_sub = m_pixels.iloc[m_random_indices]
397-
# Take corresponding random samples from the PCA transformed K and M
398-
k_sub_pca = k_pca[k_random_indices,:]
399-
m_sub_pca = m_pca[m_random_indices,:]
400-
401-
logging.info("Samples taken from K and M. k_sub.shape: %a; m_sub.shape: %a", {k_sub.shape, m_sub.shape})
402-
403-
# Identify the unique combinations of luc columns
404-
k_cat_combinations = k_sub[match_cats].drop_duplicates().sort_values(by=match_cats, ascending=[True] * len(match_cats))
405-
406-
pairs_list = []
407-
matchless_list = []
408-
409-
logging.info("Starting greedy matching... k_sub.shape: %s, m_sub.shape: %s",
410-
k_sub.shape, m_sub.shape)
411-
412-
start_time = time.time()
413-
for i in range(0, k_cat_combinations.shape[0]):
414-
# i = 6 # ith element of the unique combinations of the luc time series in k
415-
# for in range()
416-
k_cat_comb = k_cat_combinations.iloc[i]
417-
k_cat = k_sub[(k_sub[match_cats] == k_cat_comb).all(axis=1)]
418-
k_cat_pca = k_sub_pca[(k_sub[match_cats] == k_cat_comb).all(axis=1)]
419-
420-
# Find the subset in km_pixels that matches this combination
421-
m_cat = m_sub[(m_sub[match_cats] == k_cat_comb).all(axis=1)]
422-
m_cat_pca = m_sub_pca[(m_sub[match_cats] == k_cat_comb).all(axis=1)]
423-
424-
if VERBOSE:
425-
print('ksub_cat:' + str(k_cat.shape[0]))
426-
print('msub_cat:' + str(m_cat.shape[0]))
427-
428-
# If there is no suitable match for the pre-project luc time series
429-
# Then it may be preferable to just take the luc state at t0
430-
# m_luc_comb = m_pixels[(m_pixels[match_luc_years[1:3]] == K_luc_comb[1:3]).all(axis=1)]
431-
# m_luc_comb = m_pixels[(m_pixels[match_luc_years[2:3]] == K_luc_comb[2:3]).all(axis=1)]
432-
# For if there are no matches return nothing
433-
434-
if(m_cat.shape[0] < k_cat.shape[0] * 5):
435-
# print("M insufficient for matching. Set VERBOSE to True for more details.")
436-
# Append the matchless DataFrame to the list
437-
matchless_list.append(k_cat)
438-
continue
439-
440-
# Find the matches
441-
matches_index = loop_match(m_cat_pca, k_cat_pca)
442-
m_cat_matches = m_cat.iloc[matches_index]
443-
444-
# i = 0
445-
# matched = pd.concat([k_cat.iloc[i], m_cat.iloc[matches[i]]], axis=1, ignore_index=True)
446-
# matched.columns = ['trt', 'ctrl']
447-
# matched
448-
#Looks great!
449-
columns_to_compare = ['access', 'cpc0_d', 'cpc0_u', 'cpc10_d', 'cpc10_u', 'cpc5_d', 'cpc5_u', 'elevation', 'slope']
450-
# Calculate SMDs for the specified columns
451-
smd_results = []
452-
for column in columns_to_compare:
453-
smd, mean1, mean2, pooled_std = calculate_smd(k_cat[column], m_cat_matches[column])
454-
smd_results.append((column, smd, mean1, mean2, pooled_std))
455-
456-
# Convert the results to a DataFrame for better readability
457-
smd_df = pd.DataFrame(smd_results, columns=['Variable', 'SMD', 'Mean_k_cat', 'Mean_m_cat', 'Pooled_std'])
458-
459-
if VERBOSE:
460-
# Print the results
461-
print("categorical combination:")
462-
print(k_cat_comb)
463-
# Count how many items in 'column1' are not equal to the specified integer value
464-
print("LUC flips in K:")
465-
(k_cat['luc_2022'] != k_cat_comb['luc_' + str(t0)]).sum()
466-
print("LUC flips in matches:")
467-
(m_cat_matches['luc_2022'] != k_cat_comb['luc_' + str(t0)]).sum()
468-
print("Standardized Mean Differences:")
469-
print(smd_df)
470-
471-
# Join the pairs into one dataframe:
472-
k_cat = k_cat.reset_index(drop = True)
473-
m_cat_matches = m_cat_matches.reset_index(drop = True)
474-
pairs_df = pd.concat([k_cat.add_prefix('k_'), m_cat_matches.add_prefix('s_')], axis=1)
475-
476-
# Append the resulting DataFrame to the list
477-
pairs_list.append(pairs_df)
478-
479-
# Combine all the DataFrames in the list into a single DataFrame
480-
pairs = pd.concat(pairs_list, ignore_index=True)
481-
matchless = pd.concat(matchless_list, ignore_index=True)
482-
483-
logging.info("Finished greedy matching... pairs.shape: %s, matchless.shape: %s",
484-
pairs.shape, matchless.shape)
485-
486-
logging.info("Starting storing matches...")
487-
pairs.to_parquet(os.path.join(output_folder, f'{idx_and_seed[1]}.parquet'))
488-
matchless.to_parquet(os.path.join(output_folder, f'{idx_and_seed[1]}_matchless.parquet'))
489-
490-
logging.info("Finished find match iteration")
491-
492-
493-
def find_pairs(
494-
k_parquet_filename: str,
495-
m_parquet_filename: str,
496-
start_year: int,
497-
seed: int,
498-
output_folder: str,
499-
processes_count: int
500-
) -> None:
501-
logging.info("Starting find pairs")
502-
os.makedirs(output_folder, exist_ok=True)
503-
504-
rng = np.random.default_rng(seed)
505-
iteration_seeds = zip(range(REPEAT_MATCH_FINDING), rng.integers(0, 1000000, REPEAT_MATCH_FINDING))
506-
507-
with Pool(processes=processes_count) as pool:
508-
pool.map(
509-
partial(
510-
find_match_iteration,
511-
k_parquet_filename,
512-
m_parquet_filename,
513-
start_year,
514-
output_folder
515-
),
516-
iteration_seeds
517-
)
518-
519-
520-
def main():
521-
# If you use the default multiprocess model then you risk deadlocks when logging (which we
522-
# have hit). Spawn is the default on macOS, but not on Linux.
523-
set_start_method("spawn")
524-
525-
parser = argparse.ArgumentParser(description="Takes K and M and finds 100 sets of matches.")
526-
parser.add_argument(
527-
"--k",
528-
type=str,
529-
required=True,
530-
dest="k_filename",
531-
help="Parquet file containing pixels from K as generated by calculate_k.py"
532-
)
533-
parser.add_argument(
534-
"--m",
535-
type=str,
536-
required=True,
537-
dest="m_filename",
538-
help="Parquet file containing pixels from M as generated by build_m_table.py"
539-
)
540-
parser.add_argument(
541-
"--start_year",
542-
type=int,
543-
required=True,
544-
dest="start_year",
545-
help="Year project started."
546-
)
547-
parser.add_argument(
548-
"--seed",
549-
type=int,
550-
required=True,
551-
dest="seed",
552-
help="Random number seed, to ensure experiments are repeatable."
553-
)
554-
parser.add_argument(
555-
"--output",
556-
type=str,
557-
required=True,
558-
dest="output_directory_path",
559-
help="Directory into which output matches will be written. Will be created if it does not exist."
560-
)
561-
parser.add_argument(
562-
"-j",
563-
type=int,
564-
required=False,
565-
default=round(cpu_count() / 2),
566-
dest="processes_count",
567-
help="Number of concurrent threads to use."
568-
)
569-
args = parser.parse_args()
570-
571-
find_pairs(
572-
args.k_filename,
573-
args.m_filename,
574-
args.start_year,
575-
args.seed,
576-
args.output_directory_path,
577-
args.processes_count
578-
)
579-
580-
if __name__ == "__main__":
581-
main()

0 commit comments

Comments
 (0)