Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 7afd85a

Browse files
yashlambapdxjohnny
authored andcommitted
model: scikit: Create models and configs dynamicly
1 parent d478dee commit 7afd85a

File tree

6 files changed

+534
-329
lines changed

6 files changed

+534
-329
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727
- Documentation on how to use ML models on docs Models plugin page.
2828
- Mailing list info
2929
- Issue template for questions
30+
- Multiple Scikit Models with dynamic config
3031
### Changed
3132
- feature/codesec became it's own branch, binsec
3233
- BaseOrchestratorContext `run_operations` strict is default to true. With

model/scikit/dffml_model_scikit/sciLR.py

Lines changed: 0 additions & 241 deletions
This file was deleted.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2019 Intel Corporation
3+
"""
4+
Base class for Scikit models
5+
"""
6+
import os
7+
import json
8+
import hashlib
9+
from pathlib import Path
10+
from typing import AsyncIterator, Tuple, Any, NamedTuple
11+
12+
import joblib
13+
import numpy as np
14+
import pandas as pd
15+
16+
from dffml.repo import Repo
17+
from dffml.source.source import Sources
18+
from dffml.accuracy import Accuracy
19+
from dffml.model.model import ModelConfig, ModelContext, Model
20+
21+
22+
class ScikitConfig(ModelConfig, NamedTuple):
23+
directory: str
24+
predict: str
25+
26+
27+
class ScikitContext(ModelContext):
28+
def __init__(self, parent, features):
29+
super().__init__(parent, features)
30+
self.features = self.applicable_features(features)
31+
self._features_hash = self._feature_predict_hash()
32+
self.clf = None
33+
34+
@property
35+
def confidence(self):
36+
return self.parent.saved.get(self._features_hash, None)
37+
38+
@confidence.setter
39+
def confidence(self, confidence):
40+
self.parent.saved[self._features_hash] = confidence
41+
42+
def _feature_predict_hash(self):
43+
return hashlib.sha384(
44+
"".join(self.features + [self.parent.config.predict]).encode()
45+
).hexdigest()
46+
47+
def _filename(self):
48+
return os.path.join(
49+
self.parent.config.directory, self._features_hash + ".joblib"
50+
)
51+
52+
async def __aenter__(self):
53+
if os.path.isfile(self._filename()):
54+
self.clf = joblib.load(self._filename())
55+
else:
56+
self.clf = self.parent.SCIKIT_MODEL
57+
return self
58+
59+
async def __aexit__(self, exc_type, exc_value, traceback):
60+
joblib.dump(self.clf, self._filename())
61+
62+
async def train(self, sources: Sources):
63+
data = []
64+
async for repo in sources.with_features(self.features):
65+
feature_data = repo.features(
66+
self.features + [self.parent.config.predict]
67+
)
68+
data.append(feature_data)
69+
df = pd.DataFrame(data)
70+
xdata = np.array(df.drop([self.parent.config.predict], 1))
71+
ydata = np.array(df[self.parent.config.predict])
72+
self.logger.info("Number of input repos: {}".format(len(xdata)))
73+
self.clf.fit(xdata, ydata)
74+
joblib.dump(self.clf, self._filename())
75+
76+
async def accuracy(self, sources: Sources) -> Accuracy:
77+
data = []
78+
async for repo in sources.with_features(self.features):
79+
feature_data = repo.features(
80+
self.features + [self.parent.config.predict]
81+
)
82+
data.append(feature_data)
83+
df = pd.DataFrame(data)
84+
xdata = np.array(df.drop([self.parent.config.predict], 1))
85+
ydata = np.array(df[self.parent.config.predict])
86+
self.logger.debug("Number of input repos: {}".format(len(xdata)))
87+
self.confidence = self.clf.score(xdata, ydata)
88+
self.logger.debug("Model Accuracy: {}".format(self.confidence))
89+
return self.confidence
90+
91+
async def predict(
92+
self, repos: AsyncIterator[Repo]
93+
) -> AsyncIterator[Tuple[Repo, Any, float]]:
94+
if self.confidence is None:
95+
raise ValueError("Model Not Trained")
96+
async for repo in repos:
97+
feature_data = repo.features(self.features)
98+
df = pd.DataFrame(feature_data, index=[0])
99+
predict = np.array(df)
100+
self.logger.debug(
101+
"Predicted Value of {} for {}: {}".format(
102+
self.parent.config.predict,
103+
predict,
104+
self.clf.predict(predict),
105+
)
106+
)
107+
yield repo, self.clf.predict(predict)[0], self.confidence
108+
109+
110+
class Scikit(Model):
111+
def __init__(self, config) -> None:
112+
super().__init__(config)
113+
self.saved = {}
114+
115+
def _filename(self):
116+
return os.path.join(
117+
self.config.directory,
118+
hashlib.sha384(self.config.predict.encode()).hexdigest() + ".json",
119+
)
120+
121+
async def __aenter__(self) -> "Scikit":
122+
path = Path(self._filename())
123+
if path.is_file():
124+
self.saved = json.loads(path.read_text())
125+
return self
126+
127+
async def __aexit__(self, exc_type, exc_value, traceback):
128+
Path(self._filename()).write_text(json.dumps(self.saved))

0 commit comments

Comments
 (0)