diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py index bb37584c11..59ea7b92f3 100644 --- a/build_scripts/memory_migrations.py +++ b/build_scripts/memory_migrations.py @@ -58,6 +58,20 @@ def _cmd_check() -> None: tmp_path.unlink(missing_ok=True) +def _cmd_head() -> None: + """Print the current Alembic head revision ID.""" + from pathlib import Path + + from alembic.config import Config + from alembic.script import ScriptDirectory + + script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + head = ScriptDirectory.from_config(config).get_current_head() + print(head) + + def _build_parser() -> argparse.ArgumentParser: """Build the CLI argument parser.""" parser = argparse.ArgumentParser( @@ -71,6 +85,8 @@ def _build_parser() -> argparse.ArgumentParser: sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.") + sub.add_parser("head", help="Print the current Alembic head revision ID.") + return parser @@ -82,6 +98,8 @@ def main() -> int: _cmd_generate(message=args.message, force=args.force) elif args.command == "check": _cmd_check() + elif args.command == "head": + _cmd_head() return 0 diff --git a/build_scripts/migrate_prod_memory_schema.py b/build_scripts/migrate_prod_memory_schema.py new file mode 100644 index 0000000000..45a8a085c4 --- /dev/null +++ b/build_scripts/migrate_prod_memory_schema.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Deliberate schema migration tool for production databases. + +This script is the ONLY sanctioned way to apply Alembic migrations to a production +database. It is intended to be run during the release process (see +doc/contributing/10_release_process.md) or by a CD pipeline — never by normal +application startup. + +It constructs an AzureSQLMemory with skip_schema_migration=True (bypassing the +runtime guard), then explicitly calls _run_schema_migration to upgrade to head. +The environment checks ensure this only runs from a release branch. + +Safety rails: +- Validates the environment (release branch, clean working tree, no .dev version). +- Interactive confirmation when running in a terminal. +- Exits non-zero on any failure. + +Usage: + python build_scripts/migrate_prod_memory_schema.py + +The script reads the production connection string from +AZURE_SQL_DB_CONNECTION_STRING_PROD (loaded from ~/.pyrit/.env). +""" + +import subprocess +import sys + +import dotenv + +from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH + +# Load .env files from ~/.pyrit/ (same files that initialize_pyrit_async loads) +# Use override=False so explicitly-set env vars take precedence over .env values +for _env_file in [CONFIGURATION_DIRECTORY_PATH / ".env", CONFIGURATION_DIRECTORY_PATH / ".env.local"]: + if _env_file.exists(): + dotenv.load_dotenv(_env_file, override=False, interpolate=True) + +# ANSI color codes +_GREEN = "\033[92m" +_RED = "\033[91m" +_YELLOW = "\033[93m" +_RESET = "\033[0m" + + +def _print_error(message: str) -> None: + """Print an error message in red to stderr.""" + print(f"{_RED}ERROR: {message}{_RESET}", file=sys.stderr) + + +def _print_success(message: str) -> None: + """Print a success message in green.""" + print(f"{_GREEN}{message}{_RESET}") + + +def _check_release_environment() -> list[str]: + """ + Validate that the script is running in a proper release environment. + + Returns a list of warning/error messages. Empty list means all checks pass. + """ + issues: list[str] = [] + + try: + branch = subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + if not branch.startswith("releases/"): + issues.append( + f"Not on a release branch (current: '{branch}'). " + "Production migrations should run from 'releases/vX.Y.Z'." + ) + except (subprocess.CalledProcessError, FileNotFoundError): + issues.append("Could not determine current Git branch.") + + try: + dirty_files = subprocess.check_output( + ["git", "status", "--porcelain", "--", "pyrit/memory/"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + if dirty_files: + issues.append( + "Uncommitted changes detected in pyrit/memory/:\n" + f" {dirty_files}\n" + " Commit or stash changes before migrating production." + ) + except (subprocess.CalledProcessError, FileNotFoundError): + issues.append("Could not check Git working tree status.") + + try: + from pyrit import __version__ + + if ".dev" in __version__: + issues.append( + f"PyRIT version is '{__version__}' (development). " + "Production migrations should use a release version (no .dev suffix)." + ) + except ImportError: + issues.append("Could not determine PyRIT version.") + + return issues + + +def main() -> int: + """Entry point for production schema migration.""" + import argparse + + parser = argparse.ArgumentParser( + description="Apply Alembic schema migrations to the production database.", + ) + parser.add_argument( + "--skip-environment-check", + action="store_true", + help="Skip release environment checks (branch, clean tree, version). Use only in CI with caution.", + ) + args = parser.parse_args() + + # Safety rail: Verify release environment + if not args.skip_environment_check: + issues = _check_release_environment() + if issues: + _print_error("Release environment checks failed:") + for issue in issues: + _print_error(f" - {issue}") + _print_error("Fix the above issues or pass --skip-environment-check (CI only).") + return 1 + else: + print(f"{_YELLOW}WARNING: Skipping release environment checks.{_RESET}") + + # Interactive confirmation + if sys.stdin.isatty(): + print("About to migrate production database schema to head.") + response = input("Type 'yes' to proceed: ") + if response.strip().lower() != "yes": + print("Aborted.") + return 1 + + # Construct AzureSQLMemory with skip_schema_migration=True to bypass the runtime guard, + # then explicitly run migration. + import os + + from pyrit.memory import AzureSQLMemory + + prod_conn = os.environ.get(AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD) + if not prod_conn: + _print_error(f"Environment variable '{AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD}' is not set.") + return 1 + + try: + memory = AzureSQLMemory( + connection_string=prod_conn, + skip_schema_migration=True, + ) + print("Running schema migration...") + memory._run_schema_migration() + _print_success("Production schema migration completed and verified successfully.") + return 0 + except Exception as e: + _print_error(f"Migration failed: {e}") + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/doc/contributing/10_release_process.md b/doc/contributing/10_release_process.md index 7e5af48b65..4637b9b5cd 100644 --- a/doc/contributing/10_release_process.md +++ b/doc/contributing/10_release_process.md @@ -204,7 +204,55 @@ Note: You may need to build the package again if those changes modify any depend Lastly, **Verify pyrit-internal is up to date.** Follow the instructions at [aka.ms/internal-release](https://aka.ms/internal-release) to ensure the internal package is current. -## 9. Publish to PyPI +## 9. Migrate Production Database Schema + +Apply any pending Alembic migrations to the production database. This is the **only** +sanctioned path for modifying the production schema — normal startup only validates, +never upgrades. + +**Run from the release branch with release dependencies.** This ensures the migration +files and model definitions match exactly what will be shipped to users. Running from +`main` or a dev environment could apply unreleased migrations that break prod. + +```bash +git checkout releases/vx.y.z +uv run python -c "import pyrit; print(pyrit.__version__)" # verify: x.y.z (no .dev0) +``` + +**Run the migration** (reads `AZURE_SQL_DB_CONNECTION_STRING_PROD` from `~/.pyrit/.env`): + +```bash +uv run python build_scripts/migrate_prod_memory_schema.py +``` + +The script validates the environment (release branch, clean tree, no `.dev` version), +constructs an `AzureSQLMemory` pointed at prod, and runs `_run_schema_migration()` which +upgrades to head and verifies the schema matches models. Since you're on the release branch, +head is the release revision. + +**Verify prod is usable after migration.** This connects to the prod DB using the +check-only path and confirms compatibility: + +```bash +uv run python -c " +import os, dotenv +from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH +dotenv.load_dotenv(CONFIGURATION_DIRECTORY_PATH / '.env', override=False, interpolate=True) +from pyrit.memory import AzureSQLMemory +AzureSQLMemory(connection_string=os.environ['AZURE_SQL_DB_CONNECTION_STRING_PROD']) +" +``` + +If it exits without error (or only a schema mismatch warning), prod is ready. + +If no schema changes landed in this release, `_run_schema_migration` is a no-op. +Still run it as confirmation. + +**Rollback policy:** forward-fix only. Ship a new corrective migration rather than downgrading, +since `downgrade()` risks data loss. + + +## 10. Publish to PyPI Create an account on pypi.org if you don't have one yet. Ask one of the other maintainers to add you to the `pyrit` project on PyPI. @@ -221,7 +269,7 @@ If successful, it will print > View at: > https://pypi.org/project/pyrit/x.y.z/ -## 10. Update main +## 11. Update main After the release is on PyPI, make sure to create a PR for the `main` branch where the only changes are: @@ -233,7 +281,7 @@ where the only changes are: The PR should be made from your fork and should be a different branch than the releases branch you created earlier. This should be something like `x.y.z+1.dev0`. -## 11. Create GitHub Release +## 12. Create GitHub Release Finally, go to the [releases page](https://github.com/microsoft/PyRIT/releases), select "Draft a new release" and the "tag" for which you want to create the release notes. It should match the version that you just released diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 47770db028..ddc3c687c7 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -58,6 +58,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton): AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL" AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN" + # Optional environment variable for production connection string to prevent accidental schema migrations on prod + AZURE_SQL_DB_CONNECTION_STRING_PROD: str = "AZURE_SQL_DB_CONNECTION_STRING_PROD" + def __init__( self, *, @@ -81,6 +84,9 @@ def __init__( verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False. skip_schema_migration (bool): Whether to skip schema migration. Defaults to False. silent (bool): If True, suppresses schema migration console output. Defaults to False. + + Raises: + AutogenerateDiffsDetected: If connected to a production database and schema does not match models. """ self._connection_string = default_values.get_required_value( env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string @@ -107,8 +113,32 @@ def __init__( self._enable_azure_authorization() self.SessionFactory = sessionmaker(bind=self.engine) - if not skip_schema_migration: + + prod_connection_string = default_values.get_non_required_value( + env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING_PROD + ) + + is_prod = bool(prod_connection_string) and self._connection_string == prod_connection_string + should_migrate = not is_prod and not skip_schema_migration + + if should_migrate: + # Non-production: run schema migration (upgrade + check). self._run_schema_migration(silent=silent) + else: + # Production or skip_schema_migration=True: verify schema compatibility + # without modifying the database. Logs a warning on mismatch but does not + # block startup, so developers on newer code can still query data. + from alembic.util.exc import AutogenerateDiffsDetected, CommandError + + try: + self._check_schema_migration(silent=silent) + except (AutogenerateDiffsDetected, CommandError) as e: + logger.warning( + "Schema mismatch detected. " + "Your code models differ from the database schema. " + "This may cause errors if your code references columns or tables that don't exist. " + f"Schema was NOT modified. Details: {e}" + ) super().__init__() diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index e03af2461d..f7f79bde47 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1434,6 +1434,24 @@ def _run_schema_migration(self, *, silent: bool = False) -> None: run_schema_migrations(engine=self.engine, silent=silent) check_schema_migrations(engine=self.engine, silent=silent) + def _check_schema_migration(self, *, silent: bool = False) -> None: + """ + Verify that the current database schema matches the models without modifying the database. + + Args: + silent (bool): If True, suppresses Alembic console output. Defaults to False. + + Raises: + RuntimeError: If the engine is not initialized. + AutogenerateDiffsDetected: If the schema does not match the models. + """ + from pyrit.memory.migration import check_schema_migrations + + logger.info("Checking schema migration compatibility.") + if self.engine is None: + raise RuntimeError("Engine must be initialized to check schema migrations.") + check_schema_migrations(engine=self.engine, silent=silent) + def reset_database(self) -> None: """ Drop and recreate all tables in the database. diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 34e9671461..48e6b9d478 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -10,6 +10,7 @@ import pytest from sqlalchemy import inspect, text +from pyrit.common.singleton import Singleton from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry from pyrit.models import Conversation, MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter @@ -617,3 +618,172 @@ def test_reset_database_raises_when_engine_none(): obj.engine = None with pytest.raises(RuntimeError, match="Engine is not initialized"): obj.reset_database() + + +def test_init_prod_connection_runs_check_only_not_migration(): + """When connection matches prod, only check_schema_migrations runs — not run_schema_migrations.""" + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_check_schema_migration") as mock_check, + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_check.assert_called_once() + mock_migration.assert_not_called() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_prod_connection_warns_on_schema_mismatch(): + """When connection matches prod and schema doesn't match, startup succeeds with a warning (no raise).""" + from alembic.util.exc import AutogenerateDiffsDetected + + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object( + AzureSQLMemory, + "_check_schema_migration", + side_effect=AutogenerateDiffsDetected( + "diffs detected", + revision_context=MagicMock(), + diffs=[], + ), + ), + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + # Should NOT raise — AzureSQLMemory catches AutogenerateDiffsDetected and warns + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_allows_migration_when_connection_does_not_match_prod(): + """Migration proceeds normally when the connection string does not match the prod env var.""" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: "Server=tcp:prod.database.windows.net;", + }, + ), + ): + AzureSQLMemory( + connection_string="Server=tcp:dev.database.windows.net;", + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_migration.assert_called_once() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_allows_migration_when_prod_env_var_not_set(): + """Migration proceeds normally when AZURE_SQL_DB_CONNECTION_STRING_PROD is not set.""" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + }, + clear=False, + ), + ): + os.environ.pop(AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD, None) + AzureSQLMemory( + connection_string="Server=tcp:dev.database.windows.net;", + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_migration.assert_called_once() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_prod_with_skip_schema_migration_still_checks(): + """When skip_schema_migration=True on prod, the read-only check still runs but migration does not.""" + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_check_schema_migration") as mock_check, + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + skip_schema_migration=True, + ) + mock_check.assert_called_once() + mock_migration.assert_not_called() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index 2054621438..7ea3cfbe62 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -605,6 +605,72 @@ def test_check_schema_migrations_not_silent_prints_output(capsys): engine.dispose() +def test_memory_interface_check_schema_migration_calls_check(): + """_check_schema_migration on MemoryInterface calls check_schema_migrations without running upgrade.""" + from unittest.mock import MagicMock, patch + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = MagicMock() + + with patch("pyrit.memory.migration.check_schema_migrations") as mock_check: + MemoryInterface._check_schema_migration(obj, silent=True) + mock_check.assert_called_once_with(engine=obj.engine, silent=True) + + +def test_memory_interface_check_schema_migration_raises_on_mismatch(): + """_check_schema_migration raises AutogenerateDiffsDetected when schema mismatches (pure primitive).""" + from unittest.mock import MagicMock, patch + + from alembic.util.exc import AutogenerateDiffsDetected + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = MagicMock() + + with patch( + "pyrit.memory.migration.check_schema_migrations", + side_effect=AutogenerateDiffsDetected( + "diffs detected", + revision_context=MagicMock(), + diffs=[], + ), + ): + with pytest.raises(AutogenerateDiffsDetected): + MemoryInterface._check_schema_migration(obj, silent=True) + + +def test_memory_interface_check_schema_migration_raises_without_engine(): + """_check_schema_migration raises RuntimeError when engine is None.""" + from unittest.mock import MagicMock + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = None + + with pytest.raises(RuntimeError, match="Engine must be initialized"): + MemoryInterface._check_schema_migration(obj, silent=False) + + +def test_memory_migrations_head_command(capsys): + """The 'head' subcommand of memory_migrations.py prints the current Alembic head revision.""" + import sys + + # Import the module's main function + sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "build_scripts")) + from memory_migrations import _cmd_head + + _cmd_head() + captured = capsys.readouterr() + revision = captured.out.strip() + # Should be a non-empty hex-ish string + assert len(revision) > 0 + assert all(c in "0123456789abcdef" for c in revision) + + # ============================================================================= # Backfill tests for the Conversations table migration (b2f4c6a8d1e3) # =============================================================================