|
46 | 46 | import org.numenta.nupic.datagen.ResourceLocator; |
47 | 47 | import org.numenta.nupic.encoders.MultiEncoder; |
48 | 48 | import org.numenta.nupic.network.sensor.FileSensor; |
| 49 | +import org.numenta.nupic.network.sensor.HTMSensor; |
| 50 | +import org.numenta.nupic.network.sensor.ObservableSensor; |
| 51 | +import org.numenta.nupic.network.sensor.Publisher; |
49 | 52 | import org.numenta.nupic.network.sensor.Sensor; |
50 | 53 | import org.numenta.nupic.network.sensor.SensorParams; |
51 | 54 | import org.numenta.nupic.network.sensor.SensorParams.Keys; |
@@ -636,4 +639,129 @@ public void testThreadedStartFlagging() { |
636 | 639 | } |
637 | 640 | } |
638 | 641 |
|
| 642 | + double anomaly = 1; |
| 643 | + boolean completed = false; |
| 644 | + @Test |
| 645 | + public void testObservableWithCoordinateEncoder() { |
| 646 | + Publisher manual = Publisher.builder() |
| 647 | + .addHeader("timestamp,consumption,location") |
| 648 | + .addHeader("datetime,float,geo") |
| 649 | + .addHeader("T,,").build(); |
| 650 | + |
| 651 | + Sensor<ObservableSensor<String[]>> sensor = Sensor.create( |
| 652 | + ObservableSensor::create, SensorParams.create(Keys::obs, "", manual)); |
| 653 | + |
| 654 | + Parameters p = NetworkTestHarness.getParameters().copy(); |
| 655 | + p = p.union(NetworkTestHarness.getGeospatialTestEncoderParams()); |
| 656 | + p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42)); |
| 657 | + |
| 658 | + HTMSensor<ObservableSensor<String[]>> htmSensor = (HTMSensor<ObservableSensor<String[]>>)sensor; |
| 659 | + |
| 660 | + Network network = Network.create("test network", p) |
| 661 | + .add(Network.createRegion("r1") |
| 662 | + .add(Network.createLayer("1", p) |
| 663 | + .add(Anomaly.create()) |
| 664 | + .add(new TemporalMemory()) |
| 665 | + .add(new SpatialPooler()) |
| 666 | + .add(htmSensor))); |
| 667 | + |
| 668 | + network.start(); |
| 669 | + |
| 670 | + network.observe().subscribe(new Observer<Inference>() { |
| 671 | + @Override public void onCompleted() { |
| 672 | + assertEquals(0, anomaly, 0); |
| 673 | + completed = true; |
| 674 | + } |
| 675 | + @Override public void onError(Throwable e) { e.printStackTrace(); } |
| 676 | + @Override public void onNext(Inference output) { |
| 677 | + //System.out.println(output.getRecordNum() + ": input = " + Arrays.toString(output.getEncoding()));//output = " + Arrays.toString(output.getSDR()) + ", " + output.getAnomalyScore()); |
| 678 | + if(output.getAnomalyScore() < anomaly) { |
| 679 | + anomaly = output.getAnomalyScore(); |
| 680 | + System.out.println("anomaly = " + anomaly); |
| 681 | + } |
| 682 | + } |
| 683 | + }); |
| 684 | + |
| 685 | + int x = 0; |
| 686 | + for(int i = 0;i < 100;i++) { |
| 687 | + x = i % 10; |
| 688 | + manual.onNext("7/12/10 13:10,35.3,40.6457;-73.7" + x + "692;" + x); //5 = meters per second |
| 689 | + } |
| 690 | + |
| 691 | + manual.onComplete(); |
| 692 | + |
| 693 | + Layer<?> l = network.lookup("r1").lookup("1"); |
| 694 | + try { |
| 695 | + l.getLayerThread().join(); |
| 696 | + }catch(Exception e) { |
| 697 | + e.printStackTrace(); |
| 698 | + } |
| 699 | + |
| 700 | + assertTrue(completed); |
| 701 | + |
| 702 | + } |
| 703 | + |
| 704 | + String errorMessage = null; |
| 705 | + @Test |
| 706 | + public void testObservableWithCoordinateEncoder_NEGATIVE() { |
| 707 | + Publisher manual = Publisher.builder() |
| 708 | + .addHeader("timestamp,consumption,location") |
| 709 | + .addHeader("datetime,float,geo") |
| 710 | + .addHeader("T,,").build(); |
| 711 | + |
| 712 | + Sensor<ObservableSensor<String[]>> sensor = Sensor.create( |
| 713 | + ObservableSensor::create, SensorParams.create(Keys::obs, "", manual)); |
| 714 | + |
| 715 | + Parameters p = NetworkTestHarness.getParameters().copy(); |
| 716 | + p = p.union(NetworkTestHarness.getGeospatialTestEncoderParams()); |
| 717 | + p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42)); |
| 718 | + |
| 719 | + HTMSensor<ObservableSensor<String[]>> htmSensor = (HTMSensor<ObservableSensor<String[]>>)sensor; |
| 720 | + |
| 721 | + Network network = Network.create("test network", p) |
| 722 | + .add(Network.createRegion("r1") |
| 723 | + .add(Network.createLayer("1", p) |
| 724 | + .alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE) |
| 725 | + .add(Anomaly.create()) |
| 726 | + .add(new TemporalMemory()) |
| 727 | + .add(new SpatialPooler()) |
| 728 | + .add(htmSensor))); |
| 729 | + |
| 730 | + network.observe().subscribe(new Observer<Inference>() { |
| 731 | + @Override public void onCompleted() { |
| 732 | + //Should never happen here. |
| 733 | + assertEquals(0, anomaly, 0); |
| 734 | + completed = true; |
| 735 | + } |
| 736 | + @Override public void onError(Throwable e) { |
| 737 | + errorMessage = e.getMessage(); |
| 738 | + network.halt(); |
| 739 | + } |
| 740 | + @Override public void onNext(Inference output) {} |
| 741 | + }); |
| 742 | + |
| 743 | + network.start(); |
| 744 | + |
| 745 | + int x = 0; |
| 746 | + for(int i = 0;i < 100;i++) { |
| 747 | + x = i % 10; |
| 748 | + manual.onNext("7/12/10 13:10,35.3,40.6457;-73.7" + x + "692;" + x); //1st "x" is attempt to vary coords, 2nd "x" = meters per second |
| 749 | + } |
| 750 | + |
| 751 | + manual.onComplete(); |
| 752 | + |
| 753 | + Layer<?> l = network.lookup("r1").lookup("1"); |
| 754 | + try { |
| 755 | + l.getLayerThread().join(); |
| 756 | + }catch(Exception e) { |
| 757 | + assertEquals(InterruptedException.class, e.getClass()); |
| 758 | + } |
| 759 | + |
| 760 | + // Assert onNext condition never gets set |
| 761 | + assertFalse(completed); |
| 762 | + assertEquals("Cannot autoclassify with raw array input or " + |
| 763 | + "Coordinate based encoders... Remove auto classify setting.", errorMessage); |
| 764 | + } |
| 765 | + |
| 766 | + |
639 | 767 | } |
0 commit comments