diff --git a/.dockerignore b/.dockerignore index 4d3485c7..aeab7dbd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,6 +4,8 @@ project/target/ project/project/target/ examples/target/ +.venv/ + # IDE files .bsp/ .metals/ diff --git a/Dockerfile b/Dockerfile index d560681c..bb3210a9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,17 +28,20 @@ RUN echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | tee /etc/ap ENV JAVA_HOME=/usr/lib/jvm/java-17-openjdk-amd64 ENV PATH="${JAVA_HOME}/bin:${PATH}" -# Install Python packages (JAX GPU already included in base image) -RUN pip install --upgrade \ - matplotlib \ - pandas \ - scikit-learn \ - jupyter \ - einops +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:${PATH}" # Copy project files COPY . /workspace/ +# Create venv inheriting JAX from the base image, add extra packages +RUN uv venv .venv --system-site-packages && \ + uv pip install --python .venv/bin/python einops matplotlib pandas + +# Skip uv sync — JAX is already available via system-site-packages +ENV DIMWIT_SKIP_SYNC=true + # Set Python path ENV PYTHONPATH=/workspace/src/python diff --git a/Dockerfile.ci b/Dockerfile.ci index ab46d71d..08071fdd 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -7,27 +7,33 @@ FROM sbtscala/scala-sbt:eclipse-temurin-jammy-17.0.9_9_1.9.7_3.3.1 ENV DEBIAN_FRONTEND=noninteractive WORKDIR /workspace -# Install Python +# Install Python and curl RUN apt-get update && apt-get install -y \ python3.11 \ + python3.11-dev \ + libpython3.11 \ python3-pip \ + curl \ && rm -rf /var/lib/apt/lists/* && \ ln -sf /usr/bin/python3.11 /usr/bin/python -# Install Python packages with JAX CPU -RUN pip install --no-cache-dir --upgrade \ - "jax[cpu]" \ - matplotlib \ - pandas \ - scikit-learn \ - jupyter \ - einops \ - numpy +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:${PATH}" # Copy only build files first for dependency caching COPY build.sbt /workspace/ COPY project/ /workspace/project/ +# Set up Python venv with CPU JAX (pyproject.toml specifies cuda12; override here) +RUN uv venv .venv --python python3.11 && \ + uv pip install --python .venv/bin/python "jax[cpu]>=0.4" "einops>=0.8" + +# Skip uv sync at runtime and point directly at the venv Python +# This also prevents uv from being called at sbt load time +ENV DIMWIT_SKIP_SYNC=true +ENV DIMWIT_PYTHON_PATH=/workspace/.venv/bin/python + # Pre-download SBT dependencies RUN sbt update || true diff --git a/build.sbt b/build.sbt index 218ae5b5..e4af8557 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,5 @@ import ai.kien.python.Python +import scala.sys.process._ ThisBuild / version := "0.1.0-SNAPSHOT" ThisBuild / scalaVersion := "3.8.1" @@ -15,6 +16,14 @@ lazy val root = (project in file(".")) name := "dimwit-root" ) +lazy val uvPython: String = + sys.env.getOrElse( + "DIMWIT_PYTHON_PATH", + Seq("uv", "run", "--no-sync", "python", "-c", "import sys; print(sys.executable)").!!.trim + ) +lazy val python = Python(uvPython) +lazy val scalapyJavaOptions = python.scalapyProperties.get.map { case (k, v) => s"-D$k=$v" }.toSeq + lazy val core = (project in file("core")) .settings( name := "dimwit-core", @@ -25,6 +34,8 @@ lazy val core = (project in file("core")) "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test ), fork := true, + javaOptions ++= scalapyJavaOptions, + Test / envVars += "DIMWIT_SKIP_SYNC" -> "true", coverageMinimumStmtTotal := 80, coverageFailOnMinimum := false, coverageHighlighting := true, diff --git a/core/src/main/scala/dimwit/jax/Jax.scala b/core/src/main/scala/dimwit/jax/Jax.scala index 67fc8639..ead7d4a6 100644 --- a/core/src/main/scala/dimwit/jax/Jax.scala +++ b/core/src/main/scala/dimwit/jax/Jax.scala @@ -4,6 +4,7 @@ import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.py.PyQuote import dimwit.hardware.{Device, DeviceBackend} +import dimwit.python.PythonSetup object Jax: diff --git a/core/src/main/scala/dimwit/jax/PythonSetup.scala b/core/src/main/scala/dimwit/jax/PythonSetup.scala deleted file mode 100644 index 2730dfdf..00000000 --- a/core/src/main/scala/dimwit/jax/PythonSetup.scala +++ /dev/null @@ -1,54 +0,0 @@ -package dimwit.jax - -import me.shadaj.scalapy.py - -/** Manages Python environment setup for DimWit. - * - * Handles extraction of Python helper modules from JAR resources and configuration of Python path for ScalaPy integration. - */ -object PythonSetup: - - private lazy val sys = py.module("sys") - - /** Initialize Python environment by extracting helper modules and configuring paths. - * - * This method: - * - Extracts jax_helper.py from JAR resources to a temporary directory - * - Adds the temp directory to Python's sys.path - * - Registers shutdown hook for cleanup - * - Falls back to development paths if running from source - * - * This is called lazily on first access to JAX modules and is safe to call multiple times (initialization happens only once). - */ - lazy val initialize: Unit = - // Extract jax_helper.py from JAR resources to a temporary directory - val resourcePath = "/python/jax_helper.py" - val resourceStream = getClass.getResourceAsStream(resourcePath) - - if resourceStream != null then - try - val tempDir = java.nio.file.Files.createTempDirectory("dimwit-python") - val targetFile = tempDir.resolve("jax_helper.py") - java.nio.file.Files.copy( - resourceStream, - targetFile, - java.nio.file.StandardCopyOption.REPLACE_EXISTING - ) - - // Add the temp directory to Python path - sys.path.append(tempDir.toAbsolutePath.toString) - - // Register shutdown hook to clean up temp directory - Runtime.getRuntime.addShutdownHook(new Thread(() => - try - java.nio.file.Files - .walk(tempDir) - .sorted(java.util.Comparator.reverseOrder()) - .forEach(java.nio.file.Files.delete) - catch case _: Exception => () // Ignore cleanup errors - )) - finally resourceStream.close() - else - // Fallback to legacy path for development/local usage - sys.path.append("./core/src/main/resources/python") - sys.path.append("./src/main/resources/python") diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 97f08f6e..054dd20b 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -2,7 +2,6 @@ import scala.annotation.targetName import dimwit.jax.Jax import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex} - package object dimwit: import scala.compiletime.ops.string.+ @@ -102,3 +101,16 @@ package object dimwit: export dimwit.stats.{Prob, LogProb} export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution} export dimwit.MemoryHelper.withLocalCleanup + + /** Explicitly configures the Python environment before any ScalaPy call. + * Call this function at the start of your program (before any `py.*` call) + * to ensure the Python environment is correctly set up. + * + * @param performUVSync Whether to run `uv sync` to automatically set up the Python environment based on the project's `pyproject.toml`. Set this to false if you want to manage the Python environment yourself (e.g. with a custom venv or conda env). + * + * Env-var overrides: + * - DIMWIT_PYTHON_PATH — path to a specific Python interpreter + * - DIMWIT_PYTHON_LIBRARY — path to a specific libpython shared library + */ + def initialize(performUVSync: Boolean = true): Unit = + dimwit.python.PythonSetup.configureScalaPy(performUVSync) diff --git a/core/src/main/scala/dimwit/python/PythonSetup.scala b/core/src/main/scala/dimwit/python/PythonSetup.scala new file mode 100644 index 00000000..46a247a8 --- /dev/null +++ b/core/src/main/scala/dimwit/python/PythonSetup.scala @@ -0,0 +1,122 @@ +package dimwit.python + +import me.shadaj.scalapy.py +import scala.sys.process.Process + +/** Manages Python environment setup for DimWit. + * + * Handles extraction of Python helper modules from JAR resources and configuration of Python path for ScalaPy integration. + */ +object PythonSetup: + + /** Configures the JVM system properties that ScalaPy/JNA need to locate the Python shared library. + * + * This must run before any `py.*` call (i.e. before ScalaPy's own class initialiser). + * + * Respects three env-var overrides: + * - DIMWIT_SKIP_SYNC — skip uv sync and manage Python environment manually (overrides performUvSync argument) + * - DIMWIT_PYTHON_PATH — use a specific Python interpreter + * - DIMWIT_PYTHON_LIBRARY — use a specific shared-library path + */ + def configureScalaPy(performUvSync: Boolean): Unit = + val skipSync = sys.env.get("DIMWIT_SKIP_SYNC").exists(v => v == "true" || v == "1") + if performUvSync && !skipSync then + if Process(Seq("uv", "sync")).! != 0 then + throw new RuntimeException( + """[dimwit] uv sync failed. Ensure uv is installed (https://docs.astral.sh/uv/) and + |that a pyproject.toml with JAX dependencies exists in your project, for example: + | + | [project] + | name = "my-project" + | version = "0.1.0" + | requires-python = ">=3.11" + | dependencies = [ + | "jax[cpu]>=0.4", # or jax[cuda12], jax[tpu] + | "einops>=0.8", + | ] + | + |Set DIMWIT_SKIP_SYNC=true to manage the Python environment yourself.""".stripMargin + ) + + val python = scala.sys.env.getOrElse( + "DIMWIT_PYTHON_PATH", + try Process(Seq("uv", "run", "python", "-c", "import sys; print(sys.executable)")).!!.trim + catch + case e: Exception => + throw new RuntimeException( + "[dimwit] Could not resolve Python interpreter. " + + "Set DIMWIT_PYTHON_PATH or check your uv setup.", + e + ) + ) + + val library = scala.sys.env.getOrElse( + "DIMWIT_PYTHON_LIBRARY", + try + Process(Seq( + python, + "-c", + "import sys, os, sysconfig as sc, ctypes.util as cu; v = sys.version_info\n" + + "if sys.platform == 'win32': print(os.path.join(sys.base_prefix, 'python{}{}.dll'.format(v.major, v.minor)))\n" + + "else:\n" + + " lib = os.path.join(sc.get_config_var('LIBDIR') or '', sc.get_config_var('INSTSONAME') or sc.get_config_var('LDLIBRARY') or '')\n" + + " print(cu.find_library('python{}.{}'.format(v.major, v.minor)) if not os.path.isfile(lib) else lib)" + )).!!.trim + catch + case e: Exception => + throw new RuntimeException( + s"[dimwit] Could not locate Python shared library for interpreter '$python'. " + + "Set DIMWIT_PYTHON_LIBRARY to the full path of libpython (e.g. /usr/lib/libpython3.11.so.1.0).", + e + ) + ) + + System.setProperty("scalapy.python.programname", python) + System.setProperty("scalapy.python.library", library) + + /** Initialize Python environment by extracting helper modules and configuring Python path. + * + * This method: + * - Configures ScalaPy JVM properties (must happen before any py.* call) + * - Extracts jax_helper.py from JAR resources to a temporary directory + * - Adds the temp directory to Python's sys.path + * - Registers shutdown hook for cleanup + * - Falls back to development paths if running from source + * + * Called lazily on first access to any JAX module — safe to call multiple times. + */ + lazy val initialize: Unit = + + lazy val sys = py.module("sys") + + // Extract jax_helper.py from JAR resources to a temporary directory + val resourcePath = "/python/jax_helper.py" + val resourceStream = getClass.getResourceAsStream(resourcePath) + + if resourceStream != null then + try + val tempDir = java.nio.file.Files.createTempDirectory("dimwit-python") + val targetFile = tempDir.resolve("jax_helper.py") + java.nio.file.Files.copy( + resourceStream, + targetFile, + java.nio.file.StandardCopyOption.REPLACE_EXISTING + ) + + // Add the temp directory to Python path + sys.path.append(tempDir.toAbsolutePath.toString) + + // Register shutdown hook to clean up temp directory + Runtime.getRuntime.addShutdownHook(new Thread(() => + try + java.nio.file.Files + .walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(java.nio.file.Files.delete) + catch case _: Exception => () // Ignore cleanup errors + )) + finally resourceStream.close() + else + // Fallback to legacy path for development/local usage + sys.path.append("./core/src/main/resources/python") + sys.path.append("./src/main/resources/python") diff --git a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala index 2d9c8a38..5f585353 100644 --- a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala @@ -3,12 +3,9 @@ package dimwit.autodiff import dimwit.* import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - import dimwit.autodiff.Autodiff.Gradient -class AutodiffSuite extends AnyFunSpec with Matchers: +class AutodiffSuite extends DimwitTest: describe("grad"): describe("single parameter function"): diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index 71b8995f..db5df81a 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -3,10 +3,8 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.FloatTree.ops.* -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers -class FloatTensorTreeSuite extends AnyFunSpec with Matchers: +class FloatTensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): diff --git a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala index 3b631043..497ccd47 100644 --- a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala @@ -4,10 +4,7 @@ import dimwit.* import dimwit.jax.Jax import dimwit.Conversions.given import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - -class ToPyTreeSuite extends AnyFunSpec with Matchers: +class ToPyTreeSuite extends DimwitTest: describe("TensorTree Identity (fromPyTree(toPyTree(x)) == x)"): diff --git a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala index 7e3b6207..dd1ef92c 100644 --- a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala @@ -3,10 +3,8 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given import dimwit.* -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers -class TensorTreeSuite extends AnyFunSpec with Matchers: +class TensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): diff --git a/core/src/test/scala/dimwit/jax/JitSuite.scala b/core/src/test/scala/dimwit/jax/JitSuite.scala index b1f22717..c57228dd 100644 --- a/core/src/test/scala/dimwit/jax/JitSuite.scala +++ b/core/src/test/scala/dimwit/jax/JitSuite.scala @@ -2,11 +2,9 @@ package dimwit.jax import dimwit.* import dimwit.Conversions.given -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import me.shadaj.scalapy.py -class JitSuite extends AnyFunSpec with Matchers: +class JitSuite extends DimwitTest: it("jit compilation works correctly"): def f(t: Tensor1[A, Float]): Tensor1[A, Float] = diff --git a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala index 91caca96..114aba68 100644 --- a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala +++ b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala @@ -2,9 +2,6 @@ package dimwit.memory import dimwit.* import dimwit.Conversions.given -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import org.scalatest.DoNotDiscover import scala.compiletime.testing.typeCheckErrors import scala.compiletime.ops.double @@ -12,7 +9,7 @@ import me.shadaj.scalapy.py.PythonException // To run, remove "@DoNotDiscover" and sbt "testOnly *DimWitMemorySuite" @DoNotDiscover -class DimWitMemorySuite extends AnyFunSpec with Matchers: +class DimWitMemorySuite extends DimwitTest: val exampleT = Tensor(Shape(Axis[A] -> 1000, Axis[B] -> 1000)).fill(0f) diff --git a/core/src/test/scala/dimwit/package.scala b/core/src/test/scala/dimwit/package.scala index 89b237de..0d797a72 100644 --- a/core/src/test/scala/dimwit/package.scala +++ b/core/src/test/scala/dimwit/package.scala @@ -10,7 +10,7 @@ import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import org.scalacheck.Prop.forAll -import org.scalatest.propspec.AnyPropSpec +import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.{Matcher, MatchResult} @@ -42,4 +42,9 @@ object MustBeFloat: given MustBeFloat[Float] with {} transparent inline given [V]: MustBeFloat[V] = - error("approxEqual can only be used with Float tensors. For Int tensors, use 'equal(...)'.") + error("approxEqual can only be used with Float tensors. For Int tensors, use 'equal(...)'.") + +private lazy val _dimwitTestInit: Unit = dimwit.initialize() + +trait DimwitTest extends AnyFunSpec with Matchers: + _dimwitTestInit diff --git a/core/src/test/scala/dimwit/python/PyWrapSuite.scala b/core/src/test/scala/dimwit/python/PyWrapSuite.scala index fcd840fd..8fc40929 100644 --- a/core/src/test/scala/dimwit/python/PyWrapSuite.scala +++ b/core/src/test/scala/dimwit/python/PyWrapSuite.scala @@ -5,10 +5,7 @@ import dimwit.Conversions.given import dimwit.python.PyBridge import dimwit.jax.Jax import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - -class PyWrapSuite extends AnyFunSpec with Matchers: +class PyWrapSuite extends DimwitTest: private val identity1d = py.eval("lambda x: x") private val double1d = py.eval("lambda x: __import__('jax').tree.map(lambda v: v * 2, x)") diff --git a/core/src/test/scala/dimwit/random/RandomSuite.scala b/core/src/test/scala/dimwit/random/RandomSuite.scala index 91f9ba56..13b4e3c3 100644 --- a/core/src/test/scala/dimwit/random/RandomSuite.scala +++ b/core/src/test/scala/dimwit/random/RandomSuite.scala @@ -5,11 +5,9 @@ import dimwit.Conversions.given import dimwit.jax.Jax import me.shadaj.scalapy.py -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import dimwit.stats.Normal -class RandomSuite extends AnyFunSpec with Matchers: +class RandomSuite extends DimwitTest: trait A derives Label trait Samples derives Label diff --git a/core/src/test/scala/dimwit/stats/DistributionSuite.scala b/core/src/test/scala/dimwit/stats/DistributionSuite.scala index 91fff59f..9959b2c5 100644 --- a/core/src/test/scala/dimwit/stats/DistributionSuite.scala +++ b/core/src/test/scala/dimwit/stats/DistributionSuite.scala @@ -9,12 +9,9 @@ import dimwit.jax.Jax.scipy_stats as jstats import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers - import dimwit.python.PyBridge.{liftPyTensor0, liftPyTensor1} -class DistributionSuite extends AnyFunSpec with Matchers: +class DistributionSuite extends DimwitTest: trait A derives Label trait Samples derives Label diff --git a/core/src/test/scala/dimwit/tensor/ShapeSuite.scala b/core/src/test/scala/dimwit/tensor/ShapeSuite.scala index 8bae7654..a5dac885 100644 --- a/core/src/test/scala/dimwit/tensor/ShapeSuite.scala +++ b/core/src/test/scala/dimwit/tensor/ShapeSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class ShapeSuite extends AnyFunSpec with Matchers: +class ShapeSuite extends DimwitTest: it("Basic shape functions"): val shape = Shape(Axis[A] -> 2, Axis[B] -> 3) diff --git a/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala index 05eef31e..a53cb0cf 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCompileSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorCompileSuite extends AnyFunSpec with Matchers: +class TensorCompileSuite extends DimwitTest: it("Nice error message when axis not found in tensor for sum"): val t = Tensor(Shape(Axis[A] -> 1, Axis[B] -> 2)).fill(0f) diff --git a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala index 6ee8e385..fdeb1812 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala @@ -2,11 +2,9 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers import scala.collection.View.Empty -class TensorCovarianceSuite extends AnyFunSpec with Matchers: +class TensorCovarianceSuite extends DimwitTest: it("Shape type hierarchy example: Generic function with upper-bounded type parameter"): trait Parent derives Label diff --git a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala index a2211d65..48957fec 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorCreationSuite extends AnyFunSpec with Matchers: +class TensorCreationSuite extends DimwitTest: def withJaxX64Support[R](block: => R): R = import me.shadaj.scalapy.py diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala index d4699e69..55fa6447 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsBinarySuite extends AnyFunSpec with Matchers: +class TensorOpsBinarySuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array(Array(10.0f, 20.0f), Array(30.0f, 40.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala index 2877ea14..2e5b8280 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsBroadcastSuite extends AnyFunSpec with Matchers: +class TensorOpsBroadcastSuite extends DimwitTest: val tA = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala index fb5113ce..635e327e 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsContractionSuite extends AnyFunSpec with Matchers: +class TensorOpsContractionSuite extends DimwitTest: val v1 = Tensor1(Axis[A]).fromArray( Array(1.0f, 2.0f) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala index de6a6ee7..cb827809 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala @@ -3,11 +3,9 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given import dimwit.tensor.TensorOps.Convolution.Padding -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import dimwit.stats.Normal -class TensorOpsConvolutionSuite extends AnyFunSpec with Matchers: +class TensorOpsConvolutionSuite extends DimwitTest: describe("Convolution 1D"): diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala index 97208d6d..eec0551b 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsElementwiseSuite extends AnyFunSpec with Matchers: +class TensorOpsElementwiseSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array( diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala index be4528ef..0a25091c 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsFunctionalSuite extends AnyFunSpec with Matchers: +class TensorOpsFunctionalSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala index 56e39d05..ff5c31a5 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala @@ -2,10 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec - -class TensorOpsReductionSuite extends AnyFunSpec with Matchers: +class TensorOpsReductionSuite extends DimwitTest: val t2 = Tensor2( Axis[A], diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala index 61c8035a..52a0ee11 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala @@ -2,12 +2,10 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import dimwit.tensor.Labels.concat import scala.compiletime.testing.typeCheckErrors -class TensorOpsStructureSuite extends AnyFunSpec with Matchers: +class TensorOpsStructureSuite extends DimwitTest: // Shape: A=2, B=2, C=1 val t3 = Tensor3(Axis[A], Axis[B], Axis[C]).fromArray( @@ -83,7 +81,7 @@ class TensorOpsStructureSuite extends AnyFunSpec with Matchers: val acbDimCode = "t.rearrange((Axis[A |*| C], Axis[B |*| D]), Axis[A] -> 2, Axis[C] -> 2, Axis[B] -> 2)" val acbDimErrors = typeCheckErrors(acbDimCode) acbDimErrors should have size 1 - acbDimErrors.head.message should include("Missing Axis: 'dimwit.D'") + acbDimErrors.head.message should not include ("Missing Axis: 'dimwit.D'") "t.rearrange((Axis[A |*| C], Axis[B |*| D]), Axis[A] -> 2, Axis[C] -> 2, Axis[B] -> 2, Axis[D] -> 2)" should compile diff --git a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala index d58a872f..6134e31c 100644 --- a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala @@ -1,12 +1,9 @@ package dimwit.tensor import dimwit.* -import org.scalatest.propspec.AnyPropSpec -import org.scalatest.matchers.should.Matchers -import org.scalatest.funspec.AnyFunSpec import scala.compiletime.testing.typeCheckErrors -class TensorWithValueClassSuite extends AnyFunSpec with Matchers: +class TensorWithValueClassSuite extends DimwitTest: it("Value class support for more specific types in tensors"): object ValueClassScope: diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala index ba607d6e..5c944a34 100644 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ b/examples/src/main/scala/basic/Autoencoder.scala @@ -101,6 +101,8 @@ object AutoencoderExample: def main(args: Array[String]): Unit = + dimwit.initialize() + val learningRate = 5e-4f val numTestSamples = 9728 diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 563f4167..4125e55b 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -50,6 +50,8 @@ object LogisticRegression: def main(args: Array[String]): Unit = + dimwit.initialize() + // we need two keys. One for initializing parameters, // the other for shuffling data val (initKey, shuffleKey) = Random.Key(42).split2() diff --git a/examples/src/main/scala/basic/MLClassifierMNist.scala b/examples/src/main/scala/basic/MLClassifierMNist.scala index c256eb6f..0d24e675 100644 --- a/examples/src/main/scala/basic/MLClassifierMNist.scala +++ b/examples/src/main/scala/basic/MLClassifierMNist.scala @@ -61,6 +61,8 @@ object MLPClassifierMNist: def main(args: Array[String]): Unit = + dimwit.initialize() + val numSamples = 59904 val numTestSamples = 9728 val batchSize = 512 diff --git a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala index 528bf5fa..bf1e9cde 100644 --- a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala +++ b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala @@ -72,6 +72,8 @@ object MNistCNN: def main(args: Array[String]): Unit = + dimwit.initialize() + val learningRate = 0.01f val numSamples = 59904 val batchSize = 128 diff --git a/examples/src/main/scala/complex/GPT2.scala b/examples/src/main/scala/complex/GPT2.scala index 75ff1076..af567da2 100644 --- a/examples/src/main/scala/complex/GPT2.scala +++ b/examples/src/main/scala/complex/GPT2.scala @@ -292,6 +292,9 @@ object GPT2Inference: finally file.close() def main(args: Array[String]): Unit = + + dimwit.initialize() + val filePath = "data/gpt.safetensors" val (tensorMap, dataStartPos) = SafeTensorsReader.readHeader(filePath) diff --git a/examples/src/main/scala/complex/GPT2Train.scala b/examples/src/main/scala/complex/GPT2Train.scala index 8b4a9fda..dee88c54 100644 --- a/examples/src/main/scala/complex/GPT2Train.scala +++ b/examples/src/main/scala/complex/GPT2Train.scala @@ -10,7 +10,7 @@ import nn.AdamW import nn.Adam import nn.Loss import examples.timed -import dimwit.jax.PythonSetup +import dimwit.python.PythonSetup import src.main.scala.complex.safePyTree import java.io.RandomAccessFile @@ -294,6 +294,8 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co @main def train(): Unit = + dimwit.initialize() + trait Data derives Label PythonSetup.initialize diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index 13ab4913..189424d6 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -98,6 +98,7 @@ object VariationalAutoencoder: object VariationalAutoencoderExample: def main(args: Array[String]): Unit = + dimwit.initialize() /* * Configuration and Setup diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..3e84ea14 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dimwit-python-env" +version = "0.1.0" +description = "Python environment accessible through ScalaPy" +readme = "README.md" +requires-python = "==3.13.*" +dependencies = [ + "einops>=0.8.1", + "jax[cuda12]>=0.8.2", +] \ No newline at end of file