Skip to content

Commit cfa2b5d

Browse files
committed
Curveball algorithm
1 parent d99e323 commit cfa2b5d

2 files changed

Lines changed: 223 additions & 6 deletions

File tree

nodes/src/main/java/org/nodes/models/USequenceEstimator.java

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import static java.lang.Math.pow;
88
import static java.lang.Math.sqrt;
99
import static nl.peterbloem.kit.Functions.choose;
10+
import static nl.peterbloem.kit.Functions.concat;
1011
import static nl.peterbloem.kit.Functions.exp2;
1112
import static nl.peterbloem.kit.Functions.log2;
1213
import static nl.peterbloem.kit.Functions.log2Min;
1314
import static nl.peterbloem.kit.Functions.log2Sum;
15+
import static nl.peterbloem.kit.Functions.sampleInts;
1416
import static nl.peterbloem.kit.LogNum.fromDouble;
1517
import static nl.peterbloem.kit.Pair.first;
1618
import static nl.peterbloem.kit.Series.series;
@@ -24,23 +26,31 @@
2426
import java.util.Collections;
2527
import java.util.Comparator;
2628
import java.util.HashSet;
29+
import java.util.Iterator;
2730
import java.util.LinkedHashSet;
31+
import java.util.LinkedList;
2832
import java.util.List;
2933
import java.util.PriorityQueue;
3034
import java.util.Set;
3135
import java.util.Vector;
3236
import java.util.concurrent.Callable;
3337
import java.util.concurrent.ExecutorService;
3438

39+
import org.apache.commons.math3.distribution.BinomialDistribution;
3540
import org.apache.commons.math3.distribution.NormalDistribution;
3641
import org.apache.commons.math3.distribution.TDistribution;
3742
import org.nodes.DGraph;
3843
import org.nodes.DNode;
3944
import org.nodes.Graph;
45+
import org.nodes.LightUGraph;
46+
import org.nodes.Link;
4047
import org.nodes.MapUTGraph;
4148
import org.nodes.Node;
4249
import org.nodes.UGraph;
50+
import org.nodes.ULink;
51+
import org.nodes.UNode;
4352

53+
import nl.peterbloem.kit.AbstractGenerator;
4454
import nl.peterbloem.kit.Functions;
4555
import nl.peterbloem.kit.Generator;
4656
import nl.peterbloem.kit.Global;
@@ -82,7 +92,12 @@ public USequenceEstimator(Graph<?> data, int samples)
8292

8393
public USequenceEstimator(Graph<?> data)
8494
{
85-
95+
this(data, null);
96+
}
97+
98+
public USequenceEstimator(Graph<?> data, L label)
99+
{
100+
this.label = label;
86101
sequence = new ArrayList<Integer>(data.size());
87102

88103
for(Node<?> node : data.nodes())
@@ -97,8 +112,8 @@ public USequenceEstimator(List<Integer> sequence, int samples)
97112
for(int i : series(samples))
98113
{
99114
nonuniform();
100-
if(Functions.toc() > 10)
101-
System.out.println("\r " + logSamples.size() + " samples completed");
115+
if(Functions.toc() > 10 && i % (samples/100) == 0)
116+
Global.log().info(logSamples.size() + " samples completed");
102117
}
103118
}
104119

@@ -1085,12 +1100,179 @@ public String toString()
10851100
return graph + ", (c=" + c + ", sigma=" + sigma
10861101
+ ")";
10871102
}
1103+
}
1104+
1105+
/**
1106+
* Returns a generator for uniform samples.
1107+
*
1108+
* Uses the curveball algorithm for undirected simple graphs, see
1109+
* https://researchbank.rmit.edu.au/view/rmit:161573
1110+
* (chapter 4.2)
1111+
*
1112+
* @return
1113+
*/
1114+
public Generator<UGraph<L>> uniform(int mixingTime)
1115+
{
1116+
return new UniformGenerator(nonuniform().graph(), mixingTime);
1117+
}
1118+
1119+
private class UniformGenerator extends AbstractGenerator<UGraph<L>>
1120+
{
1121+
private List<Set<Integer>> adjacencies;
1122+
public int mixTime;
10881123

1124+
public UniformGenerator(UGraph<L> start, int mixingTime)
1125+
{
1126+
// * Extract an adjacency-list representation from the starting graph
1127+
adjacencies = adjacencies(start);
1128+
this.mixTime = mixingTime;
1129+
}
1130+
1131+
@Override
1132+
public UGraph<L> generate()
1133+
{
1134+
1135+
for(int i : series(mixTime))
1136+
{
1137+
step(adjacencies);
1138+
}
1139+
1140+
return graph();
1141+
}
10891142

1143+
/**
1144+
* Convert the adjacency lists to a graph
1145+
* @return
1146+
*/
1147+
private UGraph<L> graph()
1148+
{
1149+
// TODO: We can speed this up by filling the adjacency lists
1150+
// inside the LightUGraph directly
1151+
UGraph<L> graph = new LightUGraph<L>(adjacencies.size());
1152+
1153+
for(int i : series(adjacencies.size()))
1154+
graph.add(label);
1155+
1156+
for(int i : series(adjacencies.size()))
1157+
{
1158+
UNode<L> node = graph.get(i);
1159+
Set<Integer> indices = adjacencies.get(i);
1160+
1161+
for(int ind : indices)
1162+
if(ind > i)
1163+
node.connect(graph.get(ind));
1164+
}
1165+
1166+
return graph;
1167+
}
10901168
}
10911169

1092-
public Generator<UGraph<L>> uniform()
1170+
1171+
/**
1172+
* Perturbation score
1173+
* @param one
1174+
* @param two
1175+
* @return
1176+
*/
1177+
public static double perturbation(List<Set<Integer>> one, List<Set<Integer>> two)
10931178
{
1094-
return null;
1095-
}
1179+
int total = 0;
1180+
int m = 0;
1181+
1182+
for(int i : series(one.size()))
1183+
{
1184+
total += Functions.overlap(one.get(i), two.get(i));
1185+
m += one.get(i).size();
1186+
}
1187+
1188+
return 1.0 - (total/2)/(double)(m/2);
1189+
}
1190+
1191+
public static void step(List<Set<Integer>> adjacencies)
1192+
{
1193+
// * Randomly select two (distinct) sets
1194+
List<Integer> ind = sampleInts(2, adjacencies.size());
1195+
int oneInd = ind.get(0),
1196+
twoInd = ind.get(1);
1197+
1198+
Set<Integer> one = adjacencies.get(oneInd),
1199+
two = adjacencies.get(twoInd);
1200+
1201+
// * Filter out candidate swaps
1202+
List<Integer> oneCand = new ArrayList<Integer>(),
1203+
twoCand = new ArrayList<Integer>();
1204+
1205+
Iterator<Integer> itOne = one.iterator();
1206+
while(itOne.hasNext())
1207+
{
1208+
int index = itOne.next();
1209+
if(twoInd != index && ! two.contains(index))
1210+
{
1211+
oneCand.add(index);
1212+
itOne.remove();
1213+
}
1214+
}
1215+
1216+
Iterator<Integer> itTwo = two.iterator();
1217+
while(itTwo.hasNext())
1218+
{
1219+
int index = itTwo.next();
1220+
if(oneInd != index && ! one.contains(index))
1221+
{
1222+
twoCand.add(index);
1223+
itTwo.remove();
1224+
}
1225+
}
1226+
1227+
List<Integer> candidates = concat(oneCand, twoCand);
1228+
1229+
// - Remember the swaps
1230+
List<Integer> toOne = new ArrayList<Integer>(twoCand.size()); // came from two, went to one
1231+
List<Integer> toTwo = new ArrayList<Integer>(oneCand.size()); // came from one went to two
1232+
1233+
// * Add back randomly
1234+
Set<Integer> forOne = new LinkedHashSet<Integer>(sampleInts(oneCand.size(), candidates.size()));
1235+
1236+
for(int i : series(candidates.size()))
1237+
if(forOne.contains(i))
1238+
{
1239+
one.add(candidates.get(i));
1240+
if(i >= oneCand.size()) // swap, remember
1241+
toOne.add(candidates.get(i));
1242+
} else {
1243+
two.add(candidates.get(i));
1244+
if(i < oneCand.size()) // swap, remember
1245+
toTwo.add(candidates.get(i));
1246+
}
1247+
1248+
assert(toOne.size() == toTwo.size());
1249+
1250+
// * For each swap, perform the dual swap
1251+
for(int j : toOne)
1252+
{
1253+
adjacencies.get(j).remove(twoInd);
1254+
adjacencies.get(j).add(oneInd);
1255+
}
1256+
for(int k : toTwo)
1257+
{
1258+
adjacencies.get(k).remove(oneInd);
1259+
adjacencies.get(k).add(twoInd);
1260+
}
1261+
}
1262+
1263+
public static <L> List<Set<Integer>> adjacencies(UGraph<L> graph)
1264+
{
1265+
List<Set<Integer>> adjacencies = new ArrayList<Set<Integer>>(graph.size());
1266+
for(UNode<L> node : graph.nodes())
1267+
{
1268+
Set<Integer> set = new LinkedHashSet<Integer>();
1269+
1270+
for(UNode<L> neighbor : node.neighbors())
1271+
set.add(neighbor.index());
1272+
1273+
adjacencies.add(set);
1274+
}
1275+
1276+
return adjacencies;
1277+
}
10961278
}

nodes/src/test/java/org/nodes/models/USequenceModelTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static java.lang.Math.sqrt;
44
import static java.util.Arrays.asList;
5+
import static nl.peterbloem.kit.Functions.dot;
56
import static nl.peterbloem.kit.Functions.exp2;
67
import static nl.peterbloem.kit.Functions.tic;
78
import static nl.peterbloem.kit.Functions.toc;
@@ -24,9 +25,12 @@
2425
import org.nodes.Graphs;
2526
import org.nodes.MapUTGraph;
2627
import org.nodes.Node;
28+
import org.nodes.UGraph;
2729
import org.nodes.random.RandomGraphs;
2830

31+
import nl.peterbloem.kit.FrequencyModel;
2932
import nl.peterbloem.kit.Functions;
33+
import nl.peterbloem.kit.Generator;
3034
import nl.peterbloem.kit.LogNum;
3135
import nl.peterbloem.kit.Pair;
3236
import nl.peterbloem.kit.Series;
@@ -451,5 +455,36 @@ public static LogNum l(double lMag, boolean pos)
451455
{
452456
return new LogNum(lMag, pos, 2.0);
453457
}
458+
459+
@Test
460+
public void testUniform()
461+
{
462+
final int SAMPLES = 6000;
463+
List<Integer> degrees = Arrays.asList(3,2,2,2,1);
464+
USequenceEstimator<String> model = new USequenceEstimator<String>(degrees, 10000);
465+
466+
// - the number of graphs
467+
double n = Math.round(Math.pow(2.0, model.logNormalMean()));
468+
469+
Generator<UGraph<String>> gen = model.uniform(1000);
470+
471+
FrequencyModel<UGraph<String>> fm = new FrequencyModel<UGraph<String>>();
472+
473+
474+
for(int i : series(SAMPLES))
475+
{
476+
fm.add(gen.generate());
477+
dot(i, SAMPLES);
478+
}
479+
480+
assertEquals(fm.distinct(), n, 0.00000000001);
481+
482+
for(UGraph<String> token : fm.tokens())
483+
{
484+
System.out.println(token + " " + fm.probability(token));
485+
assertEquals(1.0/n, fm.probability(token), 0.1);
486+
assertEquals(degrees, Graphs.degrees(token));
487+
}
488+
}
454489
}
455490

0 commit comments

Comments
 (0)