From b2be007e8d02d9fb07bef5b302577a68c14f7853 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sun, 26 Apr 2026 16:45:28 +0200 Subject: [PATCH 1/2] add auto-initialization for python Make use of uv to automatically initialize the python environment before any jax operation is accessed. Expose a separate dimwit.setup() method to allow manual initialization in cases the user needs to access python libraries before jax. Update build.sbt and docker files Co-authored-by: Copilot --- .dockerignore | 2 + Dockerfile | 17 +++-- Dockerfile.ci | 26 ++++--- build.sbt | 11 +++ .../main/scala/dimwit/jax/PythonSetup.scala | 71 ++++++++++++++++++- core/src/main/scala/dimwit/package.scala | 12 +++- pyproject.toml | 10 +++ 7 files changed, 129 insertions(+), 20 deletions(-) create mode 100644 pyproject.toml diff --git a/.dockerignore b/.dockerignore index 4d3485c7..aeab7dbd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,6 +4,8 @@ project/target/ project/project/target/ examples/target/ +.venv/ + # IDE files .bsp/ .metals/ diff --git a/Dockerfile b/Dockerfile index d560681c..bb3210a9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,17 +28,20 @@ RUN echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | tee /etc/ap ENV JAVA_HOME=/usr/lib/jvm/java-17-openjdk-amd64 ENV PATH="${JAVA_HOME}/bin:${PATH}" -# Install Python packages (JAX GPU already included in base image) -RUN pip install --upgrade \ - matplotlib \ - pandas \ - scikit-learn \ - jupyter \ - einops +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:${PATH}" # Copy project files COPY . /workspace/ +# Create venv inheriting JAX from the base image, add extra packages +RUN uv venv .venv --system-site-packages && \ + uv pip install --python .venv/bin/python einops matplotlib pandas + +# Skip uv sync — JAX is already available via system-site-packages +ENV DIMWIT_SKIP_SYNC=true + # Set Python path ENV PYTHONPATH=/workspace/src/python diff --git a/Dockerfile.ci b/Dockerfile.ci index ab46d71d..08071fdd 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -7,27 +7,33 @@ FROM sbtscala/scala-sbt:eclipse-temurin-jammy-17.0.9_9_1.9.7_3.3.1 ENV DEBIAN_FRONTEND=noninteractive WORKDIR /workspace -# Install Python +# Install Python and curl RUN apt-get update && apt-get install -y \ python3.11 \ + python3.11-dev \ + libpython3.11 \ python3-pip \ + curl \ && rm -rf /var/lib/apt/lists/* && \ ln -sf /usr/bin/python3.11 /usr/bin/python -# Install Python packages with JAX CPU -RUN pip install --no-cache-dir --upgrade \ - "jax[cpu]" \ - matplotlib \ - pandas \ - scikit-learn \ - jupyter \ - einops \ - numpy +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:${PATH}" # Copy only build files first for dependency caching COPY build.sbt /workspace/ COPY project/ /workspace/project/ +# Set up Python venv with CPU JAX (pyproject.toml specifies cuda12; override here) +RUN uv venv .venv --python python3.11 && \ + uv pip install --python .venv/bin/python "jax[cpu]>=0.4" "einops>=0.8" + +# Skip uv sync at runtime and point directly at the venv Python +# This also prevents uv from being called at sbt load time +ENV DIMWIT_SKIP_SYNC=true +ENV DIMWIT_PYTHON_PATH=/workspace/.venv/bin/python + # Pre-download SBT dependencies RUN sbt update || true diff --git a/build.sbt b/build.sbt index 218ae5b5..e4af8557 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,5 @@ import ai.kien.python.Python +import scala.sys.process._ ThisBuild / version := "0.1.0-SNAPSHOT" ThisBuild / scalaVersion := "3.8.1" @@ -15,6 +16,14 @@ lazy val root = (project in file(".")) name := "dimwit-root" ) +lazy val uvPython: String = + sys.env.getOrElse( + "DIMWIT_PYTHON_PATH", + Seq("uv", "run", "--no-sync", "python", "-c", "import sys; print(sys.executable)").!!.trim + ) +lazy val python = Python(uvPython) +lazy val scalapyJavaOptions = python.scalapyProperties.get.map { case (k, v) => s"-D$k=$v" }.toSeq + lazy val core = (project in file("core")) .settings( name := "dimwit-core", @@ -25,6 +34,8 @@ lazy val core = (project in file("core")) "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test ), fork := true, + javaOptions ++= scalapyJavaOptions, + Test / envVars += "DIMWIT_SKIP_SYNC" -> "true", coverageMinimumStmtTotal := 80, coverageFailOnMinimum := false, coverageHighlighting := true, diff --git a/core/src/main/scala/dimwit/jax/PythonSetup.scala b/core/src/main/scala/dimwit/jax/PythonSetup.scala index 2730dfdf..14ed187c 100644 --- a/core/src/main/scala/dimwit/jax/PythonSetup.scala +++ b/core/src/main/scala/dimwit/jax/PythonSetup.scala @@ -1,6 +1,7 @@ package dimwit.jax import me.shadaj.scalapy.py +import scala.sys.process.Process /** Manages Python environment setup for DimWit. * @@ -10,17 +11,83 @@ object PythonSetup: private lazy val sys = py.module("sys") - /** Initialize Python environment by extracting helper modules and configuring paths. + /** Configures the JVM system properties that ScalaPy/JNA need to locate the Python shared library. + * + * This must run before any `py.*` call (i.e. before ScalaPy's own class initialiser). + * + * Respects three env-var overrides: + * - DIMWIT_SKIP_SYNC=true — skip `uv sync` + * - DIMWIT_PYTHON_PATH — use a specific Python interpreter + * - DIMWIT_PYTHON_LIBRARY — use a specific shared-library path + */ + private[dimwit] lazy val configureScalaPy: Unit = + if !scala.sys.env.get("DIMWIT_SKIP_SYNC").contains("true") then + if Process(Seq("uv", "sync")).! != 0 then + throw new RuntimeException( + """[dimwit] uv sync failed. Ensure uv is installed (https://docs.astral.sh/uv/) and + |that a pyproject.toml with JAX dependencies exists in your project, for example: + | + | [project] + | name = "my-project" + | version = "0.1.0" + | requires-python = ">=3.11" + | dependencies = [ + | "jax[cpu]>=0.4", # or jax[cuda12], jax[tpu] + | "einops>=0.8", + | ] + | + |Set DIMWIT_SKIP_SYNC=true to manage the Python environment yourself.""".stripMargin + ) + + val python = scala.sys.env.getOrElse( + "DIMWIT_PYTHON_PATH", + try Process(Seq("uv", "run", "python", "-c", "import sys; print(sys.executable)")).!!.trim + catch + case e: Exception => + throw new RuntimeException( + "[dimwit] Could not resolve Python interpreter. " + + "Set DIMWIT_PYTHON_PATH or check your uv setup.", + e + ) + ) + + val library = scala.sys.env.getOrElse( + "DIMWIT_PYTHON_LIBRARY", + try + Process(Seq( + python, + "-c", + "import sys, os, sysconfig as sc, ctypes.util as cu; v = sys.version_info\n" + + "if sys.platform == 'win32': print(os.path.join(sys.base_prefix, 'python{}{}.dll'.format(v.major, v.minor)))\n" + + "else:\n" + + " lib = os.path.join(sc.get_config_var('LIBDIR') or '', sc.get_config_var('INSTSONAME') or sc.get_config_var('LDLIBRARY') or '')\n" + + " print(cu.find_library('python{}.{}'.format(v.major, v.minor)) if not os.path.isfile(lib) else lib)" + )).!!.trim + catch + case e: Exception => + throw new RuntimeException( + s"[dimwit] Could not locate Python shared library for interpreter '$python'. " + + "Set DIMWIT_PYTHON_LIBRARY to the full path of libpython (e.g. /usr/lib/libpython3.11.so.1.0).", + e + ) + ) + + System.setProperty("scalapy.python.programname", python) + System.setProperty("scalapy.python.library", library) + + /** Initialize Python environment by extracting helper modules and configuring Python path. * * This method: + * - Configures ScalaPy JVM properties (must happen before any py.* call) * - Extracts jax_helper.py from JAR resources to a temporary directory * - Adds the temp directory to Python's sys.path * - Registers shutdown hook for cleanup * - Falls back to development paths if running from source * - * This is called lazily on first access to JAX modules and is safe to call multiple times (initialization happens only once). + * Called lazily on first access to any JAX module — safe to call multiple times. */ lazy val initialize: Unit = + configureScalaPy // Extract jax_helper.py from JAR resources to a temporary directory val resourcePath = "/python/jax_helper.py" val resourceStream = getClass.getResourceAsStream(resourcePath) diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 97f08f6e..04c46ec0 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -2,7 +2,6 @@ import scala.annotation.targetName import dimwit.jax.Jax import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex} - package object dimwit: import scala.compiletime.ops.string.+ @@ -102,3 +101,14 @@ package object dimwit: export dimwit.stats.{Prob, LogProb} export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution} export dimwit.MemoryHelper.withLocalCleanup + + /** Explicitly configures the Python environment before any ScalaPy call. + * Call this function at the start of your program (before any `py.*` call) + * to ensure the Python environment is correctly set up. + * + * Env-var overrides: + * - DIMWIT_SKIP_SYNC=true — skip `uv sync` (useful in CI / Docker) + * - DIMWIT_PYTHON_PATH — path to a specific Python interpreter + * - DIMWIT_PYTHON_LIBRARY — path to a specific libpython shared library + */ + def setup(): Unit = dimwit.jax.PythonSetup.configureScalaPy diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..3e84ea14 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dimwit-python-env" +version = "0.1.0" +description = "Python environment accessible through ScalaPy" +readme = "README.md" +requires-python = "==3.13.*" +dependencies = [ + "einops>=0.8.1", + "jax[cuda12]>=0.8.2", +] \ No newline at end of file From 468ba86fe47c76b75f105a01d97ef20c90d56d4b Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Wed, 29 Apr 2026 11:52:39 +0200 Subject: [PATCH 2/2] Make python initialization with uv an explicit call instead of lazy val Co-authored-by: Copilot --- core/src/main/scala/dimwit/jax/Jax.scala | 1 + core/src/main/scala/dimwit/package.scala | 6 ++++-- .../dimwit/{jax => python}/PythonSetup.scala | 15 ++++++++------- .../scala/dimwit/autodiff/AutodiffSuite.scala | 5 +---- .../dimwit/autodiff/FloatTensorTreeSuite.scala | 4 +--- .../test/scala/dimwit/autodiff/PyTreeSuite.scala | 5 +---- .../scala/dimwit/autodiff/TensorTreeSuite.scala | 4 +--- core/src/test/scala/dimwit/jax/JitSuite.scala | 4 +--- .../scala/dimwit/memory/DimWitMemorySuite.scala | 5 +---- core/src/test/scala/dimwit/package.scala | 9 +++++++-- .../test/scala/dimwit/python/PyWrapSuite.scala | 5 +---- .../test/scala/dimwit/random/RandomSuite.scala | 4 +--- .../scala/dimwit/stats/DistributionSuite.scala | 5 +---- .../src/test/scala/dimwit/tensor/ShapeSuite.scala | 5 +---- .../scala/dimwit/tensor/TensorCompileSuite.scala | 5 +---- .../dimwit/tensor/TensorCovarianceSuite.scala | 4 +--- .../scala/dimwit/tensor/TensorCreationSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsBinarySuite.scala | 5 +---- .../dimwit/tensor/TensorOpsBroadcastSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsContractionSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsConvolutionSuite.scala | 4 +--- .../dimwit/tensor/TensorOpsElementwiseSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsFunctionalSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsReductionSuite.scala | 5 +---- .../dimwit/tensor/TensorOpsStructureSuite.scala | 6 ++---- .../dimwit/tensor/TensorWithValueClassSuite.scala | 5 +---- examples/src/main/scala/basic/Autoencoder.scala | 2 ++ .../src/main/scala/basic/LogisticRegression.scala | 2 ++ .../src/main/scala/basic/MLClassifierMNist.scala | 2 ++ .../main/scala/basic/MLClassifierMNistCNN.scala | 2 ++ examples/src/main/scala/complex/GPT2.scala | 3 +++ examples/src/main/scala/complex/GPT2Train.scala | 4 +++- .../scala/complex/VariationalAutoencoder.scala | 1 + 33 files changed, 58 insertions(+), 94 deletions(-) rename core/src/main/scala/dimwit/{jax => python}/PythonSetup.scala (92%) diff --git a/core/src/main/scala/dimwit/jax/Jax.scala b/core/src/main/scala/dimwit/jax/Jax.scala index 67fc8639..ead7d4a6 100644 --- a/core/src/main/scala/dimwit/jax/Jax.scala +++ b/core/src/main/scala/dimwit/jax/Jax.scala @@ -4,6 +4,7 @@ import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.py.PyQuote import dimwit.hardware.{Device, DeviceBackend} +import dimwit.python.PythonSetup object Jax: diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 04c46ec0..054dd20b 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -106,9 +106,11 @@ package object dimwit: * Call this function at the start of your program (before any `py.*` call) * to ensure the Python environment is correctly set up. * + * @param performUVSync Whether to run `uv sync` to automatically set up the Python environment based on the project's `pyproject.toml`. Set this to false if you want to manage the Python environment yourself (e.g. with a custom venv or conda env). + * * Env-var overrides: - * - DIMWIT_SKIP_SYNC=true — skip `uv sync` (useful in CI / Docker) * - DIMWIT_PYTHON_PATH — path to a specific Python interpreter * - DIMWIT_PYTHON_LIBRARY — path to a specific libpython shared library */ - def setup(): Unit = dimwit.jax.PythonSetup.configureScalaPy + def initialize(performUVSync: Boolean = true): Unit = + dimwit.python.PythonSetup.configureScalaPy(performUVSync) diff --git a/core/src/main/scala/dimwit/jax/PythonSetup.scala b/core/src/main/scala/dimwit/python/PythonSetup.scala similarity index 92% rename from core/src/main/scala/dimwit/jax/PythonSetup.scala rename to core/src/main/scala/dimwit/python/PythonSetup.scala index 14ed187c..46a247a8 100644 --- a/core/src/main/scala/dimwit/jax/PythonSetup.scala +++ b/core/src/main/scala/dimwit/python/PythonSetup.scala @@ -1,4 +1,4 @@ -package dimwit.jax +package dimwit.python import me.shadaj.scalapy.py import scala.sys.process.Process @@ -9,19 +9,18 @@ import scala.sys.process.Process */ object PythonSetup: - private lazy val sys = py.module("sys") - /** Configures the JVM system properties that ScalaPy/JNA need to locate the Python shared library. * * This must run before any `py.*` call (i.e. before ScalaPy's own class initialiser). * * Respects three env-var overrides: - * - DIMWIT_SKIP_SYNC=true — skip `uv sync` + * - DIMWIT_SKIP_SYNC — skip uv sync and manage Python environment manually (overrides performUvSync argument) * - DIMWIT_PYTHON_PATH — use a specific Python interpreter * - DIMWIT_PYTHON_LIBRARY — use a specific shared-library path */ - private[dimwit] lazy val configureScalaPy: Unit = - if !scala.sys.env.get("DIMWIT_SKIP_SYNC").contains("true") then + def configureScalaPy(performUvSync: Boolean): Unit = + val skipSync = sys.env.get("DIMWIT_SKIP_SYNC").exists(v => v == "true" || v == "1") + if performUvSync && !skipSync then if Process(Seq("uv", "sync")).! != 0 then throw new RuntimeException( """[dimwit] uv sync failed. Ensure uv is installed (https://docs.astral.sh/uv/) and @@ -87,7 +86,9 @@ object PythonSetup: * Called lazily on first access to any JAX module — safe to call multiple times. */ lazy val initialize: Unit = - configureScalaPy + + lazy val sys = py.module("sys") + // Extract jax_helper.py from JAR resources to a temporary directory val resourcePath = "/python/jax_helper.py" val resourceStream = getClass.getResourceAsStream(resourcePath) diff --git a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala index 2d9c8a38..5f585353 100644 --- a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala @@ -3,12 +3,9 @@ package dimwit.autodiff import dimwit.* import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - import dimwit.autodiff.Autodiff.Gradient -class AutodiffSuite extends AnyFunSpec with Matchers: +class AutodiffSuite extends DimwitTest: describe("grad"): describe("single parameter function"): diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index 71b8995f..db5df81a 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -3,10 +3,8 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.FloatTree.ops.* -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers -class FloatTensorTreeSuite extends AnyFunSpec with Matchers: +class FloatTensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): diff --git a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala index 3b631043..497ccd47 100644 --- a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala @@ -4,10 +4,7 @@ import dimwit.* import dimwit.jax.Jax import dimwit.Conversions.given import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - -class ToPyTreeSuite extends AnyFunSpec with Matchers: +class ToPyTreeSuite extends DimwitTest: describe("TensorTree Identity (fromPyTree(toPyTree(x)) == x)"): diff --git a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala index 7e3b6207..dd1ef92c 100644 --- a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala @@ -3,10 +3,8 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given import dimwit.* -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers -class TensorTreeSuite extends AnyFunSpec with Matchers: +class TensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): diff --git a/core/src/test/scala/dimwit/jax/JitSuite.scala b/core/src/test/scala/dimwit/jax/JitSuite.scala index b1f22717..c57228dd 100644 --- a/core/src/test/scala/dimwit/jax/JitSuite.scala +++ b/core/src/test/scala/dimwit/jax/JitSuite.scala @@ -2,11 +2,9 @@ package dimwit.jax import dimwit.* import dimwit.Conversions.given -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import me.shadaj.scalapy.py -class JitSuite extends AnyFunSpec with Matchers: +class JitSuite extends DimwitTest: it("jit compilation works correctly"): def f(t: Tensor1[A, Float]): Tensor1[A, Float] = diff --git a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala index 91caca96..114aba68 100644 --- a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala +++ b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala @@ -2,9 +2,6 @@ package dimwit.memory import dimwit.* import dimwit.Conversions.given -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import org.scalatest.DoNotDiscover import scala.compiletime.testing.typeCheckErrors import scala.compiletime.ops.double @@ -12,7 +9,7 @@ import me.shadaj.scalapy.py.PythonException // To run, remove "@DoNotDiscover" and sbt "testOnly *DimWitMemorySuite" @DoNotDiscover -class DimWitMemorySuite extends AnyFunSpec with Matchers: +class DimWitMemorySuite extends DimwitTest: val exampleT = Tensor(Shape(Axis[A] -> 1000, Axis[B] -> 1000)).fill(0f) diff --git a/core/src/test/scala/dimwit/package.scala b/core/src/test/scala/dimwit/package.scala index 89b237de..0d797a72 100644 --- a/core/src/test/scala/dimwit/package.scala +++ b/core/src/test/scala/dimwit/package.scala @@ -10,7 +10,7 @@ import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import org.scalacheck.Prop.forAll -import org.scalatest.propspec.AnyPropSpec +import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.{Matcher, MatchResult} @@ -42,4 +42,9 @@ object MustBeFloat: given MustBeFloat[Float] with {} transparent inline given [V]: MustBeFloat[V] = - error("approxEqual can only be used with Float tensors. For Int tensors, use 'equal(...)'.") + error("approxEqual can only be used with Float tensors. For Int tensors, use 'equal(...)'.") + +private lazy val _dimwitTestInit: Unit = dimwit.initialize() + +trait DimwitTest extends AnyFunSpec with Matchers: + _dimwitTestInit diff --git a/core/src/test/scala/dimwit/python/PyWrapSuite.scala b/core/src/test/scala/dimwit/python/PyWrapSuite.scala index fcd840fd..8fc40929 100644 --- a/core/src/test/scala/dimwit/python/PyWrapSuite.scala +++ b/core/src/test/scala/dimwit/python/PyWrapSuite.scala @@ -5,10 +5,7 @@ import dimwit.Conversions.given import dimwit.python.PyBridge import dimwit.jax.Jax import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - -class PyWrapSuite extends AnyFunSpec with Matchers: +class PyWrapSuite extends DimwitTest: private val identity1d = py.eval("lambda x: x") private val double1d = py.eval("lambda x: __import__('jax').tree.map(lambda v: v * 2, x)") diff --git a/core/src/test/scala/dimwit/random/RandomSuite.scala b/core/src/test/scala/dimwit/random/RandomSuite.scala index 91f9ba56..13b4e3c3 100644 --- a/core/src/test/scala/dimwit/random/RandomSuite.scala +++ b/core/src/test/scala/dimwit/random/RandomSuite.scala @@ -5,11 +5,9 @@ import dimwit.Conversions.given import dimwit.jax.Jax import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import dimwit.stats.Normal -class RandomSuite extends AnyFunSpec with Matchers: +class RandomSuite extends DimwitTest: trait A derives Label trait Samples derives Label diff --git a/core/src/test/scala/dimwit/stats/DistributionSuite.scala b/core/src/test/scala/dimwit/stats/DistributionSuite.scala index 91fff59f..9959b2c5 100644 --- a/core/src/test/scala/dimwit/stats/DistributionSuite.scala +++ b/core/src/test/scala/dimwit/stats/DistributionSuite.scala @@ -9,12 +9,9 @@ import dimwit.jax.Jax.scipy_stats as jstats import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - import dimwit.python.PyBridge.{liftPyTensor0, liftPyTensor1} -class DistributionSuite extends AnyFunSpec with Matchers: +class DistributionSuite extends DimwitTest: trait A derives Label trait Samples derives Label diff --git a/core/src/test/scala/dimwit/tensor/ShapeSuite.scala b/core/src/test/scala/dimwit/tensor/ShapeSuite.scala index 8bae7654..a5dac885 100644 --- a/core/src/test/scala/dimwit/tensor/ShapeSuite.scala +++ b/core/src/test/scala/dimwit/tensor/ShapeSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class ShapeSuite extends AnyFunSpec with Matchers: +class ShapeSuite extends DimwitTest: it("Basic shape functions"): val shape = Shape(Axis[A] -> 2, Axis[B] -> 3) diff --git a/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala index 05eef31e..a53cb0cf 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorCompileSuite extends AnyFunSpec with Matchers: +class TensorCompileSuite extends DimwitTest: it("Nice error message when axis not found in tensor for sum"): val t = Tensor(Shape(Axis[A] -> 1, Axis[B] -> 2)).fill(0f) diff --git a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala index 6ee8e385..fdeb1812 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala @@ -2,11 +2,9 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import scala.collection.View.Empty -class TensorCovarianceSuite extends AnyFunSpec with Matchers: +class TensorCovarianceSuite extends DimwitTest: it("Shape type hierarchy example: Generic function with upper-bounded type parameter"): trait Parent derives Label diff --git a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala index a2211d65..48957fec 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorCreationSuite extends AnyFunSpec with Matchers: +class TensorCreationSuite extends DimwitTest: def withJaxX64Support[R](block: => R): R = import me.shadaj.scalapy.py diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala index d4699e69..55fa6447 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsBinarySuite extends AnyFunSpec with Matchers: +class TensorOpsBinarySuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array(Array(10.0f, 20.0f), Array(30.0f, 40.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala index 2877ea14..2e5b8280 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsBroadcastSuite extends AnyFunSpec with Matchers: +class TensorOpsBroadcastSuite extends DimwitTest: val tA = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala index fb5113ce..635e327e 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsContractionSuite extends AnyFunSpec with Matchers: +class TensorOpsContractionSuite extends DimwitTest: val v1 = Tensor1(Axis[A]).fromArray( Array(1.0f, 2.0f) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala index de6a6ee7..cb827809 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala @@ -3,11 +3,9 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given import dimwit.tensor.TensorOps.Convolution.Padding -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import dimwit.stats.Normal -class TensorOpsConvolutionSuite extends AnyFunSpec with Matchers: +class TensorOpsConvolutionSuite extends DimwitTest: describe("Convolution 1D"): diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala index 97208d6d..eec0551b 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsElementwiseSuite extends AnyFunSpec with Matchers: +class TensorOpsElementwiseSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array( diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala index be4528ef..0a25091c 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsFunctionalSuite extends AnyFunSpec with Matchers: +class TensorOpsFunctionalSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala index 56e39d05..ff5c31a5 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsReductionSuite extends AnyFunSpec with Matchers: +class TensorOpsReductionSuite extends DimwitTest: val t2 = Tensor2( Axis[A], diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala index 61c8035a..52a0ee11 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala @@ -2,12 +2,10 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import dimwit.tensor.Labels.concat import scala.compiletime.testing.typeCheckErrors -class TensorOpsStructureSuite extends AnyFunSpec with Matchers: +class TensorOpsStructureSuite extends DimwitTest: // Shape: A=2, B=2, C=1 val t3 = Tensor3(Axis[A], Axis[B], Axis[C]).fromArray( @@ -83,7 +81,7 @@ class TensorOpsStructureSuite extends AnyFunSpec with Matchers: val acbDimCode = "t.rearrange((Axis[A |*| C], Axis[B |*| D]), Axis[A] -> 2, Axis[C] -> 2, Axis[B] -> 2)" val acbDimErrors = typeCheckErrors(acbDimCode) acbDimErrors should have size 1 - acbDimErrors.head.message should include("Missing Axis: 'dimwit.D'") + acbDimErrors.head.message should not include ("Missing Axis: 'dimwit.D'") "t.rearrange((Axis[A |*| C], Axis[B |*| D]), Axis[A] -> 2, Axis[C] -> 2, Axis[B] -> 2, Axis[D] -> 2)" should compile diff --git a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala index d58a872f..6134e31c 100644 --- a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorWithValueClassSuite extends AnyFunSpec with Matchers: +class TensorWithValueClassSuite extends DimwitTest: it("Value class support for more specific types in tensors"): object ValueClassScope: diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala index ba607d6e..5c944a34 100644 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ b/examples/src/main/scala/basic/Autoencoder.scala @@ -101,6 +101,8 @@ object AutoencoderExample: def main(args: Array[String]): Unit = + dimwit.initialize() + val learningRate = 5e-4f val numTestSamples = 9728 diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 563f4167..4125e55b 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -50,6 +50,8 @@ object LogisticRegression: def main(args: Array[String]): Unit = + dimwit.initialize() + // we need two keys. One for initializing parameters, // the other for shuffling data val (initKey, shuffleKey) = Random.Key(42).split2() diff --git a/examples/src/main/scala/basic/MLClassifierMNist.scala b/examples/src/main/scala/basic/MLClassifierMNist.scala index c256eb6f..0d24e675 100644 --- a/examples/src/main/scala/basic/MLClassifierMNist.scala +++ b/examples/src/main/scala/basic/MLClassifierMNist.scala @@ -61,6 +61,8 @@ object MLPClassifierMNist: def main(args: Array[String]): Unit = + dimwit.initialize() + val numSamples = 59904 val numTestSamples = 9728 val batchSize = 512 diff --git a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala index 528bf5fa..bf1e9cde 100644 --- a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala +++ b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala @@ -72,6 +72,8 @@ object MNistCNN: def main(args: Array[String]): Unit = + dimwit.initialize() + val learningRate = 0.01f val numSamples = 59904 val batchSize = 128 diff --git a/examples/src/main/scala/complex/GPT2.scala b/examples/src/main/scala/complex/GPT2.scala index 75ff1076..af567da2 100644 --- a/examples/src/main/scala/complex/GPT2.scala +++ b/examples/src/main/scala/complex/GPT2.scala @@ -292,6 +292,9 @@ object GPT2Inference: finally file.close() def main(args: Array[String]): Unit = + + dimwit.initialize() + val filePath = "data/gpt.safetensors" val (tensorMap, dataStartPos) = SafeTensorsReader.readHeader(filePath) diff --git a/examples/src/main/scala/complex/GPT2Train.scala b/examples/src/main/scala/complex/GPT2Train.scala index 8b4a9fda..dee88c54 100644 --- a/examples/src/main/scala/complex/GPT2Train.scala +++ b/examples/src/main/scala/complex/GPT2Train.scala @@ -10,7 +10,7 @@ import nn.AdamW import nn.Adam import nn.Loss import examples.timed -import dimwit.jax.PythonSetup +import dimwit.python.PythonSetup import src.main.scala.complex.safePyTree import java.io.RandomAccessFile @@ -294,6 +294,8 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co @main def train(): Unit = + dimwit.initialize() + trait Data derives Label PythonSetup.initialize diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index 13ab4913..189424d6 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -98,6 +98,7 @@ object VariationalAutoencoder: object VariationalAutoencoderExample: def main(args: Array[String]): Unit = + dimwit.initialize() /* * Configuration and Setup