|
3 | 3 | import nl.tudelft.dnainator.annotation.Annotation; |
4 | 4 | import nl.tudelft.dnainator.core.EnrichedSequenceNode; |
5 | 5 | import nl.tudelft.dnainator.core.impl.Cluster; |
| 6 | +import nl.tudelft.dnainator.graph.impl.Neo4jScoreContainer; |
6 | 7 | import nl.tudelft.dnainator.graph.impl.Neo4jSequenceNode; |
7 | 8 | import nl.tudelft.dnainator.graph.impl.NodeLabels; |
8 | 9 | import nl.tudelft.dnainator.graph.impl.RelTypes; |
| 10 | +import nl.tudelft.dnainator.graph.impl.properties.BubbleProperties; |
9 | 11 | import nl.tudelft.dnainator.graph.impl.properties.SequenceProperties; |
10 | 12 | import nl.tudelft.dnainator.graph.interestingness.InterestingnessStrategy; |
11 | 13 |
|
| 14 | +import org.neo4j.graphalgo.GraphAlgoFactory; |
| 15 | +import org.neo4j.graphalgo.PathFinder; |
12 | 16 | import org.neo4j.graphdb.Direction; |
13 | 17 | import org.neo4j.graphdb.GraphDatabaseService; |
14 | 18 | import org.neo4j.graphdb.Node; |
| 19 | +import org.neo4j.graphdb.Path; |
| 20 | +import org.neo4j.graphdb.PathExpanders; |
| 21 | +import org.neo4j.graphdb.traversal.Evaluation; |
15 | 22 | import org.neo4j.graphdb.traversal.TraversalDescription; |
| 23 | +import org.neo4j.helpers.collection.IteratorUtil; |
16 | 24 |
|
17 | 25 | import java.util.ArrayList; |
18 | | -import java.util.HashMap; |
| 26 | +import java.util.Collection; |
| 27 | +import java.util.Collections; |
19 | 28 | import java.util.HashSet; |
20 | | -import java.util.LinkedList; |
21 | 29 | import java.util.List; |
22 | 30 | import java.util.Map; |
23 | | -import java.util.PriorityQueue; |
24 | | -import java.util.Queue; |
25 | 31 | import java.util.Set; |
26 | 32 | import java.util.stream.Collectors; |
| 33 | +import java.util.stream.Stream; |
| 34 | +import java.util.stream.StreamSupport; |
27 | 35 |
|
28 | 36 | /** |
29 | 37 | * The {@link AllClustersQuery} creates {@link Cluster}s from all nodes, |
30 | | - * starting at the startNodes, and ending when the maximum specified start rank is reached. |
| 38 | + * between the given ranks using the given threshold value. |
31 | 39 | */ |
32 | 40 | public class AllClustersQuery implements Query<Map<Integer, List<Cluster>>> { |
33 | | - private Set<String> visited; |
34 | | - private List<String> startNodes; |
35 | | - private int threshold; |
| 41 | + private int minRank; |
36 | 42 | private int maxRank; |
| 43 | + private int threshold; |
37 | 44 | private InterestingnessStrategy is; |
38 | 45 |
|
39 | 46 | /** |
40 | 47 | * Create a new {@link AllClustersQuery}, which will:. |
41 | | - * - start clustering at the specified startNodes |
42 | | - * - stop clustering when the end rank is reached |
| 48 | + * - get all clusters between the given ranks |
43 | 49 | * - use the specified clustering threshold |
44 | | - * @param startNodes the start nodes |
| 50 | + * @param minRank the minimum rank |
45 | 51 | * @param maxRank the maximum rank |
46 | 52 | * @param threshold the clustering threshold |
47 | 53 | * @param is the interestingness strategy, which determines how the |
48 | 54 | * interestingness score is calculated. |
49 | 55 | */ |
50 | | - public AllClustersQuery(List<String> startNodes, int maxRank, int threshold, |
| 56 | + public AllClustersQuery(int minRank, int maxRank, int threshold, |
51 | 57 | InterestingnessStrategy is) { |
52 | | - this.startNodes = startNodes; |
| 58 | + this.minRank = minRank; |
53 | 59 | this.maxRank = maxRank; |
54 | 60 | this.threshold = threshold; |
55 | | - this.visited = new HashSet<>(); |
56 | 61 | this.is = is; |
57 | 62 | } |
58 | 63 |
|
| 64 | + private TraversalDescription untilMaxRank(GraphDatabaseService service) { |
| 65 | + return service.traversalDescription() |
| 66 | + .breadthFirst() |
| 67 | + .evaluator(path -> { |
| 68 | + if ((int) path.endNode().getProperty(SequenceProperties.RANK.name()) |
| 69 | + <= maxRank) { |
| 70 | + return Evaluation.INCLUDE_AND_CONTINUE; |
| 71 | + } else { |
| 72 | + return Evaluation.EXCLUDE_AND_PRUNE; |
| 73 | + } |
| 74 | + }) |
| 75 | + .relationships(RelTypes.NEXT, Direction.OUTGOING); |
| 76 | + } |
| 77 | + |
59 | 78 | @Override |
60 | 79 | public Map<Integer, List<Cluster>> execute(GraphDatabaseService service) { |
61 | | - Queue<Cluster> rootClusters = new PriorityQueue<>((e1, e2) -> |
62 | | - e1.getStartRank() - e2.getStartRank() |
63 | | - ); |
64 | | - Map<Integer, List<Cluster>> result = new HashMap<Integer, List<Cluster>>(); |
65 | | - |
66 | | - rootClusters.addAll(clustersFrom(service, startNodes)); |
67 | | - |
68 | | - // Find adjacent clusters as long as there are root clusters in the queue |
69 | | - int minrank = rootClusters.stream().mapToInt(e -> e.getStartRank()).min().orElse(0); |
70 | | - while (!rootClusters.isEmpty()) { |
71 | | - Cluster c = rootClusters.poll(); |
72 | | - if (c.getStartRank() < minrank || c.getStartRank() > maxRank) { |
73 | | - continue; |
| 80 | + Set<Long> bubbleSourcesToCluster = new HashSet<>(); |
| 81 | + Set<Long> bubbleSourcesToKeepIntact = new HashSet<>(); |
| 82 | + Iterable<Node> start = IteratorUtil.loop(service.findNodes(NodeLabels.NODE, |
| 83 | + SequenceProperties.RANK.name(), minRank)); |
| 84 | + for (Node n : untilMaxRank(service).traverse(start).nodes()) { |
| 85 | + if (n.hasLabel(NodeLabels.BUBBLE_SOURCE)) { |
| 86 | + bubbleSourcesToCluster.add(n.getId()); |
| 87 | + } |
| 88 | + int interestingness = is.compute(new Neo4jScoreContainer(n)); |
| 89 | + if (interestingness > threshold) { |
| 90 | + for (long sourceID |
| 91 | + : (long[]) n.getProperty(BubbleProperties.BUBBLE_SOURCE_IDS.name())) { |
| 92 | + bubbleSourcesToKeepIntact.add(sourceID); |
| 93 | + bubbleSourcesToCluster.remove(sourceID); |
| 94 | + } |
74 | 95 | } |
75 | | - result.putIfAbsent(c.getStartRank(), new ArrayList<>()); |
76 | | - result.get(c.getStartRank()).add(c); |
77 | | - |
78 | | - c.getNodes().forEach(sn -> { |
79 | | - rootClusters.addAll(clustersFrom(service, sn.getOutgoing())); |
80 | | - }); |
81 | 96 | } |
| 97 | + return cluster(service, bubbleSourcesToCluster, bubbleSourcesToKeepIntact); |
| 98 | + } |
82 | 99 |
|
83 | | - return result; |
| 100 | + private Map<Integer, List<Cluster>> cluster(GraphDatabaseService service, |
| 101 | + Set<Long> bubbleSourcesToCluster, Set<Long> bubbleSourcesToKeepIntact) { |
| 102 | + Map<Integer, List<Cluster>> bubblesClustered = bubbleSourcesToCluster.stream() |
| 103 | + .map(service::getNodeById) |
| 104 | + .map(source -> collapseBubble(service, source, getSinkFromSource(source))) |
| 105 | + .collect(Collectors.groupingBy(Cluster::getStartRank)); |
| 106 | + Stream<Map<Integer, List<Cluster>>> singletonClusters = bubbleSourcesToKeepIntact.stream() |
| 107 | + .map(service::getNodeById) |
| 108 | + .map(source -> getSingletonClusters(service, source, getSinkFromSource(source))); |
| 109 | + return mergeMaps(Stream.concat(Stream.of(bubblesClustered), singletonClusters)); |
84 | 110 | } |
85 | 111 |
|
86 | | - private Queue<Cluster> clustersFrom(GraphDatabaseService service, List<String> startNodes) { |
87 | | - Queue<Cluster> rootClusters = new LinkedList<Cluster>(); |
| 112 | + private static Node getSinkFromSource(Node source) { |
| 113 | + return source.getSingleRelationship(RelTypes.BUBBLE_SOURCE_OF, Direction.OUTGOING) |
| 114 | + .getEndNode(); |
| 115 | + } |
88 | 116 |
|
89 | | - for (String sn : startNodes) { |
90 | | - // Continue if this startNode was consumed by another cluster |
91 | | - if (visited.contains(sn)) { |
92 | | - continue; |
93 | | - } |
| 117 | + private Map<Integer, List<Cluster>> getSingletonClusters(GraphDatabaseService service, |
| 118 | + Node source, Node sink) { |
| 119 | + int sourceRank = (int) source.getProperty(SequenceProperties.RANK.name()); |
| 120 | + int sinkRank = (int) sink.getProperty(SequenceProperties.RANK.name()); |
| 121 | + PathFinder<Path> withinBubble = pathFinderBetweenRanks(sourceRank, sinkRank); |
| 122 | + return stream(withinBubble.findAllPaths(source, sink)) |
| 123 | + .flatMap(path -> stream(path.nodes())) |
| 124 | + .distinct() |
| 125 | + .map(n -> createSingletonCluster(service, n)) |
| 126 | + .collect(Collectors.groupingBy(Cluster::getStartRank)); |
| 127 | + } |
94 | 128 |
|
95 | | - // Otherwise get the cluster starting from this startNode |
96 | | - rootClusters.add(cluster(service, sn)); |
97 | | - } |
| 129 | + private Cluster createSingletonCluster(GraphDatabaseService service, Node n) { |
| 130 | + EnrichedSequenceNode sn = new Neo4jSequenceNode(service, n); |
| 131 | + return new Cluster((int) n.getProperty(SequenceProperties.RANK.name()), |
| 132 | + Collections.singletonList(sn), sn.getAnnotations()); |
| 133 | + } |
98 | 134 |
|
99 | | - return rootClusters; |
| 135 | + private Cluster collapseBubble(GraphDatabaseService service, Node source, Node sink) { |
| 136 | + int sourceRank = (int) source.getProperty(SequenceProperties.RANK.name()); |
| 137 | + int sinkRank = (int) sink.getProperty(SequenceProperties.RANK.name()); |
| 138 | + int clusterRank = sourceRank + (sinkRank - sourceRank) / 2; |
| 139 | + PathFinder<Path> withinBubble = pathFinderBetweenRanks(sourceRank, sinkRank); |
| 140 | + // FIXME: don't collapse source and sink, keep those intact. |
| 141 | + List<EnrichedSequenceNode> nodes = stream( |
| 142 | + withinBubble.findAllPaths(source, sink)) |
| 143 | + .flatMap(path -> stream(path.nodes())) |
| 144 | + .distinct() |
| 145 | + .map(n -> new Neo4jSequenceNode(service, n)) |
| 146 | + .collect(Collectors.toList()); |
| 147 | + List<Annotation> annotations = nodes.stream() |
| 148 | + .flatMap(e -> e.getAnnotations().stream()) |
| 149 | + .collect(Collectors.toList()); |
| 150 | + return new Cluster(clusterRank, nodes, annotations); |
100 | 151 | } |
101 | 152 |
|
102 | | - private Cluster cluster(GraphDatabaseService service, String start) { |
103 | | - Cluster cluster = null; |
104 | | - Node startNode = service.findNode(NodeLabels.NODE, SequenceProperties.ID.name(), start); |
105 | | - List<Node> result = new ArrayList<>(); |
| 153 | + private PathFinder<Path> pathFinderBetweenRanks(int minRank, int maxRank) { |
| 154 | + return GraphAlgoFactory.allSimplePaths( |
| 155 | + PathExpanders.forTypeAndDirection(RelTypes.NEXT, Direction.OUTGOING), |
| 156 | + maxRank - minRank); |
| 157 | + } |
106 | 158 |
|
107 | | - // A depth first traversal traveling along both incoming and outgoing edges. |
108 | | - TraversalDescription clusterDesc = service.traversalDescription() |
109 | | - .depthFirst() |
110 | | - .relationships(RelTypes.NEXT, Direction.BOTH) |
111 | | - .evaluator(new ClusterEvaluator(threshold, visited, is)); |
112 | | - // Traverse the cluster starting from the startNode. |
113 | | - int rankStart = (int) startNode.getProperty(SequenceProperties.RANK.name()); |
114 | | - for (Node end : clusterDesc.traverse(startNode).nodes()) { |
115 | | - result.add(end); |
| 159 | + private Map<Integer, List<Cluster>> mergeMaps(Stream<Map<Integer, List<Cluster>>> concat) { |
| 160 | + return concat.map(Map::entrySet) |
| 161 | + .flatMap(Collection::stream) |
| 162 | + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue(), (left, right) -> { |
| 163 | + List<Cluster> newList = new ArrayList<>(left.size() + right.size()); |
| 164 | + newList.addAll(right); |
| 165 | + newList.addAll(left); |
| 166 | + return left; |
| 167 | + })); |
| 168 | + } |
116 | 169 |
|
117 | | - // Update this cluster's start rank according to the lowest node rank. |
118 | | - int endRank = (int) startNode.getProperty(SequenceProperties.RANK.name()); |
119 | | - if (endRank < rankStart) { |
120 | | - rankStart = endRank; |
121 | | - } |
122 | | - } |
123 | | - // Might want to internally pass nodes. |
124 | | - List<EnrichedSequenceNode> retrieve = result.stream() |
125 | | - .map(e -> new Neo4jSequenceNode(service, e)) |
126 | | - .collect(Collectors.toList()); |
127 | | - List<Annotation> annotations = retrieve.stream().flatMap(e -> e.getAnnotations().stream()) |
128 | | - .collect(Collectors.toList()); |
129 | | - cluster = new Cluster(rankStart, retrieve, annotations); |
130 | | - return cluster; |
| 170 | + private static <T> Stream<T> stream(Iterable<T> in) { |
| 171 | + // Quick utility method, for converting iterables to streams. |
| 172 | + return StreamSupport.stream(in.spliterator(), false); |
131 | 173 | } |
132 | 174 | } |
0 commit comments