Skip to content

Commit 59d9f28

Browse files
authored
Merge pull request #147 from VirtualPlantLab/fix-regression-benchmark
Fixing #146: performance regression
2 parents f0e97f3 + 7595212 commit 59d9f28

9 files changed

Lines changed: 198 additions & 186 deletions

File tree

src/PlantSimEngine.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ include("component_models/RefVector.jl")
4343
# Simulation table (time-step table, from PlantMeteo):
4444
include("component_models/TimeStepTable.jl")
4545

46+
# Declaring the dependency graph
47+
include("dependencies/dependency_graph.jl")
48+
4649
# List of models:
4750
include("component_models/ModelList.jl")
4851
include("mtg/MultiScaleModel.jl")
@@ -53,8 +56,7 @@ include("component_models/get_status.jl")
5356
# Transform into a dataframe:
5457
include("dataframe.jl")
5558

56-
# Model dependencies:
57-
include("dependencies/dependency_graph.jl")
59+
# Computing model dependencies:
5860
include("dependencies/soft_dependencies.jl")
5961
include("dependencies/hard_dependencies.jl")
6062
include("dependencies/traversal.jl")

src/component_models/ModelList.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ julia> [typeof(models[i][1]) for i in keys(status(models))]
147147
struct ModelList{M<:NamedTuple,S}
148148
models::M
149149
status::S
150-
type_promotion::Union{Nothing, Dict}
150+
type_promotion::Union{Nothing,Dict}
151+
dependency_graph::DependencyGraph
151152
end
152153

153154
#=function ModelList(models::M, status::Status) where {M<:NamedTuple{names,T} where {names,T<:NTuple{N,<:AbstractModel} where {N}}}
@@ -160,7 +161,6 @@ function ModelList(
160161
status=nothing,
161162
type_promotion::Union{Nothing,Dict}=nothing,
162163
variables_check::Bool=true,
163-
nsteps=nothing,
164164
kwargs...
165165
)
166166

@@ -187,10 +187,13 @@ function ModelList(
187187
ts_kwargs = homogeneous_ts_kwargs(status)
188188
ts_kwargs = add_model_vars(ts_kwargs, mods, type_promotion)
189189

190+
191+
190192
model_list = ModelList(
191193
mods,
192194
ts_kwargs,
193-
type_promotion
195+
type_promotion,
196+
dep(; verbose=true, mods...)
194197
)
195198
variables_check && !is_initialized(model_list)
196199

@@ -219,8 +222,8 @@ function add_model_vars(x, models, type_promotion)
219222

220223
# If the user gave a status, we check if all the variables are already initialized:
221224
vars_in_x = status_keys(x)
222-
status_x =
223-
all([k in vars_in_x for k in keys(ref_vars)]) && return isa(x, Status) ? x : Status(x) # If so, we return the input
225+
status_x =
226+
all([k in vars_in_x for k in keys(ref_vars)]) && return isa(x, Status) ? x : Status(x) # If so, we return the input
224227

225228
# Else, we add the variables by making a new object (carefull, this is a copy so it takes more time):
226229

@@ -229,15 +232,15 @@ function add_model_vars(x, models, type_promotion)
229232

230233
# If the user gave an empty status, we initialize all variables to their default values:
231234
if x === nothing
232-
return Status(ref_vars)
235+
return Status(ref_vars)
233236
end
234-
237+
235238
if Tables.istable(x)
236239
# This situation only occurs if the user provided a table instead of a status
237240
# Meaning we have a status of vector values, all initialized up to a certain point
238241
# Unsure this is desirable, as that means run! does nothing or overwrites everything
239242
# Anyway, we wish to create a NamedTuple() of Vectors here
240-
x_full = (;zip(propertynames(x), Tables.columns(x))...)
243+
x_full = (; zip(propertynames(x), Tables.columns(x))...)
241244
x_full = merge(ref_vars, x_full)
242245

243246
else
@@ -286,7 +289,7 @@ PlantSimEngine.homogeneous_ts_kwargs((Tₗ=[25.0, 26.0], aPPFD=1000.0))
286289
function homogeneous_ts_kwargs(kwargs::NamedTuple{N,T}) where {N,T}
287290
length(kwargs) == 0 && return kwargs
288291
vars_vals = collect(Any, values(kwargs))
289-
292+
290293
vars_array = NamedTuple{keys(kwargs)}(j for j in vars_vals)
291294

292295
return vars_array
@@ -325,15 +328,17 @@ function Base.copy(m::T) where {T<:ModelList}
325328
ModelList(
326329
m.models,
327330
deepcopy(m.status),
328-
deepcopy(m.type_promotion)
331+
deepcopy(m.type_promotion),
332+
deepcopy(m.dependency_graph)
329333
)
330334
end
331335

332336
function Base.copy(m::T, status) where {T<:ModelList}
333337
ModelList(
334338
m.models,
335339
status,
336-
deepcopy(m.type_promotion)
340+
deepcopy(m.type_promotion),
341+
deepcopy(m.dependency_graph)
337342
)
338343
end
339344

@@ -465,7 +470,7 @@ function convert_vars!(mapped_vars::Dict{String,Dict{Symbol,Any}}, type_promotio
465470
end
466471

467472
function Base.show(io::IO, ::MIME"text/plain", t::ModelList)
468-
print(io, dep(t, verbose=false))
473+
print(io, dep(t))
469474
print(io, status(t))
470475
end
471476

src/component_models/Status.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,26 +135,28 @@ end
135135

136136
# Returns a status with all vector variables replaced with their first value (ie a Status ready for simulation)
137137
# also returns a tuple of symbols corresponding to the vector variables
138-
function flatten_status(s::Status)
139-
status_values_flattened = NamedTuple()
140-
vector_variables = NamedTuple()
141-
142-
for (var, value) in zip(keys(s), s)
143-
if length(value) > 1
144-
vector_variables = (vector_variables..., var)
145-
status_values_flattened = (status_values_flattened..., value[1])
146-
else
147-
status_values_flattened = (status_values_flattened..., value)
148-
end
138+
function flatten_status(s::Status{T}) where {T}
139+
n_vars_several_values = findall(x -> length(x) > 1, s)
140+
if length(n_vars_several_values) == 0
141+
return s, n_vars_several_values
142+
else
143+
return Status{keys(s)}(first.(values(s))), n_vars_several_values
149144
end
150-
151-
return Status(; zip(keys(s), status_values_flattened)...), vector_variables
152145
end
153146

154-
# Update to the next timestep the variables that were passed in as vectors by the user
155-
function update_vector_variables(s::Status, sf::Status, vector_variables, i)
156-
for vec in vector_variables
157-
sf[vec] = s[vec][i]
147+
"""
148+
set_variables_at_timestep!(status_timestep::Status, user_status::Status, variables_to_update, timestep)
149+
150+
Update `status_timestep` to the current values at the `timestep` for all `variables_to_update` in the status provided by the user (`user_status`).
151+
The variables to update are given in `variables_to_update`, which is a vector of symbols.
152+
153+
`status_timestep` is a status representing a single time-step. `user_status` is the status provided that gives values for variables that are not computed by any model.
154+
It may give constant values or vectors of values, in which case the `timestep` is used to select the value to use for the current time step.
155+
156+
"""
157+
function set_variables_at_timestep!(status_timestep::Status, user_status::Status, variables_to_update, timestep)
158+
for vec in variables_to_update
159+
status_timestep[vec] = user_status[vec][timestep]
158160
end
159161
end
160162

src/dependencies/dependencies.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
dep(::T, nsteps=1) where {T<:AbstractModel} = NamedTuple()
22

33
"""
4-
dep(m::ModelList, nsteps=1; verbose::Bool=true)
4+
dep(m::ModelList)
55
dep(mapping::Dict{String,T}; verbose=true)
6+
dep!(m::ModelList, nsteps=1)
67
78
Get the model dependency graph given a ModelList or a multiscale model mapping. If one graph is returned,
89
then all models are coupled. If several graphs are returned, then only the models inside each graph are coupled, and
@@ -34,7 +35,12 @@ to other scales if needed. Then we transform all these nodes into soft dependenc
3435
Then we traverse all these and we set nodes that need outputs from other nodes as inputs as children/parents.
3536
If a node has no dependency, it is set as a root node and pushed into a new Dict (independant_process_root). This Dict is the returned dependency graph. And
3637
it presents root nodes as independent starting points for the sub-graphs, which are the models that are coupled together. We can then traverse each of
37-
these graphs independently to retrieve the models that are coupled together, in the right order of execution.
38+
these graphs independently to r
39+
40+
# Notes
41+
42+
The difference between `dep(m::ModelList)` and `dep!(m::ModelList, nsteps)` is that the first one returns the dependency graph found in the model list, while the
43+
second one returns the dependency graph with the specified number of steps, modifying the simulation IDs of each node in the graph (`simulation_id=fill(0, nsteps)`).
3844
3945
# Examples
4046
@@ -75,8 +81,18 @@ function dep(nsteps=1; verbose::Bool=true, vars...)
7581
return deps
7682
end
7783

78-
function dep(m::ModelList, nsteps=1; verbose::Bool=true)
79-
dep(nsteps; verbose=verbose, m.models...)
84+
function dep(m::ModelList)
85+
m.dependency_graph
86+
end
87+
88+
function dep!(m::ModelList, nsteps=1)
89+
traverse_dependency_graph!(m.dependency_graph; visit_hard_dep=false) do node
90+
if length(node.simulation_id) != nsteps
91+
node.simulation_id = fill(0, nsteps)
92+
end
93+
end
94+
95+
return m.dependency_graph
8096
end
8197

8298

src/mtg/save_results.jl

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ julia> collect(keys(preallocated_vars["Leaf"]))
112112

113113
function pre_allocate_outputs(statuses, statuses_template, reverse_multiscale_mapping, vars_need_init, outs, nsteps; type_promotion=nothing, check=true)
114114
outs_ = Dict{String,Vector{Symbol}}()
115-
115+
116116
# default behaviour : track everything
117117
if isnothing(outs)
118118
for organ in keys(statuses)
@@ -130,7 +130,7 @@ function pre_allocate_outputs(statuses, statuses_template, reverse_multiscale_ma
130130
end
131131
end
132132

133-
len = Dict{String, Int}()
133+
len = Dict{String,Int}()
134134
for (organ, vals) in outs_
135135
len[organ] = length(outs_[organ])
136136
unique!(outs_[organ])
@@ -210,15 +210,15 @@ function pre_allocate_outputs(statuses, statuses_template, reverse_multiscale_ma
210210
node_type = only(node_type)
211211

212212
# I don't know if this function barrier is necessary
213-
preallocated_outputs = Dict{String, Vector}()
213+
preallocated_outputs = Dict{String,Vector}()
214214
complete_preallocation_from_types!(preallocated_outputs, nsteps, outs_, node_type, statuses_template)
215215
return preallocated_outputs
216216
end
217217

218218
function complete_preallocation_from_types!(preallocated_outputs, nsteps, outs_, node_type, statuses_template)
219219
types = Vector{DataType}()
220220
for organ in keys(outs_)
221-
221+
222222
outs_no_node = filter(x -> x != :node, outs_[organ])
223223

224224
#types = [typeof(status_from_template(statuses_template[organ])[var]) for var in outs[organ]]
@@ -230,14 +230,14 @@ function complete_preallocation_from_types!(preallocated_outputs, nsteps, outs_,
230230
symbols_tuple = (:timestep, :node, outs_no_node...,)
231231
# using node_type.parameters[1] is clunky, but covers both NodeMTG and AbstractNodeMTG types
232232
values_tuple = (1, MultiScaleTreeGraph.Node((node_type.parameters[1])("/", "Uninitialized", 0, 0),), values...,)
233-
233+
234234
# Dummy value to make accessing the type easier
235235
# (empty arrays don't have references to an instance, so their types can't be inspected and manipulated as easily)
236-
dummy_status = (;zip(symbols_tuple, values_tuple)...)
236+
dummy_status = (; zip(symbols_tuple, values_tuple)...)
237237
data = typeof(Status(dummy_status))[]
238238
resize!(data, nsteps)
239-
240-
for ii in 1:nsteps
239+
240+
for ii in 1:nsteps
241241
data[ii] = Status(dummy_status)
242242
end
243243
preallocated_outputs[organ] = data
@@ -278,22 +278,22 @@ function save_results!(object::GraphSimulation, i)
278278
# So there may be possible simplifications (maybe no need for a function barrier, perhaps the resizing could be made a one-liner...)
279279
# But this should work without causing visible performance regressions on XPalm
280280
len = length(outs[organ])
281-
if length(statuses[organ]) + index - 1 > len
281+
if length(statuses[organ]) + index - 1 > len
282282
min_required = max(length(statuses[organ]) + index - len, index)
283-
284-
extra_length = 2*min_required - len
283+
284+
extra_length = 2 * min_required - len
285285
data = eltype(outs[organ])[]
286286
resize!(data, extra_length)
287287
dummy_value = NamedTuple(outs[organ][1])
288288
# TODO set timestep to 0 for clarity ?
289-
289+
290290
# Using fill! caused Ref issues, so call a Status constructor here instead of passing a prebuilt value
291291
# This will avoid having all array entries point to the same ref but keep construction cost at a minimum
292292
for new_entry in 1:extra_length
293293
data[new_entry] = Status(dummy_value)
294294
end
295295

296-
outs[organ] = cat(outs[organ], data, dims=1)
296+
outs[organ] = cat(outs[organ], data, dims=1)
297297
#println("len : ", len, " statuses #", length(statuses[organ]), " index ", index)
298298
#println("min_required : ", min_required, " extra_length ", extra_length, " new len ", length(outs[organ]))
299299
end
@@ -316,15 +316,9 @@ function copy_tracked_outputs_into_vector!(outs_organ, i, statuses_organ, tracke
316316
return j
317317
end
318318

319-
320319
function pre_allocate_outputs(m::ModelList, outs, nsteps; type_promotion=nothing, check=true)
321-
322-
# NOTE : init_variables recreates a DependencyGraph, it's not great
323-
# TODO : copy ?
324-
out_vars_pre_type_promotion = merge(init_variables(m; verbose=false)...)
325-
326-
# bit hacky, could be cleaned up
327-
out_vars_all = convert_vars(out_vars_pre_type_promotion, m.type_promotion)
320+
st, = flatten_status(status(m))
321+
out_vars_all = convert_vars(st, type_promotion)
328322

329323
out_keys_requested = Symbol[]
330324
if !isnothing(outs)
@@ -337,9 +331,10 @@ function pre_allocate_outputs(m::ModelList, outs, nsteps; type_promotion=nothing
337331

338332
# default implicit behaviour, track everything
339333
if isempty(out_keys_requested)
340-
out_vars_requested = out_vars_all
334+
# We already have the status here, just repeating its value:
335+
out_vars_requested = NamedTuple(out_vars_all)
341336
else
342-
unexpected_outputs = setdiff(out_keys_requested, status_keys(status(m)))
337+
unexpected_outputs = setdiff(out_keys_requested, keys(st))
343338

344339
if !isempty(unexpected_outputs)
345340
e = string(
@@ -354,22 +349,21 @@ function pre_allocate_outputs(m::ModelList, outs, nsteps; type_promotion=nothing
354349
@info e
355350
[delete!(unexpected_outputs, i) for i in unexpected_outputs]
356351
end
357-
end
352+
end
358353

359354
out_defaults_requested = (out_vars_all[i] for i in out_keys_requested)
360-
out_vars_requested = (;zip(out_keys_requested, out_defaults_requested)...)
355+
out_vars_requested = (; zip(out_keys_requested, out_defaults_requested)...)
361356
end
362357

363-
outputs_timestep = fill(out_vars_requested, nsteps)
364-
return TimeStepTable([Status(i) for i in outputs_timestep])
358+
return TimeStepTable([Status(out_vars_requested) for i in Base.OneTo(nsteps)])
365359
end
366360

367361
function save_results!(status_flattened::Status, outputs, i)
368-
if length(outputs) == 0
369-
return
362+
if length(outputs) == 0
363+
return
370364
end
371365
outs = outputs[i]
372-
366+
373367
for var in keys(outs)
374368
outs[var] = status_flattened[var]
375369
end

src/processes/model_initialisation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ mapping = Dict(
6565
to_initialize(mapping)
6666
```
6767
"""
68-
function to_initialize(m::ModelList; verbose::Bool=true)
69-
needed_variables = to_initialize(dep(m; verbose=verbose))
68+
function to_initialize(m::ModelList)
69+
needed_variables = to_initialize(dep(m))
7070
to_init = Dict{Symbol,Tuple}()
7171
for (process, vars) in needed_variables
7272
# default_values = needed_variables[:process1]
@@ -245,7 +245,7 @@ function init_variables(model::T; verbose::Bool=true) where {T<:AbstractModel}
245245
end
246246

247247
function init_variables(m::ModelList; verbose::Bool=true)
248-
init_variables(dep(m; verbose=verbose))
248+
init_variables(dep(m))
249249
end
250250

251251
function init_variables(m::DependencyGraph)
@@ -300,7 +300,7 @@ is_initialized(models)
300300
```
301301
"""
302302
function is_initialized(m::T; verbose=true) where {T<:ModelList}
303-
var_names = to_initialize(m; verbose=verbose)
303+
var_names = to_initialize(m)
304304

305305
if any([length(to_init) > 0 for (process, to_init) in pairs(var_names)])
306306
verbose && @info "Some variables must be initialized before simulation: $var_names (see `to_initialize()`)" maxlog = 1

0 commit comments

Comments
 (0)