Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
639 changes: 339 additions & 300 deletions .idea/workspace.xml

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ publishing {
maven(MavenPublication) {
groupId = 'de.dmi3y.behaiv'
artifactId = 'behaiv'
version = '0.3.0-alpha'
version = '0.3.7-alpha'


from components.java
Expand All @@ -40,13 +40,13 @@ repositories {

dependencies {
// This dependency is exported to consumers, that is to say found on their compile classpath.
implementation 'org.apache.commons:commons-math3:3.6.1'
implementation 'org.apache.commons:commons-lang3:3.9'

// This dependency is used internally, and not exposed to consumers on their own compile classpath.
implementation 'org.apache.commons:commons-lang3:3.9'
implementation 'com.google.guava:guava:28.0-jre'
implementation 'io.reactivex.rxjava3:rxjava:3.0.0-RC2'
implementation 'com.google.code.gson:gson:2.8.5'
compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.9.9'
compile group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: '2.9.9'


//ejml (replacement of dl4j)
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/de/dmi3y/behaiv/Behaiv.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import de.dmi3y.behaiv.provider.ProviderCallback;
import de.dmi3y.behaiv.session.CaptureSession;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.subjects.ReplaySubject;
import org.apache.commons.math3.util.Pair;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -100,7 +100,7 @@ protected CaptureSession getCurrentSession() {
@Override
public void onFeaturesCaptured(List<Pair<Double, String>> features) {
if (kernel.readyToPredict() && predict) {
subject.onNext(kernel.predictOne(features.stream().map(Pair::getFirst).collect(Collectors.toCollection(ArrayList::new))));
subject.onNext(kernel.predictOne(features.stream().map(Pair::getKey).collect(Collectors.toCollection(ArrayList::new))));
}
}

Expand All @@ -111,7 +111,7 @@ public void stopCapturing(boolean discard) {
}
String label = currentSession.getLabel();
List<Pair<Double, String>> features = currentSession.getFeatures();
kernel.updateSingle(features.stream().map(Pair::getFirst).collect(Collectors.toCollection(ArrayList::new)), label);
kernel.updateSingle(features.stream().map(Pair::getKey).collect(Collectors.toCollection(ArrayList::new)), label);
if (kernel.readyToPredict()) {
kernel.fit();
}
Expand Down
23 changes: 15 additions & 8 deletions src/main/java/de/dmi3y/behaiv/kernel/Kernel.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package de.dmi3y.behaiv.kernel;

import com.google.gson.Gson;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.dmi3y.behaiv.storage.BehaivStorage;
import org.apache.commons.math3.util.Pair;
import de.dmi3y.behaiv.tools.Pair;

import java.io.BufferedReader;
import java.io.BufferedWriter;
Expand All @@ -15,9 +16,11 @@ public abstract class Kernel {

protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;

public Kernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}


Expand Down Expand Up @@ -50,24 +53,28 @@ public void update(ArrayList<Pair<ArrayList<Double>, String>> data) {
}

public void updateSingle(ArrayList<Double> features, String label) {
data.add(Pair.create(features, label));
data.add(new Pair<>(features, label));
}

public abstract String predictOne(ArrayList<Double> features);

public void save(BehaivStorage storage) throws IOException {
final Gson gson = new Gson();

try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getNetworkFile(id)))) {
writer.write(gson.toJson(data));
writer.write(objectMapper.writeValueAsString(data));
}
}

public void restore(BehaivStorage storage) throws IOException {
final Gson gson = new Gson();

final TypeReference<ArrayList<Pair<ArrayList<Double>, String>>> typeReference = new TypeReference<ArrayList<Pair<ArrayList<Double>, String>>>() {
};
try (final BufferedReader reader = new BufferedReader(new FileReader(storage.getNetworkFile(id)))) {
data = ((ArrayList<Pair<ArrayList<Double>, String>>) gson.fromJson(reader.readLine(), data.getClass()));
final String content = reader.readLine();
if (content == null || content.isEmpty()) {
data = new ArrayList<>();
} else {
data = objectMapper.readValue(content, typeReference);
}
}
}
}
29 changes: 17 additions & 12 deletions src/main/java/de/dmi3y/behaiv/kernel/LogisticRegressionKernel.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package de.dmi3y.behaiv.kernel;

import com.google.gson.Gson;
import com.fasterxml.jackson.core.type.TypeReference;
import de.dmi3y.behaiv.kernel.logistic.LogisticUtils;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.util.Pair;
import org.ejml.simple.SimpleMatrix;

import java.io.BufferedReader;
Expand Down Expand Up @@ -41,19 +41,19 @@ public boolean isEmpty() {
@Override
public void fit(ArrayList<Pair<ArrayList<Double>, String>> data) {
this.data = data;
labels = this.data.stream().map(Pair::getSecond).distinct().collect(Collectors.toList());
labels = this.data.stream().map(Pair::getValue).distinct().collect(Collectors.toList());
if (readyToPredict()) {


//features
double[][] inputs = this.data.stream().map(Pair::getFirst).map(l -> l.toArray(new Double[0]))
double[][] inputs = this.data.stream().map(Pair::getKey).map(l -> l.toArray(new Double[0]))
.map(ArrayUtils::toPrimitive)
.toArray(double[][]::new);

//labels
double[][] labelArray = new double[data.size()][labels.size()];
for (int i = 0; i < data.size(); i++) {
int dummyPos = labels.indexOf(data.get(i).getSecond());
int dummyPos = labels.indexOf(data.get(i).getValue());
labelArray[i][dummyPos] = 1.0;
}

Expand Down Expand Up @@ -107,22 +107,21 @@ public String predictOne(ArrayList<Double> features) {

@Override
public void save(BehaivStorage storage) throws IOException {
if (theta == null && data == null) {
if (theta == null && (data == null || data.isEmpty())) {
throw new IOException("Not enough data to save, network data is empty");
}
if (labels == null) {
if (labels == null || labels.isEmpty()) {
String message;
message = "Kernel collected labels but failed to get data, couldn't save network.";
throw new IOException(message);
}
if (theta == null) {
super.save(storage);

} else {
theta.saveToFileBinary(storage.getNetworkFile(id).toString());
try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getNetworkMetadataFile(id)))) {
final Gson gson = new Gson();
writer.write(gson.toJson(labels));

writer.write(objectMapper.writeValueAsString(labels));
} catch (Exception e) {
e.printStackTrace();
}
Expand All @@ -139,8 +138,14 @@ public void restore(BehaivStorage storage) throws IOException {
}

try (final BufferedReader reader = new BufferedReader(new FileReader(storage.getNetworkMetadataFile(id)))) {
final Gson gson = new Gson();
labels = ((List<String>) gson.fromJson(reader.readLine(), labels.getClass()));
final TypeReference<ArrayList<String>> typeReference = new TypeReference<ArrayList<String>>() {
};
final String labelsData = reader.readLine();
if (labelsData == null) {
labels = new ArrayList<>();
} else {
labels = objectMapper.readValue(labelsData, typeReference);
}
} catch (IOException e) {
e.printStackTrace();
}
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/de/dmi3y/behaiv/provider/DayTimeProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.reactivex.rxjava3.core.Single;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Collections;
Expand Down Expand Up @@ -45,7 +46,7 @@ public List<String> availableFeatures() {
if (compound) {
return Collections.singletonList("day_time");
}
final List<String> strings = Arrays.asList("day_time_hours", "day_time_minutes");
final List<String> strings = new ArrayList<>(Arrays.asList("day_time_hours", "day_time_minutes"));
if (secondsEnabled) {
strings.add("day_time_seconds");
}
Expand All @@ -64,10 +65,10 @@ public Single<List<Double>> getFeature() {
if (compound) {
return Single.just(Collections.singletonList(hours * 60.0 + minutes));
}
final List<Double> listToSend = Arrays.asList(
final List<Double> listToSend = new ArrayList<>(Arrays.asList(
(double) hours,
(double) minutes
);
));
if (secondsEnabled) {
final int seconds = calendar.get(Calendar.SECOND);
listToSend.add((double) seconds);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.dmi3y.behaiv.provider;

import org.apache.commons.math3.util.Pair;
import de.dmi3y.behaiv.tools.Pair;

import java.util.List;

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/de/dmi3y/behaiv/session/CaptureSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import de.dmi3y.behaiv.Behaiv;
import de.dmi3y.behaiv.provider.Provider;
import org.apache.commons.math3.util.Pair;
import de.dmi3y.behaiv.tools.Pair;

import java.util.ArrayList;
import java.util.InputMismatchException;
Expand Down
51 changes: 51 additions & 0 deletions src/main/java/de/dmi3y/behaiv/tools/Pair.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package de.dmi3y.behaiv.tools;

public class Pair<K, V> {
private K key;
private V value;

public Pair(K k, V v) {
this.key = k;
this.value = v;
}


public Pair(de.dmi3y.behaiv.tools.Pair<? extends K, ? extends V> entry) {
this(entry.getKey(), entry.getValue());
}

public Pair() {
}

public K getKey() {
return this.key;
}

public V getValue() {
return this.value;
}

public int hashCode() {
int result = this.key == null ? 0 : this.key.hashCode();
int h = this.value == null ? 0 : this.value.hashCode();
result = 37 * result + h ^ h >>> 16;
return result;
}

public String toString() {
return "[" + this.getKey() + ", " + this.getValue() + "]";
}

public static <K, V> Pair<K, V> create(K k, V v) {
return new Pair<K, V>(k, v);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Pair<?, ?> pair = (Pair<?, ?>) o;
return key.equals(pair.key) &&
value.equals(pair.value);
}
}
7 changes: 3 additions & 4 deletions src/test/java/de/dmi3y/behaiv/LibraryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
package de.dmi3y.behaiv;

import de.dmi3y.behaiv.kernel.KernelTest;
import de.dmi3y.behaiv.kernel.LogisticRegressionKernel;
import de.dmi3y.behaiv.provider.TestProvider;
import io.reactivex.rxjava3.core.Observable;
import org.apache.commons.math3.util.Pair;
import de.dmi3y.behaiv.tools.Pair;
import org.junit.Before;
import org.junit.Test;

Expand Down Expand Up @@ -37,10 +36,10 @@ public void setUp() throws Exception {
@Test
public void behaivTest_basicTestFlow_predictsJob() throws Exception {
for (Pair<ArrayList<Double>, String> fToL : data) {
ArrayList<Double> features = fToL.getFirst();
ArrayList<Double> features = fToL.getKey();
timeProvider.next(new Double[]{features.get(0)});
positionProvider.next(new Double[]{features.get(1), features.get(2)});
capture(fToL.getSecond());
capture(fToL.getValue());
}

Observable<String> register = behaiv.subscribe();
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/de/dmi3y/behaiv/kernel/KernelTest.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package de.dmi3y.behaiv.kernel;

import de.dmi3y.behaiv.storage.SimpleStorage;
import org.apache.commons.math3.util.Pair;
import de.dmi3y.behaiv.tools.Pair;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -39,8 +39,8 @@ public void setUp() throws Exception {
@Test
public void setTreshold() {
dummyKernel.setTreshold(1L);
dummyKernel.data.add(Pair.create(null, null));
dummyKernel.data.add(Pair.create(null, null));
dummyKernel.data.add(new Pair<>(null, null));
dummyKernel.data.add(new Pair<>(null, null));
boolean readyToPredict = dummyKernel.readyToPredict();
assertTrue(readyToPredict);
dummyKernel.setTreshold(10L);
Expand Down
Loading