Skip to content

Commit a66a0f7

Browse files
authored
Merge pull request #954 from EnergySystemsModellingLab/topo_sort_cycles
Solve investment order for commodity graphs with cycles
2 parents 5bb5ec3 + 1b5c6ef commit a66a0f7

6 files changed

Lines changed: 910 additions & 785 deletions

File tree

src/graph.rs

Lines changed: 114 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
use crate::commodity::{CommodityID, CommodityMap, CommodityType};
33
use crate::process::{ProcessID, ProcessMap};
44
use crate::region::RegionID;
5+
use crate::simulation::investment::InvestmentSet;
56
use crate::time_slice::{TimeSliceInfo, TimeSliceLevel, TimeSliceSelection};
67
use crate::units::{Dimensionless, Flow};
7-
use anyhow::{Context, Result, anyhow, ensure};
8+
use anyhow::{Context, Result, ensure};
89
use indexmap::IndexSet;
9-
use itertools::{Itertools, iproduct};
10+
use itertools::iproduct;
1011
use petgraph::Directed;
11-
use petgraph::algo::toposort;
12+
use petgraph::algo::{condensation, toposort};
1213
use petgraph::dot::Dot;
1314
use petgraph::graph::{EdgeReference, Graph};
14-
use petgraph::visit::EdgeFiltered;
1515
use std::collections::HashMap;
1616
use std::fmt::Display;
1717
use std::fs::File;
@@ -301,43 +301,64 @@ fn validate_commodities_graph(
301301
/// Performs topological sort on the commodity graph to get the ordering for investments
302302
///
303303
/// The returned Vec only includes SVD and SED commodities.
304-
fn topo_sort_commodities(
304+
fn solve_investment_order(
305305
graph: &CommoditiesGraph,
306306
commodities: &CommodityMap,
307-
) -> Result<Vec<CommodityID>> {
308-
// We only consider primary edges
309-
let primary_graph =
310-
EdgeFiltered::from_fn(graph, |edge| matches!(edge.weight(), GraphEdge::Primary(_)));
311-
312-
// Perform a topological sort on the graph
313-
let order = toposort(&primary_graph, None).map_err(|cycle| {
314-
let cycle_commodity = graph.node_weight(cycle.node_id()).unwrap().clone();
315-
anyhow!("Cycle detected in commodity graph for commodity {cycle_commodity}")
316-
})?;
317-
318-
// We return the order in reverse so that leaf-node commodities are solved first
319-
// We also filter to only include SVD and SED commodities
320-
let order = order
321-
.iter()
322-
.rev()
323-
.filter_map(|node_idx| {
307+
) -> Vec<InvestmentSet> {
308+
// Filter the graph to only include SVD/SED commodities and primary edges
309+
let graph_filtered = graph.filter_map(
310+
// Consider only SVD/SED commodities
311+
|_, node_weight| {
324312
// Get the commodity for the node
325-
let GraphNode::Commodity(commodity_id) = graph.node_weight(*node_idx).unwrap() else {
313+
let GraphNode::Commodity(commodity_id) = node_weight else {
326314
// Skip special nodes
327315
return None;
328316
};
329317
let commodity = &commodities[commodity_id];
330-
331-
// Only include SVD and SED commodities
332318
matches!(
333319
commodity.kind,
334320
CommodityType::ServiceDemand | CommodityType::SupplyEqualsDemand
335321
)
336-
.then(|| commodity_id.clone())
322+
.then_some(node_weight)
323+
},
324+
// Consider only primary edges
325+
|_, edge_weight| matches!(edge_weight, GraphEdge::Primary(_)).then_some(edge_weight),
326+
);
327+
328+
// Condense strongly connected components
329+
let condensed_graph = condensation(graph_filtered, true);
330+
331+
// Perform a topological sort on the condensed graph
332+
// We can safely unwrap because `toposort` will only return an error in case of cycles, which
333+
// should have been detected and compressed with `condensation`
334+
let order = toposort(&condensed_graph, None).unwrap();
335+
336+
// Create investment sets (reverse topological order)
337+
order
338+
.iter()
339+
.rev()
340+
.filter_map(|node_idx| {
341+
// Get set of commodity ID(s) for the node, referring back to `condensed_graph`
342+
let commodities: Vec<CommodityID> = condensed_graph
343+
.node_weight(*node_idx)
344+
.unwrap()
345+
.iter()
346+
.filter_map(|node| match node {
347+
GraphNode::Commodity(id) => Some(id.clone()),
348+
_ => None,
349+
})
350+
.collect();
351+
352+
// Create investment set
353+
// If a single commodity in the node this is `InvestmentSet::Single`, if multiple
354+
// commodities this is `InvestmentSet::Cycle`
355+
match commodities.as_slice() {
356+
[] => None,
357+
[only] => Some(InvestmentSet::Single(only.clone())),
358+
_ => Some(InvestmentSet::Cycle(commodities)),
359+
}
337360
})
338-
.collect();
339-
340-
Ok(order)
361+
.collect()
341362
}
342363

343364
/// Builds base commodity graphs for each region and year
@@ -347,65 +368,42 @@ pub fn build_commodity_graphs_for_model(
347368
processes: &ProcessMap,
348369
region_ids: &IndexSet<RegionID>,
349370
years: &[u32],
350-
) -> Result<HashMap<(RegionID, u32), CommoditiesGraph>> {
351-
let commodity_graphs: HashMap<(RegionID, u32), CommoditiesGraph> =
352-
iproduct!(region_ids, years.iter())
353-
.map(|(region_id, year)| {
354-
let graph = create_commodities_graph_for_region_year(processes, region_id, *year);
355-
((region_id.clone(), *year), graph)
356-
})
357-
.collect();
358-
359-
Ok(commodity_graphs)
371+
) -> HashMap<(RegionID, u32), CommoditiesGraph> {
372+
iproduct!(region_ids, years.iter())
373+
.map(|(region_id, year)| {
374+
let graph = create_commodities_graph_for_region_year(processes, region_id, *year);
375+
((region_id.clone(), *year), graph)
376+
})
377+
.collect()
360378
}
361379

362380
/// Validates commodity graphs for the entire model.
363381
///
364-
/// This function creates commodity flow graphs for each region/year combination in the model,
365-
/// validates the graph structure against commodity type rules, and determines the optimal
366-
/// investment order for commodities.
367-
///
368382
/// The validation process checks three time slice levels:
369383
/// - **Annual**: Validates annual-level commodities and processes
370384
/// - **Seasonal**: Validates seasonal-level commodities and processes for each season
371385
/// - **Day/Night**: Validates day/night-level commodities and processes for each time slice
372386
///
373387
/// # Arguments
374388
///
389+
/// * `commodity_graphs` - Commodity graphs for each region and year, outputted from `build_commodity_graphs_for_model`
375390
/// * `processes` - All processes in the model with their flows and activity limits
376391
/// * `commodities` - All commodities with their types and demand specifications
377392
/// * `region_ids` - Collection of regions to model
378393
/// * `years` - Years to analyse
379394
/// * `time_slice_info` - Time slice configuration (seasons, day/night periods)
380395
///
381-
/// # Returns
382-
///
383-
/// A map from `(region, year)` to the ordered list of commodities for investment decisions. The
384-
/// ordering ensures that leaf-node commodities (those with no outgoing edges) are solved first.
385-
///
386396
/// # Errors
387397
///
388398
/// Returns an error if:
389-
/// - Any commodity graph contains cycles
390399
/// - Commodity type rules are violated (e.g., SVD commodities being consumed)
391400
/// - Demand cannot be satisfied
392401
pub fn validate_commodity_graphs_for_model(
393402
commodity_graphs: &HashMap<(RegionID, u32), CommoditiesGraph>,
394403
processes: &ProcessMap,
395404
commodities: &CommodityMap,
396405
time_slice_info: &TimeSliceInfo,
397-
) -> Result<HashMap<(RegionID, u32), Vec<CommodityID>>> {
398-
// Determine commodity ordering for each region and year
399-
let commodity_order: HashMap<(RegionID, u32), Vec<CommodityID>> = commodity_graphs
400-
.iter()
401-
.map(|((region_id, year), graph)| -> Result<_> {
402-
let order = topo_sort_commodities(graph, commodities).with_context(|| {
403-
format!("Error validating commodity graph for {region_id} in {year}")
404-
})?;
405-
Ok(((region_id.clone(), *year), order))
406-
})
407-
.try_collect()?;
408-
406+
) -> Result<()> {
409407
// Validate graphs at all time slice levels (taking into account process availability and demand)
410408
for ((region_id, year), base_graph) in commodity_graphs {
411409
for ts_level in TimeSliceLevel::iter() {
@@ -428,9 +426,33 @@ pub fn validate_commodity_graphs_for_model(
428426
}
429427
}
430428
}
429+
Ok(())
430+
}
431431

432-
// If all the validation passes, return the commodity ordering
433-
Ok(commodity_order)
432+
/// Determine commodity ordering for each region and year
433+
///
434+
/// # Arguments
435+
///
436+
/// * `commodity_graphs` - Commodity graphs for each region and year, outputted from `build_commodity_graphs_for_model`
437+
/// * `commodities` - All commodities with their types and demand specifications
438+
///
439+
/// # Returns
440+
///
441+
/// A map from `(region, year)` to the ordered list of commodities for investment decisions. The
442+
/// ordering ensures that leaf-node commodities (those with no outgoing edges) are solved first.
443+
pub fn solve_investment_order_for_model(
444+
commodity_graphs: &HashMap<(RegionID, u32), CommoditiesGraph>,
445+
commodities: &CommodityMap,
446+
) -> HashMap<(RegionID, u32), Vec<InvestmentSet>> {
447+
commodity_graphs
448+
.iter()
449+
.map(|((region_id, year), graph)| {
450+
(
451+
(region_id.clone(), *year),
452+
solve_investment_order(graph, commodities),
453+
)
454+
})
455+
.collect()
434456
}
435457

436458
/// Gets custom DOT attributes for edges in a commodity graph
@@ -473,7 +495,10 @@ mod tests {
473495
use std::rc::Rc;
474496

475497
#[rstest]
476-
fn test_topo_sort_linear_graph(sed_commodity: Commodity, svd_commodity: Commodity) {
498+
fn test_solve_investment_order_linear_graph(
499+
sed_commodity: Commodity,
500+
svd_commodity: Commodity,
501+
) {
477502
// Create a simple linear graph: A -> B -> C
478503
let mut graph = Graph::new();
479504

@@ -491,17 +516,18 @@ mod tests {
491516
commodities.insert("B".into(), Rc::new(sed_commodity));
492517
commodities.insert("C".into(), Rc::new(svd_commodity));
493518

494-
let result = topo_sort_commodities(&graph, &commodities).unwrap();
519+
let result = solve_investment_order(&graph, &commodities);
495520

496521
// Expected order: C, B, A (leaf nodes first)
522+
// No cycles, so all investment sets should be `Single`
497523
assert_eq!(result.len(), 3);
498-
assert_eq!(result[0], "C".into());
499-
assert_eq!(result[1], "B".into());
500-
assert_eq!(result[2], "A".into());
524+
assert_eq!(result[0], InvestmentSet::Single("C".into()));
525+
assert_eq!(result[1], InvestmentSet::Single("B".into()));
526+
assert_eq!(result[2], InvestmentSet::Single("A".into()));
501527
}
502528

503529
#[rstest]
504-
fn test_topo_sort_cyclic_graph(sed_commodity: Commodity) {
530+
fn test_solve_investment_order_cyclic_graph(sed_commodity: Commodity) {
505531
// Create a simple cyclic graph: A -> B -> A
506532
let mut graph = Graph::new();
507533

@@ -517,11 +543,14 @@ mod tests {
517543
commodities.insert("A".into(), Rc::new(sed_commodity.clone()));
518544
commodities.insert("B".into(), Rc::new(sed_commodity));
519545

520-
// This should return an error due to the cycle
521-
// The error message should flag commodity B
522-
// Note: A is also involved in the cycle, but B is flagged as it is encountered first
523-
let result = topo_sort_commodities(&graph, &commodities);
524-
assert_error!(result, "Cycle detected in commodity graph for commodity B");
546+
let result = solve_investment_order(&graph, &commodities);
547+
548+
// Should be a single `Cycle` investment set containing both commodities
549+
assert_eq!(result.len(), 1);
550+
assert_eq!(
551+
result[0],
552+
InvestmentSet::Cycle(vec!["A".into(), "B".into()])
553+
);
525554
}
526555

527556
#[rstest]
@@ -548,8 +577,7 @@ mod tests {
548577
graph.add_edge(node_c, node_d, GraphEdge::Demand);
549578

550579
// Validate the graph at DayNight level
551-
let result = validate_commodities_graph(&graph, &commodities, TimeSliceLevel::Annual);
552-
assert!(result.is_ok());
580+
assert!(validate_commodities_graph(&graph, &commodities, TimeSliceLevel::Annual).is_ok());
553581
}
554582

555583
#[rstest]
@@ -574,8 +602,10 @@ mod tests {
574602
graph.add_edge(node_a, node_b, GraphEdge::Primary("process2".into()));
575603

576604
// Validate the graph at DayNight level
577-
let result = validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight);
578-
assert_error!(result, "SVD commodity A cannot be an input to a process");
605+
assert_error!(
606+
validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight),
607+
"SVD commodity A cannot be an input to a process"
608+
);
579609
}
580610

581611
#[rstest]
@@ -592,8 +622,10 @@ mod tests {
592622
graph.add_edge(node_a, node_b, GraphEdge::Demand);
593623

594624
// Validate the graph at DayNight level
595-
let result = validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight);
596-
assert_error!(result, "SVD commodity A is demanded but has no producers");
625+
assert_error!(
626+
validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight),
627+
"SVD commodity A is demanded but has no producers"
628+
);
597629
}
598630

599631
#[rstest]
@@ -611,9 +643,8 @@ mod tests {
611643
graph.add_edge(node_b, node_a, GraphEdge::Primary("process1".into()));
612644

613645
// Validate the graph at DayNight level
614-
let result = validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight);
615646
assert_error!(
616-
result,
647+
validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight),
617648
"SED commodity B may be consumed but has no producers"
618649
);
619650
}
@@ -639,9 +670,8 @@ mod tests {
639670
graph.add_edge(node_a, node_c, GraphEdge::Primary("process2".into()));
640671

641672
// Validate the graph at DayNight level
642-
let result = validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight);
643673
assert_error!(
644-
result,
674+
validate_commodities_graph(&graph, &commodities, TimeSliceLevel::DayNight),
645675
"OTH commodity A cannot have both producers and consumers"
646676
);
647677
}

src/input.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
//! Common routines for handling input data.
22
use crate::asset::AssetPool;
33
use crate::graph::{
4-
CommoditiesGraph, build_commodity_graphs_for_model, validate_commodity_graphs_for_model,
4+
CommoditiesGraph, build_commodity_graphs_for_model, solve_investment_order_for_model,
5+
validate_commodity_graphs_for_model,
56
};
67
use crate::id::{HasID, IDLike};
78
use crate::model::{Model, ModelParameters};
@@ -232,15 +233,17 @@ pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
232233
let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
233234

234235
// Build and validate commodity graphs for all regions and years
235-
// This gives us the commodity order for each region/year which is passed to the model
236-
let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
237-
let commodity_order = validate_commodity_graphs_for_model(
236+
let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
237+
validate_commodity_graphs_for_model(
238238
&commodity_graphs,
239239
&processes,
240240
&commodities,
241241
&time_slice_info,
242242
)?;
243243

244+
// Solve investment order for each region/year
245+
let investment_order = solve_investment_order_for_model(&commodity_graphs, &commodities);
246+
244247
let model_path = model_dir
245248
.as_ref()
246249
.canonicalize()
@@ -253,7 +256,7 @@ pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
253256
processes,
254257
time_slice_info,
255258
regions,
256-
commodity_order,
259+
investment_order,
257260
};
258261
Ok((model, AssetPool::new(assets)))
259262
}
@@ -284,7 +287,7 @@ pub fn load_commodity_graphs<P: AsRef<Path>>(
284287
years,
285288
)?;
286289

287-
let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
290+
let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
288291
Ok(commodity_graphs)
289292
}
290293

0 commit comments

Comments
 (0)