|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +""" |
| 4 | +Plots the location of minima for each optimization run in 3D input space for |
| 5 | +the N best runs. |
| 6 | +
|
| 7 | +To be run with both the history file (*.npy) and the persis_info file (*.pickle) |
| 8 | +from a libEnsemble/APOSMM run present in the current directory. The most recent |
| 9 | +of each file type present will be used for the plot. |
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +import glob |
| 16 | +import os |
| 17 | +import pickle |
| 18 | +import sys |
| 19 | + |
| 20 | +N = 6 # number of opt runs to show. |
| 21 | + |
| 22 | +x_name = 'x0' |
| 23 | +y_name = 'x1' |
| 24 | +z_name = 'x2' |
| 25 | + |
| 26 | +full_bounds = False # For entire input space enter bounds below |
| 27 | + |
| 28 | +if full_bounds: |
| 29 | + # Define parameter bounds |
| 30 | + mcr = 1e-2 |
| 31 | + x0_min, x0_max = mcr, 1.0 - mcr |
| 32 | + x1_min, x1_max = -20.0, 20.0 |
| 33 | + x2_min, x2_max = 1.0, 20.0 |
| 34 | + |
| 35 | +# Find the most recent .npy and pickle files |
| 36 | +try: |
| 37 | + H_file = max(glob.glob("*.npy"), key=os.path.getmtime) |
| 38 | + persis_info_file = max(glob.iglob('*.pickle'), key=os.path.getctime) |
| 39 | +except Exception: |
| 40 | + sys.exit("Need a *.npy and a *.pickle files in run dir. Exiting...") |
| 41 | + |
| 42 | +H = np.load(H_file) |
| 43 | + |
| 44 | +with open(persis_info_file, "rb") as f: |
| 45 | + index_sets = pickle.load(f)["run_order"] |
| 46 | + |
| 47 | +# Filter best N opt runs for clearer graph |
| 48 | +trimmed_index_sets = {key: indices[:-1] for key, indices in index_sets.items()} |
| 49 | +min_f_per_set = [(key, indices, H['f'][indices].min()) for key, indices in trimmed_index_sets.items() if len(indices) > 0] |
| 50 | +min_f_per_set_sorted = sorted(min_f_per_set, key=lambda x: x[2])[:N] |
| 51 | + |
| 52 | +# Plotting |
| 53 | +fig = plt.figure(figsize=(6, 6)) |
| 54 | +ax = fig.add_subplot(111, projection='3d') |
| 55 | + |
| 56 | +for key, indices, _ in min_f_per_set_sorted: |
| 57 | + min_f_index = indices[np.argmin(H['f'][indices])] |
| 58 | + |
| 59 | + # Extract the corresponding 3D x position from H |
| 60 | + try: |
| 61 | + x, y, z = H['x'][min_f_index] |
| 62 | + except ValueError: |
| 63 | + x = H[x_name][min_f_index] |
| 64 | + y = H[y_name][min_f_index] |
| 65 | + z = H[z_name][min_f_index] |
| 66 | + |
| 67 | + # Plot the 3D point |
| 68 | + ax.scatter(x, y, z, marker='o', s=50, label=f'Opt run {key}') |
| 69 | + |
| 70 | + # Draw a line from the point to the XY plane (z=0) |
| 71 | + ax.plot([x, x], [y, y], [0, z], color='grey', linestyle='--') |
| 72 | + |
| 73 | +if full_bounds: |
| 74 | + ax.set_xlim(x0_min, x0_max) |
| 75 | + ax.set_ylim(x1_min, x1_max) |
| 76 | + ax.set_zlim(x2_min, x2_max) |
| 77 | + |
| 78 | +# Label the plot |
| 79 | +ax.set_xlabel(x_name) |
| 80 | +ax.set_ylabel(y_name) |
| 81 | +ax.set_zlabel(z_name) |
| 82 | +ax.set_title('Locations of best points from each optimization run') |
| 83 | +ax.legend(bbox_to_anchor=(-0.1, 0.9), loc='upper left', borderaxespad=0) |
| 84 | +plt.savefig(f"location_min_best{N}.png") |
0 commit comments