Skip to content

Commit e34e0d1

Browse files
committed
Implement rate limiting and input sanitization for API calls
1 parent b1306ad commit e34e0d1

2 files changed

Lines changed: 64 additions & 1 deletion

File tree

pydeepskylog/deepskylog_interface.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,41 @@
55
from pydeepskylog.exceptions import (
66
APIConnectionError, APITimeoutError, APIAuthenticationError, APIResponseError, InvalidParameterError
77
)
8+
import threading
9+
from functools import wraps
10+
from pydeepskylog.sanitization import sanitize_string
11+
812
DSL_API_BASE_URL: str = "https://test.deepskylog.org/api/" # Change this as needed
913

1014
# Simple in-memory cache: {url: (timestamp, data)}
1115
_DSL_API_CACHE: Dict[str, tuple[float, Any]] = {}
1216
_DSL_API_CACHE_TTL: int = 300 # seconds (5 minutes)
1317

18+
# Rate limiting parameters
19+
_DSL_API_RATE_LIMIT = 1 # max requests
20+
_DSL_API_RATE_PERIOD = 1.0 # seconds
21+
22+
_rate_lock = threading.Lock()
23+
_rate_timestamps = []
24+
25+
def rate_limited(max_calls: int, period: float):
26+
def decorator(func):
27+
@wraps(func)
28+
def wrapper(*args, **kwargs):
29+
with _rate_lock:
30+
now = time.time()
31+
# Remove timestamps outside the window
32+
while _rate_timestamps and _rate_timestamps[0] <= now - period:
33+
_rate_timestamps.pop(0)
34+
if len(_rate_timestamps) >= max_calls:
35+
sleep_time = period - (now - _rate_timestamps[0])
36+
if sleep_time > 0:
37+
time.sleep(sleep_time)
38+
_rate_timestamps.append(time.time())
39+
return func(*args, **kwargs)
40+
return wrapper
41+
return decorator
42+
1443
def dsl_instruments(username: str) -> Dict[str, Any]:
1544
"""
1645
Retrieve all defined astronomical instruments for a DeepskyLog user via the DeepskyLog API.
@@ -39,6 +68,7 @@ def dsl_instruments(username: str) -> Dict[str, Any]:
3968
>>> for inst_id, inst in instruments.items():
4069
... print(inst["name"], inst["diameter"])
4170
"""
71+
username = sanitize_string(username, max_length=64)
4272
return _dsl_api_call("instrument", username)
4373

4474
def dsl_eyepieces(username: str) -> Dict[str, Any]:
@@ -68,6 +98,7 @@ def dsl_eyepieces(username: str) -> Dict[str, Any]:
6898
>>> for ep_id, ep in eyepieces.items():
6999
... print(ep["name"], ep["focal_length_mm"])
70100
"""
101+
username = sanitize_string(username, max_length=64)
71102
return _dsl_api_call("eyepieces", username)
72103

73104
def dsl_lenses(username: str) -> Dict[str, Any]:
@@ -97,6 +128,7 @@ def dsl_lenses(username: str) -> Dict[str, Any]:
97128
>>> for lens_id, lens in lenses.items():
98129
... print(lens["name"], lens["focal_length_mm"])
99130
"""
131+
username = sanitize_string(username, max_length=64)
100132
return _dsl_api_call("lenses", username)
101133

102134
def dsl_filters(username: str) -> Dict[str, Any]:
@@ -126,6 +158,7 @@ def dsl_filters(username: str) -> Dict[str, Any]:
126158
>>> for filter_id, flt in filters.items():
127159
... print(flt["name"], flt["type"])
128160
"""
161+
username = sanitize_string(username, max_length=64)
129162
return _dsl_api_call("filters", username)
130163

131164

@@ -240,6 +273,7 @@ def convert_instrument_type_to_string(instrument_type: int) -> str:
240273

241274
return instrument_types[instrument_type]
242275

276+
@rate_limited(_DSL_API_RATE_LIMIT, _DSL_API_RATE_PERIOD)
243277
def _dsl_api_call(api_call: str, username: str) -> Dict[str, Any]:
244278
"""
245279
Make a GET request to the DeepskyLog API for a specific resource and user.
@@ -267,6 +301,7 @@ def _dsl_api_call(api_call: str, username: str) -> Dict[str, Any]:
267301
>>> data = _dsl_api_call\("instrument", "astro_user"\)
268302
>>> print\(data\)
269303
"""
304+
username = sanitize_string(username, max_length=64)
270305
api_url: str = f"{DSL_API_BASE_URL}{api_call}/{username}"
271306
now: float = time.time()
272307
logger: logging = logging.getLogger(__name__)
@@ -279,7 +314,7 @@ def _dsl_api_call(api_call: str, username: str) -> Dict[str, Any]:
279314
return data
280315

281316
try:
282-
response = requests.get(api_url, timeout=10)
317+
response = requests.get(api_url, timeout=10, verify=True)
283318
if response.status_code in (401, 403):
284319
logger.error(f"Authentication failed for user '{username}' (status {response.status_code})")
285320
raise APIAuthenticationError(f"Authentication failed for user '{username}' (status {response.status_code})")

pydeepskylog/sanitization.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import re
2+
from typing import Any
3+
4+
def sanitize_string(value: Any, max_length: int = 256, allow_unicode: bool = False) -> str:
5+
"""
6+
Sanitize a string input by removing leading/trailing whitespace, dangerous characters,
7+
and enforcing a maximum length. Optionally restrict to ASCII.
8+
9+
Args:
10+
value (Any): The input value to sanitize.
11+
max_length (int): Maximum allowed length of the string.
12+
allow_unicode (bool): If False, restricts to ASCII characters.
13+
14+
Returns:
15+
str: The sanitized string.
16+
17+
Raises:
18+
ValueError: If the input is not a string or exceeds allowed length after sanitization.
19+
"""
20+
if not isinstance(value, str):
21+
raise ValueError("Input must be a string")
22+
value = value.strip()
23+
if not allow_unicode:
24+
value = re.sub(r'[^\x20-\x7E]', '', value) # Remove non-ASCII
25+
value = re.sub(r'[<>;"\'`\\]', '', value) # Remove potentially dangerous characters
26+
if len(value) > max_length:
27+
raise ValueError(f"Input exceeds maximum length of {max_length}")
28+
return value

0 commit comments

Comments
 (0)