From 55bfcb1223a2299410bd82320ea90215754514cc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 24 Mar 2026 18:51:12 +0530 Subject: [PATCH] feat: enable inline linear SCCs by default --- lib/ModelingToolkitTearing/src/reassemble.jl | 51 +++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/lib/ModelingToolkitTearing/src/reassemble.jl b/lib/ModelingToolkitTearing/src/reassemble.jl index a9c1b97..95d0de5 100644 --- a/lib/ModelingToolkitTearing/src/reassemble.jl +++ b/lib/ModelingToolkitTearing/src/reassemble.jl @@ -426,6 +426,7 @@ function generate_system_equations!(state::TearingState, neweqs::Vector{Equation end digraph = DiCMOBiGraph{false}(graph, var_eq_matching) + has_inline_linsols = false for (i, scc) in enumerate(var_sccs) # note that the `vscc <-> escc` relation is a set-to-set mapping, and not # point-to-point. @@ -447,6 +448,7 @@ function generate_system_equations!(state::TearingState, neweqs::Vector{Equation linsol_result = get_linear_scc_linsol(state, escc, vscc, neweqs, var_eq_matching, total_sub, analytical_linear_scc_limit, simplify) end if linsol_result isa Tuple{SymbolicT, BitVector, BitVector} + has_inline_linsols = true linsol, eqs_mask, vars_mask = linsol_result @assert length(eqs_mask) == length(escc) @assert length(vars_mask) == length(vscc) @@ -501,6 +503,11 @@ function generate_system_equations!(state::TearingState, neweqs::Vector{Equation codegen_equation!(eq_generator, neweqs[eq], eq, var; simplify) end + if has_inline_linsols + sys = SU.setmetadata(sys, MTKBase.SymbolicADDisallowed, "Inline linear SCCs are enabled in `mtkcompile`") + state.sys = sys + end + (; neweqs′, eq_ordering, var_ordering, solved_eqs, solved_vars) = eq_generator is_diff_eq = .!iszero.(var_ordering) @@ -1264,14 +1271,44 @@ $TYPEDFIELDS """ Whether SCCs which are linear systems of the associated variables should be handled using inline linear solves via `LinearSolve.jl`. By default, such - SCCs generate algebraic equations. + SCCs generate algebraic equations. Not supported for nonlinear systems - use + `SCCNonlinearProblem` instead. Also unsupported for discrete systems. The + default of `nothing` enables it for time-dependent non-discrete systems. """ - inline_linear_sccs::Bool = false + inline_linear_sccs::Union{Bool, Nothing} = nothing """ If `inline_linear_sccs == true`, this is the maximum size of a system of linear equations which is solved symbolically rather than using `LinearSolve.jl`. """ analytical_linear_scc_limit::Int = 2 + + function DefaultReassembleAlgorithm( + simplify, array_hack, inline_linear_sccs, analytical_linear_scc_limit + ) + if inline_linear_sccs && analytical_linear_scc_limit < 1 + throw(ArgumentError("`analytical_linear_scc_limit` cannot be less than 1.")) + end + return new(simplify, array_hack, inline_linear_sccs, analytical_linear_scc_limit) + end +end + +struct InlineLinearSCCsUnsupportedError <: Exception + reason::String +end + +function Base.showerror(io::IO, err::InlineLinearSCCsUnsupportedError) + print(io, """ + Inline linear SCCs are not supported for this system for the following reason: + + $(err.reason) + + This can be fixed by not passing `reassemble_alg` to `mtkcompile`. Alternatively, + ensure to pass `inline_linear_sccs = false` to the `DefaultReassembleAlgorithm` + provided to `reassemble_alg`. For example: + ```julia + mtkcompile(sys; reassemble_alg = ModelingToolkitTearing.DefaultReassembleAlgorithm(; inline_linear_sccs = false)) + ``` + """) end function (alg::DefaultReassembleAlgorithm)(state::TearingState, @@ -1306,6 +1343,16 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState, inline_linear_sccs = false end + # `nothing` enables this for supported systems. + inline_linear_sccs = something(inline_linear_sccs, iv isa SymbolicT && D isa Differential) + if inline_linear_sccs + if iv === nothing + throw(InlineLinearSCCsUnsupportedError("System is time-independent")) + elseif !(D isa Differential) + throw(InlineLinearSCCsUnsupportedError("System is discrete")) + end + end + extra_unknowns = state.fullvars[extra_eqs_vars[2]] if StateSelection.is_only_discrete(state.structure) var_sccs = add_additional_history!(