|
4 | 4 | [](https://www.python.org/downloads/) |
5 | 5 | [](https://opensource.org/licenses/Apache-2.0) |
6 | 6 |
|
7 | | -Production-ready Python SDK for FAIM (Foundation AI Models) - a high-performance time-series forecasting platform powered by foundation models. |
| 7 | +Production-ready Python SDK for FAIM (Foundation AI Models) - a unified platform for time-series forecasting and tabular inference powered by foundation models. |
8 | 8 |
|
9 | 9 | ## Features |
10 | 10 |
|
11 | | -- **🚀 Multiple Foundation Models**: FlowState, Amazon Chronos 2.0, TiRex |
| 11 | +- **🚀 Multiple Foundation Models**: |
| 12 | + - **Time-Series**: FlowState, Amazon Chronos 2.0, TiRex |
| 13 | + - **Tabular**: LimiX (classification & regression) |
12 | 14 | - **🔒 Type-Safe API**: Full type hints with Pydantic validation |
13 | 15 | - **⚡ High Performance**: Optimized Apache Arrow serialization with zero-copy operations |
14 | | -- **🎯 Probabilistic & Deterministic**: Point forecasts, quantiles, and samples |
| 16 | +- **🎯 Probabilistic & Deterministic**: Point forecasts, quantiles, samples, and probabilistic predictions |
15 | 17 | - **🔄 Async Support**: Built-in async/await support for concurrent requests |
16 | 18 | - **📊 Rich Error Handling**: Machine-readable error codes with detailed diagnostics |
17 | 19 | - **🧪 Battle-Tested**: Production-ready with comprehensive error handling |
18 | 20 | - **📈 Evaluation Tools**: Built-in metrics (MSE, MASE, CRPS) and visualization utilities |
| 21 | +- **🔎 Retrieval-Augmented Inference**: Optional RAI for improved accuracy on small datasets |
19 | 22 |
|
20 | 23 | ## Installation |
21 | 24 |
|
@@ -171,6 +174,114 @@ response = client.forecast(request) |
171 | 174 | print(response.point.shape) # (batch_size, 24, features) |
172 | 175 | ``` |
173 | 176 |
|
| 177 | +## Tabular Inference with LimiX |
| 178 | + |
| 179 | +The SDK also supports **LimiX**, a foundation model for tabular classification and regression: |
| 180 | + |
| 181 | +```python |
| 182 | +from faim_sdk import TabularClient, LimiXPredictRequest |
| 183 | +import numpy as np |
| 184 | + |
| 185 | +# Initialize tabular client |
| 186 | +client = TabularClient(api_key="your-api-key") |
| 187 | + |
| 188 | +# Prepare tabular data (2D arrays) |
| 189 | +X_train = np.random.randn(100, 10).astype(np.float32) |
| 190 | +y_train = np.random.randint(0, 2, 100).astype(np.float32) |
| 191 | +X_test = np.random.randn(20, 10).astype(np.float32) |
| 192 | + |
| 193 | +# Create classification request |
| 194 | +request = LimiXPredictRequest( |
| 195 | + X_train=X_train, |
| 196 | + y_train=y_train, |
| 197 | + X_test=X_test, |
| 198 | + task_type="Classification", # or "Regression" |
| 199 | + use_retrieval=False # Set to True for retrieval-augmented inference |
| 200 | +) |
| 201 | + |
| 202 | +# Generate predictions |
| 203 | +response = client.predict(request) |
| 204 | +print(response.predictions.shape) # (20,) |
| 205 | +print(response.probabilities.shape) # (20, n_classes) - classification only |
| 206 | +``` |
| 207 | + |
| 208 | +### Classification Example |
| 209 | + |
| 210 | +```python |
| 211 | +from sklearn.datasets import load_breast_cancer |
| 212 | +from sklearn.model_selection import train_test_split |
| 213 | + |
| 214 | +# Load dataset |
| 215 | +X, y = load_breast_cancer(return_X_y=True) |
| 216 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) |
| 217 | + |
| 218 | +# Convert to float32 |
| 219 | +X_train = X_train.astype(np.float32) |
| 220 | +X_test = X_test.astype(np.float32) |
| 221 | +y_train = y_train.astype(np.float32) |
| 222 | + |
| 223 | +# Create and send request |
| 224 | +request = LimiXPredictRequest( |
| 225 | + X_train=X_train, |
| 226 | + y_train=y_train, |
| 227 | + X_test=X_test, |
| 228 | + task_type="Classification" |
| 229 | +) |
| 230 | + |
| 231 | +response = client.predict(request) |
| 232 | + |
| 233 | +# Evaluate |
| 234 | +from sklearn.metrics import accuracy_score |
| 235 | +accuracy = accuracy_score(y_test, response.predictions.astype(int)) |
| 236 | +print(f"Accuracy: {accuracy:.4f}") |
| 237 | +``` |
| 238 | + |
| 239 | +### Regression Example |
| 240 | + |
| 241 | +```python |
| 242 | +from sklearn.datasets import fetch_california_housing |
| 243 | + |
| 244 | +# Load dataset |
| 245 | +house_data = fetch_california_housing() |
| 246 | +X, y = house_data.data, house_data.target |
| 247 | + |
| 248 | +# Split data (50/50 for demo) |
| 249 | +split_idx = len(X) // 2 |
| 250 | +X_train, X_test = X[:split_idx].astype(np.float32), X[split_idx:].astype(np.float32) |
| 251 | +y_train, y_test = y[:split_idx].astype(np.float32), y[split_idx:].astype(np.float32) |
| 252 | + |
| 253 | +# Create and send request |
| 254 | +request = LimiXPredictRequest( |
| 255 | + X_train=X_train, |
| 256 | + y_train=y_train, |
| 257 | + X_test=X_test, |
| 258 | + task_type="Regression" |
| 259 | +) |
| 260 | + |
| 261 | +response = client.predict(request) |
| 262 | + |
| 263 | +# Evaluate |
| 264 | +from sklearn.metrics import mean_squared_error |
| 265 | +rmse = np.sqrt(mean_squared_error(y_test, response.predictions)) |
| 266 | +print(f"RMSE: {rmse:.4f}") |
| 267 | +``` |
| 268 | + |
| 269 | +### Retrieval-Augmented Inference |
| 270 | + |
| 271 | +For better accuracy on small datasets, enable retrieval-augmented inference: |
| 272 | + |
| 273 | +```python |
| 274 | +request = LimiXPredictRequest( |
| 275 | + X_train=X_train, |
| 276 | + y_train=y_train, |
| 277 | + X_test=X_test, |
| 278 | + task_type="Classification", |
| 279 | + use_retrieval=True # Enable RAI (slower but more accurate) |
| 280 | +) |
| 281 | + |
| 282 | +response = client.predict(request) |
| 283 | +``` |
| 284 | + |
174 | 285 | ## Response Format |
175 | 286 |
|
176 | 287 | All forecasts return a `ForecastResponse` object with predictions and metadata: |
|
0 commit comments