This document provides guidance for Claude when working on the Simplexity codebase, a JAX-based computational mechanics library for sequence prediction models.
Simplexity is a research-oriented machine learning library focused on computational mechanics perspectives of sequence prediction. The codebase uses JAX/Equinox for neural network implementations and follows functional programming patterns.
- Line Length: Maximum 120 characters
- Python Version: 3.12+
- Formatting: Use
ruff formatfor automatic formatting - Linting: Follow ruff rules defined in
pyproject.toml - Type Annotations: Always use type hints for function parameters and return values
- Docstrings: Follow Google style docstrings for all public functions and classes
- Import Order: Managed by ruff's isort rules (standard library, third-party, local)
- Type Safety: All code must pass
pyrighttype checking in standard mode - Testing: Write pytest tests for new functionality, maintain high coverage
- No Comments: Avoid inline comments; code should be self-documenting through clear naming and structure
- Functional Style: Prefer functional programming patterns, especially when working with JAX
- Files: Use snake_case for all Python files
- Classes: Use PascalCase for classes
- Functions/Methods: Use snake_case
- Constants: Use UPPER_SNAKE_CASE
- Type Variables: Use PascalCase (e.g.,
State,Model)
- Place tests in the
tests/directory mirroring the source structure - Use pytest fixtures for common test setup
- Test files must start with
test_ - Test functions must start with
test_
@pytest.fixture
def model_fixture() -> Model:
return build_model(...)
def test_functionality(model_fixture: Model):
# Arrange
input_data = ...
# Act
result = model_fixture.process(input_data)
# Assert
chex.assert_trees_all_close(result, expected)- Use
chexfor JAX array assertions - Test with different random seeds
- Verify shape and dtype consistency
- New code: Minimum 80% coverage (strictly enforced in PRs via
diff-cover) - Existing code: Overall coverage is monitored but won't block updates (allowing gradual improvement)
- PRs: Only new/changed code must meet 80% threshold (checked via
diff-cover) - Main branch: Overall coverage is tracked but not enforced (to allow updates while improving coverage)
- HTML coverage reports are generated in
htmlcov/directory - Diff coverage reports show coverage of only new/changed lines in PRs
- Goal: Gradually improve overall coverage while ensuring all new code meets standards
simplexity/
├── configs/ # Hydra configuration files
├── data_structures/ # Core data structures
├── evaluation/ # Model evaluation functions
├── generative_processes/ # Generative model implementations
├── logging/ # Logging utilities
├── persistence/ # Model checkpointing
├── predictive_models/ # Neural network models
├── training/ # Training loops and utilities
└── utils/ # Helper functions
- Protocol Classes: Use
typing.Protocolfor defining interfaces - Abstract Base Classes: Use
eqx.Modulewith@abstractmethodfor base classes - Builder Pattern: Provide builder functions for complex object construction
- Separation of Concerns: Keep model definitions, training logic, and evaluation separate
- Pure Functions: Keep functions pure and side-effect free
- Vectorization: Use
eqx.filter_vmapfor batch processing - Random Keys: Always split PRNG keys appropriately
- Tree Operations: Use JAX tree operations for nested structures
- JIT Compilation: Consider JIT compatibility when writing functions
- Package Manager: Use
uvfor dependency management - Dependencies: Add to
pyproject.tomlin appropriate sections - Optional Dependencies: Use extras for optional features (aws, cuda, dev, mac, pytorch)
- Version Pinning: Specify minimum versions with
>=
Before submitting code, ensure it passes:
uv run --extra dev ruff check- Lintinguv run --extra dev ruff format --check- Formattinguv run --extra dev --extra pytorch pyright- Type checkinguv run --extra dev --extra pytorch pytest- Testsuv run --extra dev diff-cover coverage.xml --compare-branch=origin/main --fail-under=80- New code coverage (must meet 80% threshold)
# Install dependencies
uv sync --extra dev
# Run linting
uv run --extra dev ruff check
# Format code
uv run --extra dev ruff format
# Type check
uv run --extra dev --extra pytorch pyright
# Run tests (coverage is tracked but won't fail)
uv run --extra dev --extra pytorch pytest
# Check overall coverage threshold (optional - will fail if below 80%)
uv run --extra dev --extra pytorch pytest --cov-fail-under=80
# Check coverage for new code only (compares against main branch)
# This is what CI checks - new code must meet 80% threshold
# Note: Requires pytest to be run first to generate coverage.xml
uv run --extra dev --extra pytorch pytest --cov-fail-under=0
uv run --extra dev diff-cover coverage.xml --compare-branch=origin/main --fail-under=80
# View HTML coverage report (generated in htmlcov/)
# After running pytest, open htmlcov/index.html in a browser
# Train a model
uv run python simplexity/train_model.pyWhen reviewing or creating PRs:
- Ensure all CI checks pass
- Maintain or improve test coverage
- Follow existing patterns and conventions
- Keep changes focused and atomic
- Update relevant documentation
- Never commit credentials or API keys
- Use environment variables or config files for sensitive data
- Follow AWS best practices when using S3 persistence
- Validate all external inputs
- Profile JAX code for bottlenecks
- Prefer vectorized operations over loops
- Use appropriate data structures from
simplexity.data_structures - Consider memory usage with large models
- Leverage GPU acceleration when available
- Keep documentation concise and technical
- Focus on "why" rather than "what"
- Update README.md for user-facing changes
- Use type hints as primary documentation
- Provide usage examples in docstrings when helpful