Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions build_scripts/memory_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def _cmd_check() -> None:
tmp_path.unlink(missing_ok=True)


def _cmd_head() -> None:
"""Print the current Alembic head revision ID."""
from pathlib import Path

from alembic.config import Config
from alembic.script import ScriptDirectory

script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic"
config = Config()
config.set_main_option("script_location", str(script_location))
head = ScriptDirectory.from_config(config).get_current_head()
print(head)


def _build_parser() -> argparse.ArgumentParser:
"""Build the CLI argument parser."""
parser = argparse.ArgumentParser(
Expand All @@ -71,6 +85,8 @@ def _build_parser() -> argparse.ArgumentParser:

sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.")

sub.add_parser("head", help="Print the current Alembic head revision ID.")

return parser


Expand All @@ -82,6 +98,8 @@ def main() -> int:
_cmd_generate(message=args.message, force=args.force)
elif args.command == "check":
_cmd_check()
elif args.command == "head":
_cmd_head()

return 0

Expand Down
169 changes: 169 additions & 0 deletions build_scripts/migrate_prod_memory_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Microsoft Corporation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if adding a 300 line script is a bit of overhead, with a lot of overlap with other migration code (i.e. the run_migrations functions there which upgrades a db given its engine.

I think this should be a thin wrapper that constructs an AzureSQLMemory with skip=true and the prod connection string, then calls its _run_schema_migration ... no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, slimmed it down!

# Licensed under the MIT license.

"""
Deliberate schema migration tool for production databases.

This script is the ONLY sanctioned way to apply Alembic migrations to a production
database. It is intended to be run during the release process (see
doc/contributing/10_release_process.md) or by a CD pipeline — never by normal
application startup.

It constructs an AzureSQLMemory with skip_schema_migration=True (bypassing the
runtime guard), then explicitly calls _run_schema_migration to upgrade to head.
The environment checks ensure this only runs from a release branch.

Safety rails:
- Validates the environment (release branch, clean working tree, no .dev version).
- Interactive confirmation when running in a terminal.
- Exits non-zero on any failure.

Usage:
python build_scripts/migrate_prod_memory_schema.py

The script reads the production connection string from
AZURE_SQL_DB_CONNECTION_STRING_PROD (loaded from ~/.pyrit/.env).
"""

import subprocess
import sys

import dotenv

from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH

# Load .env files from ~/.pyrit/ (same files that initialize_pyrit_async loads)
# Use override=False so explicitly-set env vars take precedence over .env values
for _env_file in [CONFIGURATION_DIRECTORY_PATH / ".env", CONFIGURATION_DIRECTORY_PATH / ".env.local"]:
if _env_file.exists():
dotenv.load_dotenv(_env_file, override=False, interpolate=True)

# ANSI color codes
_GREEN = "\033[92m"
_RED = "\033[91m"
_YELLOW = "\033[93m"
_RESET = "\033[0m"


def _print_error(message: str) -> None:
"""Print an error message in red to stderr."""
print(f"{_RED}ERROR: {message}{_RESET}", file=sys.stderr)


def _print_success(message: str) -> None:
"""Print a success message in green."""
print(f"{_GREEN}{message}{_RESET}")


def _check_release_environment() -> list[str]:
"""
Validate that the script is running in a proper release environment.

Returns a list of warning/error messages. Empty list means all checks pass.
"""
issues: list[str] = []

try:
branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
if not branch.startswith("releases/"):
issues.append(
f"Not on a release branch (current: '{branch}'). "
"Production migrations should run from 'releases/vX.Y.Z'."
)
except (subprocess.CalledProcessError, FileNotFoundError):
issues.append("Could not determine current Git branch.")

try:
dirty_files = subprocess.check_output(
["git", "status", "--porcelain", "--", "pyrit/memory/"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
if dirty_files:
issues.append(
"Uncommitted changes detected in pyrit/memory/:\n"
f" {dirty_files}\n"
" Commit or stash changes before migrating production."
)
except (subprocess.CalledProcessError, FileNotFoundError):
issues.append("Could not check Git working tree status.")

try:
from pyrit import __version__

if ".dev" in __version__:
issues.append(
f"PyRIT version is '{__version__}' (development). "
"Production migrations should use a release version (no .dev suffix)."
)
except ImportError:
issues.append("Could not determine PyRIT version.")

return issues


def main() -> int:
"""Entry point for production schema migration."""
import argparse

parser = argparse.ArgumentParser(
description="Apply Alembic schema migrations to the production database.",
)
parser.add_argument(
"--skip-environment-check",
action="store_true",
help="Skip release environment checks (branch, clean tree, version). Use only in CI with caution.",
)
args = parser.parse_args()

# Safety rail: Verify release environment
if not args.skip_environment_check:
issues = _check_release_environment()
if issues:
_print_error("Release environment checks failed:")
for issue in issues:
_print_error(f" - {issue}")
_print_error("Fix the above issues or pass --skip-environment-check (CI only).")
return 1
else:
print(f"{_YELLOW}WARNING: Skipping release environment checks.{_RESET}")

# Interactive confirmation
if sys.stdin.isatty():
print("About to migrate production database schema to head.")
response = input("Type 'yes' to proceed: ")
if response.strip().lower() != "yes":
print("Aborted.")
return 1

# Construct AzureSQLMemory with skip_schema_migration=True to bypass the runtime guard,
# then explicitly run migration.
import os

from pyrit.memory import AzureSQLMemory

prod_conn = os.environ.get(AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD)
if not prod_conn:
_print_error(f"Environment variable '{AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD}' is not set.")
return 1

try:
memory = AzureSQLMemory(
connection_string=prod_conn,
skip_schema_migration=True,
)
print("Running schema migration...")
memory._run_schema_migration()
_print_success("Production schema migration completed and verified successfully.")
return 0
except Exception as e:
_print_error(f"Migration failed: {e}")
return 1


if __name__ == "__main__":
raise SystemExit(main())
54 changes: 51 additions & 3 deletions doc/contributing/10_release_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,55 @@ Note: You may need to build the package again if those changes modify any depend
Lastly, **Verify pyrit-internal is up to date.** Follow the instructions at [aka.ms/internal-release](https://aka.ms/internal-release) to ensure the internal package is current.


## 9. Publish to PyPI
## 9. Migrate Production Database Schema

Apply any pending Alembic migrations to the production database. This is the **only**
sanctioned path for modifying the production schema — normal startup only validates,
never upgrades.

**Run from the release branch with release dependencies.** This ensures the migration
files and model definitions match exactly what will be shipped to users. Running from
`main` or a dev environment could apply unreleased migrations that break prod.

```bash
git checkout releases/vx.y.z
uv run python -c "import pyrit; print(pyrit.__version__)" # verify: x.y.z (no .dev0)
```

**Run the migration** (reads `AZURE_SQL_DB_CONNECTION_STRING_PROD` from `~/.pyrit/.env`):

```bash
uv run python build_scripts/migrate_prod_memory_schema.py
```

The script validates the environment (release branch, clean tree, no `.dev` version),
constructs an `AzureSQLMemory` pointed at prod, and runs `_run_schema_migration()` which
upgrades to head and verifies the schema matches models. Since you're on the release branch,
head is the release revision.

**Verify prod is usable after migration.** This connects to the prod DB using the
check-only path and confirms compatibility:

```bash
uv run python -c "
import os, dotenv
from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH
dotenv.load_dotenv(CONFIGURATION_DIRECTORY_PATH / '.env', override=False, interpolate=True)
from pyrit.memory import AzureSQLMemory
AzureSQLMemory(connection_string=os.environ['AZURE_SQL_DB_CONNECTION_STRING_PROD'])
"
```

If it exits without error (or only a schema mismatch warning), prod is ready.

If no schema changes landed in this release, `_run_schema_migration` is a no-op.
Still run it as confirmation.

**Rollback policy:** forward-fix only. Ship a new corrective migration rather than downgrading,
since `downgrade()` risks data loss.


## 10. Publish to PyPI

Create an account on pypi.org if you don't have one yet.
Ask one of the other maintainers to add you to the `pyrit` project on PyPI.
Expand All @@ -221,7 +269,7 @@ If successful, it will print
> View at:
> https://pypi.org/project/pyrit/x.y.z/

## 10. Update main
## 11. Update main

After the release is on PyPI, make sure to create a PR for the `main` branch
where the only changes are:
Expand All @@ -233,7 +281,7 @@ where the only changes are:
The PR should be made from your fork and should be a different branch than the releases branch you created earlier.
This should be something like `x.y.z+1.dev0`.

## 11. Create GitHub Release
## 12. Create GitHub Release

Finally, go to the [releases page](https://github.com/microsoft/PyRIT/releases), select "Draft a new release" and the "tag"
for which you want to create the release notes. It should match the version that you just released
Expand Down
32 changes: 31 additions & 1 deletion pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton):
AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL"
AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN"

# Optional environment variable for production connection string to prevent accidental schema migrations on prod
AZURE_SQL_DB_CONNECTION_STRING_PROD: str = "AZURE_SQL_DB_CONNECTION_STRING_PROD"

def __init__(
self,
*,
Expand All @@ -81,6 +84,9 @@ def __init__(
verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False.
skip_schema_migration (bool): Whether to skip schema migration. Defaults to False.
silent (bool): If True, suppresses schema migration console output. Defaults to False.

Raises:
AutogenerateDiffsDetected: If connected to a production database and schema does not match models.
"""
self._connection_string = default_values.get_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
Expand All @@ -107,8 +113,32 @@ def __init__(
self._enable_azure_authorization()

self.SessionFactory = sessionmaker(bind=self.engine)
if not skip_schema_migration:

prod_connection_string = default_values.get_non_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING_PROD
)

is_prod = bool(prod_connection_string) and self._connection_string == prod_connection_string
should_migrate = not is_prod and not skip_schema_migration

if should_migrate:
# Non-production: run schema migration (upgrade + check).
self._run_schema_migration(silent=silent)
else:
# Production or skip_schema_migration=True: verify schema compatibility
# without modifying the database. Logs a warning on mismatch but does not
# block startup, so developers on newer code can still query data.
from alembic.util.exc import AutogenerateDiffsDetected, CommandError

try:
self._check_schema_migration(silent=silent)
Comment thread
jsong468 marked this conversation as resolved.
except (AutogenerateDiffsDetected, CommandError) as e:
logger.warning(
"Schema mismatch detected. "
"Your code models differ from the database schema. "
"This may cause errors if your code references columns or tables that don't exist. "
f"Schema was NOT modified. Details: {e}"
)

super().__init__()

Expand Down
18 changes: 18 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,24 @@ def _run_schema_migration(self, *, silent: bool = False) -> None:
run_schema_migrations(engine=self.engine, silent=silent)
check_schema_migrations(engine=self.engine, silent=silent)

def _check_schema_migration(self, *, silent: bool = False) -> None:
"""
Verify that the current database schema matches the models without modifying the database.

Args:
silent (bool): If True, suppresses Alembic console output. Defaults to False.

Raises:
RuntimeError: If the engine is not initialized.
AutogenerateDiffsDetected: If the schema does not match the models.
"""
from pyrit.memory.migration import check_schema_migrations

logger.info("Checking schema migration compatibility.")
if self.engine is None:
raise RuntimeError("Engine must be initialized to check schema migrations.")
check_schema_migrations(engine=self.engine, silent=silent)

def reset_database(self) -> None:
"""
Drop and recreate all tables in the database.
Expand Down
Loading
Loading