Skip to content

Commit 883ea9d

Browse files
committed
WIP: Initial clustering implementation for bubbles
1 parent c407dc0 commit 883ea9d

6 files changed

Lines changed: 128 additions & 150 deletions

File tree

dnainator-core/src/main/java/nl/tudelft/dnainator/graph/Graph.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ public interface Graph extends AnnotationCollection {
6969

7070
/**
7171
* Return a list of nodes that belong to the same cluster as the given startId.
72-
* @param startNodes the start nodes
72+
* @param start the start nodes
7373
* @param end the maximum rank of the cluster
7474
* @param threshold the clustering threshold
7575
* @return a list representing the cluster
7676
*/
77-
Map<Integer, List<Cluster>> getAllClusters(List<String> startNodes, int end, int threshold);
77+
Map<Integer, List<Cluster>> getAllClusters(int start, int end, int threshold);
7878

7979
/**
8080
* Sets the interestingness strategy which calculates the interestingness when

dnainator-core/src/main/java/nl/tudelft/dnainator/graph/impl/Neo4jGraph.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ public int getRankFromBasePair(int base) {
171171
}
172172

173173
@Override
174-
public Map<Integer, List<Cluster>> getAllClusters(List<String> startNodes,
174+
public Map<Integer, List<Cluster>> getAllClusters(int start,
175175
int end, int threshold) {
176-
return query(new AllClustersQuery(startNodes, end, threshold, is));
176+
return query(new AllClustersQuery(start, end, threshold, is));
177177
}
178178

179179
@Override

dnainator-core/src/main/java/nl/tudelft/dnainator/graph/impl/command/TopologicalPathExpander.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,21 @@ private boolean hasUnprocessedIncoming(Node n) {
5353
public Iterable<Relationship> expand(Path path,
5454
BranchState<Object> noState) {
5555
Node from = path.endNode();
56-
5756
// Propagate all unclosed bubbles and the newly created ones.
5857
Set<Long> toPropagate = getSourcesToPropagate(from);
5958

6059
// For each unclosed bubble source, remove the current node from the endings and
6160
// add outgoing nodes to the ending nodes, thereby advancing the bubble endings.
6261
toPropagate.forEach(e -> advanceEnds(e, from));
63-
62+
// Store in this node the bubbles in which it is nested.
63+
storeOuterBubbles(from, toPropagate);
6464
// Create a new bubblesource, that will have its own bubble endings.
6565
createBubbleSource(from, toPropagate);
6666

6767
// Encode the unclosed propagated bubbles on the edges.
6868
from.getRelationships(RelTypes.NEXT, Direction.OUTGOING)
6969
.forEach(out -> propagateSourceIDs(toPropagate, out));
7070

71-
// Process all outgoing edges.
7271
List<Relationship> expand = new LinkedList<>();
7372
for (Relationship out : from.getRelationships(RelTypes.NEXT, Direction.OUTGOING)) {
7473
setNumStrainsThrough(out);
@@ -82,6 +81,13 @@ public Iterable<Relationship> expand(Path path,
8281
return expand;
8382
}
8483

84+
private void storeOuterBubbles(Node from, Set<Long> toPropagate) {
85+
// Set the source id of the bubbles to which this node belongs. Excludes its own
86+
// source id if it's a source.
87+
from.setProperty(BubbleProperties.BUBBLE_SOURCE_IDS.name(),
88+
toPropagate.stream().mapToLong(l -> l).toArray());
89+
}
90+
8591
private Set<Long> getSourcesToPropagate(Node from) {
8692
Iterable<Relationship> ins = from.getRelationships(RelTypes.NEXT, Direction.INCOMING);
8793

@@ -121,8 +127,6 @@ private void createBubbleSource(Node n, Set<Long> toPropagate) {
121127
}
122128

123129
private void propagateSourceIDs(Set<Long> propagatedUnique, Relationship out) {
124-
out.setProperty(BubbleProperties.BUBBLE_SOURCE_IDS.name(),
125-
propagatedUnique.stream().mapToLong(l -> l).toArray());
126130
relIDtoSourceIDs.put(out.getId(), propagatedUnique);
127131
}
128132

dnainator-core/src/main/java/nl/tudelft/dnainator/graph/impl/query/AllClustersQuery.java

Lines changed: 114 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,130 +3,172 @@
33
import nl.tudelft.dnainator.annotation.Annotation;
44
import nl.tudelft.dnainator.core.EnrichedSequenceNode;
55
import nl.tudelft.dnainator.core.impl.Cluster;
6+
import nl.tudelft.dnainator.graph.impl.Neo4jScoreContainer;
67
import nl.tudelft.dnainator.graph.impl.Neo4jSequenceNode;
78
import nl.tudelft.dnainator.graph.impl.NodeLabels;
89
import nl.tudelft.dnainator.graph.impl.RelTypes;
10+
import nl.tudelft.dnainator.graph.impl.properties.BubbleProperties;
911
import nl.tudelft.dnainator.graph.impl.properties.SequenceProperties;
1012
import nl.tudelft.dnainator.graph.interestingness.InterestingnessStrategy;
1113

14+
import org.neo4j.graphalgo.GraphAlgoFactory;
15+
import org.neo4j.graphalgo.PathFinder;
1216
import org.neo4j.graphdb.Direction;
1317
import org.neo4j.graphdb.GraphDatabaseService;
1418
import org.neo4j.graphdb.Node;
19+
import org.neo4j.graphdb.Path;
20+
import org.neo4j.graphdb.PathExpanders;
21+
import org.neo4j.graphdb.traversal.Evaluation;
1522
import org.neo4j.graphdb.traversal.TraversalDescription;
23+
import org.neo4j.helpers.collection.IteratorUtil;
1624

1725
import java.util.ArrayList;
18-
import java.util.HashMap;
26+
import java.util.Collection;
27+
import java.util.Collections;
1928
import java.util.HashSet;
20-
import java.util.LinkedList;
2129
import java.util.List;
2230
import java.util.Map;
23-
import java.util.PriorityQueue;
24-
import java.util.Queue;
2531
import java.util.Set;
2632
import java.util.stream.Collectors;
33+
import java.util.stream.Stream;
34+
import java.util.stream.StreamSupport;
2735

2836
/**
2937
* The {@link AllClustersQuery} creates {@link Cluster}s from all nodes,
30-
* starting at the startNodes, and ending when the maximum specified start rank is reached.
38+
* between the given ranks using the given threshold value.
3139
*/
3240
public class AllClustersQuery implements Query<Map<Integer, List<Cluster>>> {
33-
private Set<String> visited;
34-
private List<String> startNodes;
35-
private int threshold;
41+
private int minRank;
3642
private int maxRank;
43+
private int threshold;
3744
private InterestingnessStrategy is;
3845

3946
/**
4047
* Create a new {@link AllClustersQuery}, which will:.
41-
* - start clustering at the specified startNodes
42-
* - stop clustering when the end rank is reached
48+
* - get all clusters between the given ranks
4349
* - use the specified clustering threshold
44-
* @param startNodes the start nodes
50+
* @param minRank the minimum rank
4551
* @param maxRank the maximum rank
4652
* @param threshold the clustering threshold
4753
* @param is the interestingness strategy, which determines how the
4854
* interestingness score is calculated.
4955
*/
50-
public AllClustersQuery(List<String> startNodes, int maxRank, int threshold,
56+
public AllClustersQuery(int minRank, int maxRank, int threshold,
5157
InterestingnessStrategy is) {
52-
this.startNodes = startNodes;
58+
this.minRank = minRank;
5359
this.maxRank = maxRank;
5460
this.threshold = threshold;
55-
this.visited = new HashSet<>();
5661
this.is = is;
5762
}
5863

64+
private TraversalDescription untilMaxRank(GraphDatabaseService service) {
65+
return service.traversalDescription()
66+
.breadthFirst()
67+
.evaluator(path -> {
68+
if ((int) path.endNode().getProperty(SequenceProperties.RANK.name())
69+
<= maxRank) {
70+
return Evaluation.INCLUDE_AND_CONTINUE;
71+
} else {
72+
return Evaluation.EXCLUDE_AND_PRUNE;
73+
}
74+
})
75+
.relationships(RelTypes.NEXT, Direction.OUTGOING);
76+
}
77+
5978
@Override
6079
public Map<Integer, List<Cluster>> execute(GraphDatabaseService service) {
61-
Queue<Cluster> rootClusters = new PriorityQueue<>((e1, e2) ->
62-
e1.getStartRank() - e2.getStartRank()
63-
);
64-
Map<Integer, List<Cluster>> result = new HashMap<Integer, List<Cluster>>();
65-
66-
rootClusters.addAll(clustersFrom(service, startNodes));
67-
68-
// Find adjacent clusters as long as there are root clusters in the queue
69-
int minrank = rootClusters.stream().mapToInt(e -> e.getStartRank()).min().orElse(0);
70-
while (!rootClusters.isEmpty()) {
71-
Cluster c = rootClusters.poll();
72-
if (c.getStartRank() < minrank || c.getStartRank() > maxRank) {
73-
continue;
80+
Set<Long> bubbleSourcesToCluster = new HashSet<>();
81+
Set<Long> bubbleSourcesToKeepIntact = new HashSet<>();
82+
Iterable<Node> start = IteratorUtil.loop(service.findNodes(NodeLabels.NODE,
83+
SequenceProperties.RANK.name(), minRank));
84+
for (Node n : untilMaxRank(service).traverse(start).nodes()) {
85+
if (n.hasLabel(NodeLabels.BUBBLE_SOURCE)) {
86+
bubbleSourcesToCluster.add(n.getId());
87+
}
88+
int interestingness = is.compute(new Neo4jScoreContainer(n));
89+
if (interestingness > threshold) {
90+
for (long sourceID
91+
: (long[]) n.getProperty(BubbleProperties.BUBBLE_SOURCE_IDS.name())) {
92+
bubbleSourcesToKeepIntact.add(sourceID);
93+
bubbleSourcesToCluster.remove(sourceID);
94+
}
7495
}
75-
result.putIfAbsent(c.getStartRank(), new ArrayList<>());
76-
result.get(c.getStartRank()).add(c);
77-
78-
c.getNodes().forEach(sn -> {
79-
rootClusters.addAll(clustersFrom(service, sn.getOutgoing()));
80-
});
8196
}
97+
return cluster(service, bubbleSourcesToCluster, bubbleSourcesToKeepIntact);
98+
}
8299

83-
return result;
100+
private Map<Integer, List<Cluster>> cluster(GraphDatabaseService service,
101+
Set<Long> bubbleSourcesToCluster, Set<Long> bubbleSourcesToKeepIntact) {
102+
Map<Integer, List<Cluster>> bubblesClustered = bubbleSourcesToCluster.stream()
103+
.map(service::getNodeById)
104+
.map(source -> collapseBubble(service, source, getSinkFromSource(source)))
105+
.collect(Collectors.groupingBy(Cluster::getStartRank));
106+
Stream<Map<Integer, List<Cluster>>> singletonClusters = bubbleSourcesToKeepIntact.stream()
107+
.map(service::getNodeById)
108+
.map(source -> getSingletonClusters(service, source, getSinkFromSource(source)));
109+
return mergeMaps(Stream.concat(Stream.of(bubblesClustered), singletonClusters));
84110
}
85111

86-
private Queue<Cluster> clustersFrom(GraphDatabaseService service, List<String> startNodes) {
87-
Queue<Cluster> rootClusters = new LinkedList<Cluster>();
112+
private static Node getSinkFromSource(Node source) {
113+
return source.getSingleRelationship(RelTypes.BUBBLE_SOURCE_OF, Direction.OUTGOING)
114+
.getEndNode();
115+
}
88116

89-
for (String sn : startNodes) {
90-
// Continue if this startNode was consumed by another cluster
91-
if (visited.contains(sn)) {
92-
continue;
93-
}
117+
private Map<Integer, List<Cluster>> getSingletonClusters(GraphDatabaseService service,
118+
Node source, Node sink) {
119+
int sourceRank = (int) source.getProperty(SequenceProperties.RANK.name());
120+
int sinkRank = (int) sink.getProperty(SequenceProperties.RANK.name());
121+
PathFinder<Path> withinBubble = pathFinderBetweenRanks(sourceRank, sinkRank);
122+
return stream(withinBubble.findAllPaths(source, sink))
123+
.flatMap(path -> stream(path.nodes()))
124+
.distinct()
125+
.map(n -> createSingletonCluster(service, n))
126+
.collect(Collectors.groupingBy(Cluster::getStartRank));
127+
}
94128

95-
// Otherwise get the cluster starting from this startNode
96-
rootClusters.add(cluster(service, sn));
97-
}
129+
private Cluster createSingletonCluster(GraphDatabaseService service, Node n) {
130+
EnrichedSequenceNode sn = new Neo4jSequenceNode(service, n);
131+
return new Cluster((int) n.getProperty(SequenceProperties.RANK.name()),
132+
Collections.singletonList(sn), sn.getAnnotations());
133+
}
98134

99-
return rootClusters;
135+
private Cluster collapseBubble(GraphDatabaseService service, Node source, Node sink) {
136+
int sourceRank = (int) source.getProperty(SequenceProperties.RANK.name());
137+
int sinkRank = (int) sink.getProperty(SequenceProperties.RANK.name());
138+
int clusterRank = sourceRank + (sinkRank - sourceRank) / 2;
139+
PathFinder<Path> withinBubble = pathFinderBetweenRanks(sourceRank, sinkRank);
140+
// FIXME: don't collapse source and sink, keep those intact.
141+
List<EnrichedSequenceNode> nodes = stream(
142+
withinBubble.findAllPaths(source, sink))
143+
.flatMap(path -> stream(path.nodes()))
144+
.distinct()
145+
.map(n -> new Neo4jSequenceNode(service, n))
146+
.collect(Collectors.toList());
147+
List<Annotation> annotations = nodes.stream()
148+
.flatMap(e -> e.getAnnotations().stream())
149+
.collect(Collectors.toList());
150+
return new Cluster(clusterRank, nodes, annotations);
100151
}
101152

102-
private Cluster cluster(GraphDatabaseService service, String start) {
103-
Cluster cluster = null;
104-
Node startNode = service.findNode(NodeLabels.NODE, SequenceProperties.ID.name(), start);
105-
List<Node> result = new ArrayList<>();
153+
private PathFinder<Path> pathFinderBetweenRanks(int minRank, int maxRank) {
154+
return GraphAlgoFactory.allSimplePaths(
155+
PathExpanders.forTypeAndDirection(RelTypes.NEXT, Direction.OUTGOING),
156+
maxRank - minRank);
157+
}
106158

107-
// A depth first traversal traveling along both incoming and outgoing edges.
108-
TraversalDescription clusterDesc = service.traversalDescription()
109-
.depthFirst()
110-
.relationships(RelTypes.NEXT, Direction.BOTH)
111-
.evaluator(new ClusterEvaluator(threshold, visited, is));
112-
// Traverse the cluster starting from the startNode.
113-
int rankStart = (int) startNode.getProperty(SequenceProperties.RANK.name());
114-
for (Node end : clusterDesc.traverse(startNode).nodes()) {
115-
result.add(end);
159+
private Map<Integer, List<Cluster>> mergeMaps(Stream<Map<Integer, List<Cluster>>> concat) {
160+
return concat.map(Map::entrySet)
161+
.flatMap(Collection::stream)
162+
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue(), (left, right) -> {
163+
List<Cluster> newList = new ArrayList<>(left.size() + right.size());
164+
newList.addAll(right);
165+
newList.addAll(left);
166+
return left;
167+
}));
168+
}
116169

117-
// Update this cluster's start rank according to the lowest node rank.
118-
int endRank = (int) startNode.getProperty(SequenceProperties.RANK.name());
119-
if (endRank < rankStart) {
120-
rankStart = endRank;
121-
}
122-
}
123-
// Might want to internally pass nodes.
124-
List<EnrichedSequenceNode> retrieve = result.stream()
125-
.map(e -> new Neo4jSequenceNode(service, e))
126-
.collect(Collectors.toList());
127-
List<Annotation> annotations = retrieve.stream().flatMap(e -> e.getAnnotations().stream())
128-
.collect(Collectors.toList());
129-
cluster = new Cluster(rankStart, retrieve, annotations);
130-
return cluster;
170+
private static <T> Stream<T> stream(Iterable<T> in) {
171+
// Quick utility method, for converting iterables to streams.
172+
return StreamSupport.stream(in.spliterator(), false);
131173
}
132174
}

dnainator-core/src/main/java/nl/tudelft/dnainator/graph/impl/query/ClusterEvaluator.java

Lines changed: 0 additions & 66 deletions
This file was deleted.

dnainator-javafx/src/main/java/nl/tudelft/dnainator/javafx/drawables/strains/Strain.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,8 @@ public void loadChildren(Bounds bounds) {
8181
Range ranks = getRange(bounds);
8282

8383
System.out.println("load iteration: " + ranks.getX() + " -> " + ranks.getY());
84-
List<String> roots = graph.getRank(ranks.getX()).stream()
85-
.map(SequenceNode::getId).collect(Collectors.toList());
8684
List<Annotation> annotations = getSortedAnnotations(ranks);
87-
Map<Integer, List<Cluster>> result = graph.getAllClusters(roots, ranks.getY(),
85+
Map<Integer, List<Cluster>> result = graph.getAllClusters(ranks.getX(), ranks.getY(),
8886
(int) (bounds.getWidth() / CLUSTER_DIVIDER));
8987
clusters.clear();
9088
childContent.getChildren().clear();

0 commit comments

Comments
 (0)