|
3 | 3 |
|
4 | 4 | import random |
5 | 5 | import numpy as np |
| 6 | +import pandas as pd |
6 | 7 | import matplotlib.pyplot as plt |
7 | 8 |
|
8 | 9 | cols = ["red", "green", "blue", "royalblue", "orange", "black"] |
@@ -45,8 +46,7 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None): |
45 | 46 | # Use a different color index for resources (continuing from where |
46 | 47 | # species left off) |
47 | 48 | color_idx = observed_species.shape[1] + resource_idx |
48 | | - # Use modulo to cycle through colors if we have more entities than |
49 | | - # colors |
| 49 | + |
50 | 50 | color_idx = color_idx % len(cols) |
51 | 51 |
|
52 | 52 | label = f'Resource {resource_idx + 1}' |
@@ -109,6 +109,85 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None): |
109 | 109 | return fig, ax |
110 | 110 |
|
111 | 111 |
|
| 112 | +def plot_CRM_with_intervals( |
| 113 | + observed_species, |
| 114 | + observed_resources, |
| 115 | + species_lower, |
| 116 | + species_upper, |
| 117 | + resource_lower, |
| 118 | + resource_upper, |
| 119 | + times, |
| 120 | + filename=None): |
| 121 | + fig, ax = plt.subplots(figsize=(12, 8)) |
| 122 | + |
| 123 | + # Plot median trajectories |
| 124 | + for i in range(observed_species.shape[1]): |
| 125 | + ax.plot(times, observed_species[:, i], |
| 126 | + label=f'Species {i+1}', linewidth=2) |
| 127 | + |
| 128 | + for i in range(observed_resources.shape[1]): |
| 129 | + ax.plot(times, |
| 130 | + observed_resources[:, |
| 131 | + i], |
| 132 | + label=f'Resource {i+1}', |
| 133 | + linewidth=2, |
| 134 | + linestyle='--') |
| 135 | + |
| 136 | + # Add confidence ribbons |
| 137 | + for i in range(observed_species.shape[1]): |
| 138 | + ax.fill_between(times, species_lower[:, i], species_upper[:, i], |
| 139 | + alpha=0.2, color=plt.cm.tab10(i)) |
| 140 | + |
| 141 | + for i in range(observed_resources.shape[1]): |
| 142 | + ax.fill_between(times, |
| 143 | + resource_lower[:, |
| 144 | + i], |
| 145 | + resource_upper[:, |
| 146 | + i], |
| 147 | + alpha=0.2, |
| 148 | + color=plt.cm.tab10(i + observed_species.shape[1])) |
| 149 | + |
| 150 | + if filename: |
| 151 | + true_data = pd.read_csv(filename) |
| 152 | + true_times = true_data['time'].values |
| 153 | + |
| 154 | + for i in range(observed_species.shape[1]): |
| 155 | + col_name = f'species_{i+1}' |
| 156 | + if col_name in true_data.columns: |
| 157 | + ax.scatter( |
| 158 | + true_times, |
| 159 | + true_data[col_name], |
| 160 | + marker='o', |
| 161 | + s=30, |
| 162 | + color=plt.cm.tab10(i), |
| 163 | + label=f'True {col_name}') |
| 164 | + |
| 165 | + for i in range(observed_resources.shape[1]): |
| 166 | + col_name = f'resource_{i+1}' |
| 167 | + if col_name in true_data.columns: |
| 168 | + ax.scatter( |
| 169 | + true_times, |
| 170 | + true_data[col_name], |
| 171 | + marker='s', |
| 172 | + s=30, |
| 173 | + color=plt.cm.tab10( |
| 174 | + i + observed_species.shape[1]), |
| 175 | + label=f'True {col_name}') |
| 176 | + |
| 177 | + ax.set_xlabel('Time', fontsize=14) |
| 178 | + ax.set_ylabel('Concentration', fontsize=14) |
| 179 | + ax.set_title( |
| 180 | + 'Consumer-Resource Model Dynamics with 95% Credible Intervals', |
| 181 | + fontsize=16) |
| 182 | + ax.legend(loc='best', fontsize=12) |
| 183 | + ax.grid(True, alpha=0.3) |
| 184 | + |
| 185 | + plt.tight_layout() |
| 186 | + if filename: |
| 187 | + plt.savefig(f"{filename.split('.')[0]}_with_intervals.png", dpi=300) |
| 188 | + plt.show() |
| 189 | + |
| 190 | + |
112 | 191 | def plot_gMLV(yobs, sobs, timepoints): |
113 | 192 | # fig, axs = plt.subplots(1, 2, layout='constrained') |
114 | 193 | fig, axs = plt.subplots(1, 2) |
|
0 commit comments