Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ project/target/
project/project/target/
examples/target/

.venv/

# IDE files
.bsp/
.metals/
Expand Down
17 changes: 10 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ RUN echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | tee /etc/ap
ENV JAVA_HOME=/usr/lib/jvm/java-17-openjdk-amd64
ENV PATH="${JAVA_HOME}/bin:${PATH}"

# Install Python packages (JAX GPU already included in base image)
RUN pip install --upgrade \
matplotlib \
pandas \
scikit-learn \
jupyter \
einops
# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"

# Copy project files
COPY . /workspace/

# Create venv inheriting JAX from the base image, add extra packages
RUN uv venv .venv --system-site-packages && \
uv pip install --python .venv/bin/python einops matplotlib pandas

# Skip uv sync — JAX is already available via system-site-packages
ENV DIMWIT_SKIP_SYNC=true

# Set Python path
ENV PYTHONPATH=/workspace/src/python

Expand Down
26 changes: 16 additions & 10 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,33 @@ FROM sbtscala/scala-sbt:eclipse-temurin-jammy-17.0.9_9_1.9.7_3.3.1
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /workspace

# Install Python
# Install Python and curl
RUN apt-get update && apt-get install -y \
python3.11 \
python3.11-dev \
libpython3.11 \
python3-pip \
curl \
&& rm -rf /var/lib/apt/lists/* && \
ln -sf /usr/bin/python3.11 /usr/bin/python

# Install Python packages with JAX CPU
RUN pip install --no-cache-dir --upgrade \
"jax[cpu]" \
matplotlib \
pandas \
scikit-learn \
jupyter \
einops \
numpy
# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"

# Copy only build files first for dependency caching
COPY build.sbt /workspace/
COPY project/ /workspace/project/

# Set up Python venv with CPU JAX (pyproject.toml specifies cuda12; override here)
RUN uv venv .venv --python python3.11 && \
uv pip install --python .venv/bin/python "jax[cpu]>=0.4" "einops>=0.8"

# Skip uv sync at runtime and point directly at the venv Python
# This also prevents uv from being called at sbt load time
ENV DIMWIT_SKIP_SYNC=true
ENV DIMWIT_PYTHON_PATH=/workspace/.venv/bin/python

# Pre-download SBT dependencies
RUN sbt update || true

Expand Down
11 changes: 11 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ai.kien.python.Python
import scala.sys.process._

ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / scalaVersion := "3.8.1"
Expand All @@ -15,6 +16,14 @@ lazy val root = (project in file("."))
name := "dimwit-root"
)

lazy val uvPython: String =
sys.env.getOrElse(
"DIMWIT_PYTHON_PATH",
Seq("uv", "run", "--no-sync", "python", "-c", "import sys; print(sys.executable)").!!.trim
)
lazy val python = Python(uvPython)
lazy val scalapyJavaOptions = python.scalapyProperties.get.map { case (k, v) => s"-D$k=$v" }.toSeq

lazy val core = (project in file("core"))
.settings(
name := "dimwit-core",
Expand All @@ -25,6 +34,8 @@ lazy val core = (project in file("core"))
"org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test
),
fork := true,
javaOptions ++= scalapyJavaOptions,
Test / envVars += "DIMWIT_SKIP_SYNC" -> "true",
coverageMinimumStmtTotal := 80,
coverageFailOnMinimum := false,
coverageHighlighting := true,
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/dimwit/jax/Jax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.py.PyQuote
import dimwit.hardware.{Device, DeviceBackend}
import dimwit.python.PythonSetup

object Jax:

Expand Down
54 changes: 0 additions & 54 deletions core/src/main/scala/dimwit/jax/PythonSetup.scala

This file was deleted.

14 changes: 13 additions & 1 deletion core/src/main/scala/dimwit/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import scala.annotation.targetName

import dimwit.jax.Jax
import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex}

package object dimwit:

import scala.compiletime.ops.string.+
Expand Down Expand Up @@ -102,3 +101,16 @@ package object dimwit:
export dimwit.stats.{Prob, LogProb}
export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution}
export dimwit.MemoryHelper.withLocalCleanup

/** Explicitly configures the Python environment before any ScalaPy call.
* Call this function at the start of your program (before any `py.*` call)
* to ensure the Python environment is correctly set up.
*
* @param performUVSync Whether to run `uv sync` to automatically set up the Python environment based on the project's `pyproject.toml`. Set this to false if you want to manage the Python environment yourself (e.g. with a custom venv or conda env).
*
* Env-var overrides:
* - DIMWIT_PYTHON_PATH — path to a specific Python interpreter
* - DIMWIT_PYTHON_LIBRARY — path to a specific libpython shared library
*/
def initialize(performUVSync: Boolean = true): Unit =
dimwit.python.PythonSetup.configureScalaPy(performUVSync)
122 changes: 122 additions & 0 deletions core/src/main/scala/dimwit/python/PythonSetup.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package dimwit.python

import me.shadaj.scalapy.py
import scala.sys.process.Process

/** Manages Python environment setup for DimWit.
*
* Handles extraction of Python helper modules from JAR resources and configuration of Python path for ScalaPy integration.
*/
object PythonSetup:

/** Configures the JVM system properties that ScalaPy/JNA need to locate the Python shared library.
*
* This must run before any `py.*` call (i.e. before ScalaPy's own class initialiser).
*
* Respects three env-var overrides:
* - DIMWIT_SKIP_SYNC — skip uv sync and manage Python environment manually (overrides performUvSync argument)
* - DIMWIT_PYTHON_PATH — use a specific Python interpreter
* - DIMWIT_PYTHON_LIBRARY — use a specific shared-library path
*/
def configureScalaPy(performUvSync: Boolean): Unit =
val skipSync = sys.env.get("DIMWIT_SKIP_SYNC").exists(v => v == "true" || v == "1")
if performUvSync && !skipSync then
if Process(Seq("uv", "sync")).! != 0 then
throw new RuntimeException(
"""[dimwit] uv sync failed. Ensure uv is installed (https://docs.astral.sh/uv/) and
|that a pyproject.toml with JAX dependencies exists in your project, for example:
|
| [project]
| name = "my-project"
| version = "0.1.0"
| requires-python = ">=3.11"
| dependencies = [
| "jax[cpu]>=0.4", # or jax[cuda12], jax[tpu]
| "einops>=0.8",
| ]
|
|Set DIMWIT_SKIP_SYNC=true to manage the Python environment yourself.""".stripMargin
)

val python = scala.sys.env.getOrElse(
"DIMWIT_PYTHON_PATH",
try Process(Seq("uv", "run", "python", "-c", "import sys; print(sys.executable)")).!!.trim
catch
case e: Exception =>
throw new RuntimeException(
"[dimwit] Could not resolve Python interpreter. " +
"Set DIMWIT_PYTHON_PATH or check your uv setup.",
e
)
)

val library = scala.sys.env.getOrElse(
"DIMWIT_PYTHON_LIBRARY",
try
Process(Seq(
python,
"-c",
"import sys, os, sysconfig as sc, ctypes.util as cu; v = sys.version_info\n" +
"if sys.platform == 'win32': print(os.path.join(sys.base_prefix, 'python{}{}.dll'.format(v.major, v.minor)))\n" +
"else:\n" +
" lib = os.path.join(sc.get_config_var('LIBDIR') or '', sc.get_config_var('INSTSONAME') or sc.get_config_var('LDLIBRARY') or '')\n" +
" print(cu.find_library('python{}.{}'.format(v.major, v.minor)) if not os.path.isfile(lib) else lib)"
)).!!.trim
catch
case e: Exception =>
throw new RuntimeException(
s"[dimwit] Could not locate Python shared library for interpreter '$python'. " +
"Set DIMWIT_PYTHON_LIBRARY to the full path of libpython (e.g. /usr/lib/libpython3.11.so.1.0).",
e
)
)

System.setProperty("scalapy.python.programname", python)
System.setProperty("scalapy.python.library", library)

/** Initialize Python environment by extracting helper modules and configuring Python path.
*
* This method:
* - Configures ScalaPy JVM properties (must happen before any py.* call)
* - Extracts jax_helper.py from JAR resources to a temporary directory
* - Adds the temp directory to Python's sys.path
* - Registers shutdown hook for cleanup
* - Falls back to development paths if running from source
*
* Called lazily on first access to any JAX module — safe to call multiple times.
*/
lazy val initialize: Unit =

lazy val sys = py.module("sys")

// Extract jax_helper.py from JAR resources to a temporary directory
val resourcePath = "/python/jax_helper.py"
val resourceStream = getClass.getResourceAsStream(resourcePath)

if resourceStream != null then
try
val tempDir = java.nio.file.Files.createTempDirectory("dimwit-python")
val targetFile = tempDir.resolve("jax_helper.py")
java.nio.file.Files.copy(
resourceStream,
targetFile,
java.nio.file.StandardCopyOption.REPLACE_EXISTING
)

// Add the temp directory to Python path
sys.path.append(tempDir.toAbsolutePath.toString)

// Register shutdown hook to clean up temp directory
Runtime.getRuntime.addShutdownHook(new Thread(() =>
try
java.nio.file.Files
.walk(tempDir)
.sorted(java.util.Comparator.reverseOrder())
.forEach(java.nio.file.Files.delete)
catch case _: Exception => () // Ignore cleanup errors
))
finally resourceStream.close()
else
// Fallback to legacy path for development/local usage
sys.path.append("./core/src/main/resources/python")
sys.path.append("./src/main/resources/python")
5 changes: 1 addition & 4 deletions core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@ package dimwit.autodiff
import dimwit.*
import dimwit.*
import dimwit.Conversions.given
import org.scalatest.matchers.should.Matchers
import org.scalatest.funspec.AnyFunSpec

import dimwit.autodiff.Autodiff.Gradient

class AutodiffSuite extends AnyFunSpec with Matchers:
class AutodiffSuite extends DimwitTest:

describe("grad"):
describe("single parameter function"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package dimwit.autodiff
import dimwit.*
import dimwit.Conversions.given
import dimwit.autodiff.FloatTree.ops.*
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

class FloatTensorTreeSuite extends AnyFunSpec with Matchers:
class FloatTensorTreeSuite extends DimwitTest:

describe("map"):
it("1-level case class"):
Expand Down
5 changes: 1 addition & 4 deletions core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import dimwit.*
import dimwit.jax.Jax
import dimwit.Conversions.given
import me.shadaj.scalapy.py
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

class ToPyTreeSuite extends AnyFunSpec with Matchers:
class ToPyTreeSuite extends DimwitTest:

describe("TensorTree Identity (fromPyTree(toPyTree(x)) == x)"):

Expand Down
4 changes: 1 addition & 3 deletions core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package dimwit.autodiff
import dimwit.*
import dimwit.Conversions.given
import dimwit.*
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

class TensorTreeSuite extends AnyFunSpec with Matchers:
class TensorTreeSuite extends DimwitTest:

describe("map"):
it("1-level case class"):
Expand Down
4 changes: 1 addition & 3 deletions core/src/test/scala/dimwit/jax/JitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package dimwit.jax

import dimwit.*
import dimwit.Conversions.given
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers
import me.shadaj.scalapy.py

class JitSuite extends AnyFunSpec with Matchers:
class JitSuite extends DimwitTest:

it("jit compilation works correctly"):
def f(t: Tensor1[A, Float]): Tensor1[A, Float] =
Expand Down
Loading