Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 633ee5e

Browse files
committed
adding tf regression model
1 parent e510002 commit 633ee5e

File tree

1 file changed

+278
-0
lines changed
  • model/tensorflow/dffml_model_tensorflow

1 file changed

+278
-0
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""
2+
Uses Tensorflow to create a generic DNN which learns on all of the features in a
3+
repo.
4+
"""
5+
import os
6+
import abc
7+
import pydoc
8+
import hashlib
9+
import inspect
10+
from dataclasses import dataclass
11+
from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type
12+
13+
import numpy as np
14+
import tensorflow
15+
16+
from dffml.repo import Repo
17+
from dffml.feature import Feature, Features
18+
from dffml.source.source import Sources
19+
from dffml.model.model import ModelConfig, ModelContext, Model
20+
from dffml.accuracy import Accuracy
21+
from dffml.util.entrypoint import entry_point
22+
from dffml.base import BaseConfig
23+
from dffml.util.cli.arg import Arg
24+
25+
from dffml_model_tensorflow.dnnc import TensorflowModelContext
26+
27+
DEBUG =True
28+
@dataclass(init=True, eq=True)
29+
class DNNRegressionModelConfig:
30+
directory: str
31+
steps: int
32+
epochs: int
33+
hidden: List[int]
34+
label_name : str # feature_name holding target values
35+
36+
37+
38+
class DNNRegressionModelContext(TensorflowModelContext):
39+
"""
40+
Model using tensorflow to make predictions. Handels creation of feature
41+
columns for real valued, string, and list of real valued features.
42+
"""
43+
44+
def __init__(self, config, parent) -> None:
45+
super().__init__(config, parent)
46+
self.model_dir_path = self._model_dir_path()
47+
self.label_name=self.parent.config.label_name
48+
self.all_features=self.features+[self.label_name]
49+
50+
def _model_dir_path(self):
51+
"""
52+
Creates the path to the model dir by using the provided model dir and
53+
the sha384 hash of the concatenated feature names.
54+
"""
55+
if self.parent.config.directory is None:
56+
return None
57+
_to_hash = self.features + list(map(str, self.parent.config.hidden))
58+
model = hashlib.sha384("".join(_to_hash).encode("utf-8")).hexdigest()
59+
if not os.path.isdir(self.parent.config.directory):
60+
raise NotADirectoryError(
61+
"%s is not a directory" % (self.parent.config.directory)
62+
)
63+
return os.path.join(self.parent.config.directory, model)
64+
65+
@property
66+
def model(self):
67+
"""
68+
Generates or loads a model
69+
"""
70+
if self._model is not None:
71+
return self._model
72+
self.logger.debug(
73+
"Loading model ")
74+
75+
_head=tensorflow.contrib.estimator.regression_head()
76+
self._model = tensorflow.estimator.DNNEstimator(
77+
head=_head,
78+
feature_columns=list(self.feature_columns.values()),
79+
hidden_units=self.parent.config.hidden,
80+
model_dir=self.model_dir_path,
81+
)
82+
83+
return self._model
84+
85+
async def training_input_fn(self,sources: Sources,batch_size=20,shuffle=False,epochs=1,**kwargs,):
86+
"""
87+
Uses the numpy input function with data from repo features.
88+
"""
89+
self.logger.debug("Training on features: %r", self.features)
90+
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
91+
y_cols = []
92+
93+
async for repo in sources.with_features(self.all_features ):
94+
for feature, results in repo.features(self.features).items():
95+
96+
x_cols[feature].append(np.array(results))
97+
y_cols.append(repo.feature(self.label_name))
98+
99+
y_cols = np.array(y_cols)
100+
for feature in x_cols:
101+
x_cols[feature] = np.array(x_cols[feature])
102+
self.logger.info("------ Repo Data ------")
103+
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
104+
self.logger.info("y_cols: %d", len(y_cols))
105+
self.logger.info("-----------------------")
106+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
107+
x_cols,
108+
y_cols,
109+
batch_size=batch_size,
110+
shuffle=shuffle,
111+
num_epochs=epochs,
112+
**kwargs,
113+
)
114+
return input_fn
115+
116+
async def evaluate_input_fn(self,sources: Sources,batch_size=20,shuffle=False,epochs=1,**kwargs,):
117+
"""
118+
Uses the numpy input function with data from repo features.
119+
"""
120+
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
121+
y_cols = []
122+
123+
async for repo in sources.with_features(self.all_features ):
124+
for feature, results in repo.features(self.features).items():
125+
x_cols[feature].append(np.array(results))
126+
y_cols.append(repo.feature(self.label_name))
127+
128+
y_cols = np.array(y_cols)
129+
for feature in x_cols:
130+
x_cols[feature] = np.array(x_cols[feature])
131+
self.logger.info("------ Repo Data ------")
132+
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
133+
self.logger.info("y_cols: %d", len(y_cols))
134+
self.logger.info("-----------------------")
135+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
136+
x_cols,
137+
y_cols,
138+
batch_size=batch_size,
139+
shuffle=shuffle,
140+
num_epochs=epochs,
141+
**kwargs,
142+
)
143+
return input_fn
144+
145+
async def predict_input_fn(self, repos: AsyncIterator[Repo], **kwargs):
146+
"""
147+
Uses the numpy input function with data from repo features.
148+
"""
149+
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
150+
ret_repos = []
151+
async for repo in repos:
152+
if not repo.features(self.features):
153+
continue
154+
ret_repos.append(repo)
155+
for feature, results in repo.features(self.features).items():
156+
x_cols[feature].append(np.array(results))
157+
for feature in x_cols:
158+
x_cols[feature] = np.array(x_cols[feature])
159+
self.logger.info("------ Repo Data ------")
160+
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
161+
self.logger.info("-----------------------")
162+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
163+
x_cols, shuffle=False, num_epochs=1, **kwargs
164+
)
165+
return input_fn, ret_repos
166+
167+
async def train(self, sources: Sources):
168+
"""
169+
Train on data submitted via classify.
170+
"""
171+
input_fn = await self.training_input_fn(
172+
sources,
173+
batch_size=20,
174+
shuffle=True,
175+
epochs=self.parent.config.epochs,
176+
)
177+
self.model.train(input_fn=input_fn, steps=self.parent.config.steps)
178+
179+
async def accuracy(self, sources: Sources) -> Accuracy:
180+
"""
181+
Evaluates the accuracy of our model after training using the input repos
182+
as test data.
183+
"""
184+
if not os.path.isdir(self.model_dir_path):
185+
raise NotADirectoryError("Model not trained")
186+
input_fn = await self.evaluate_input_fn(
187+
sources, batch_size=20, shuffle=False, epochs=1
188+
)
189+
metrics = self.model.evaluate(input_fn=input_fn)
190+
return Accuracy(1-metrics["loss"]) # 1 - mse
191+
192+
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
193+
"""
194+
Uses trained data to make a prediction about the quality of a repo.
195+
"""
196+
197+
if not os.path.isdir(self.model_dir_path):
198+
raise NotADirectoryError("Model not trained")
199+
# Create the input function
200+
input_fn, predict_repo = await self.predict_input_fn(repos)
201+
# Makes predictions on
202+
predictions = self.model.predict(input_fn=input_fn)
203+
204+
for repo, pred_dict in zip(predict_repo, predictions):
205+
repo.predicted(float(pred_dict["predictions"]),float('nan')) # 0,arbitary number to match fun signature -> (value,confidence)
206+
207+
yield repo
208+
209+
210+
@entry_point("tfdnnr")
211+
class DNNRegressionModel(Model):
212+
213+
214+
CONTEXT = DNNRegressionModelContext
215+
216+
@classmethod
217+
def args(cls, args, *above) -> Dict[str, Arg]:
218+
cls.config_set(
219+
args,
220+
above,
221+
"directory",
222+
Arg(
223+
default=os.path.join(
224+
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
225+
),
226+
help="Directory where state should be saved",
227+
),
228+
)
229+
cls.config_set(
230+
args,
231+
above,
232+
"steps",
233+
Arg(
234+
type=int,
235+
default=3000,
236+
help="Number of steps to train the model",
237+
),
238+
)
239+
cls.config_set(
240+
args,
241+
above,
242+
"epochs",
243+
Arg(
244+
type=int,
245+
default=30,
246+
help="Number of iterations to pass over all repos in a source",
247+
),
248+
)
249+
cls.config_set(
250+
args,
251+
above,
252+
"hidden",
253+
Arg(
254+
type=int,
255+
nargs="+",
256+
default=[12, 40, 15],
257+
help="List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
258+
),
259+
)
260+
cls.config_set(
261+
args,
262+
above,
263+
"label_name",
264+
Arg(help="Feature name holding truth value"),
265+
)
266+
267+
return args
268+
269+
@classmethod
270+
def config(cls, config, *above) -> BaseConfig:
271+
return DNNRegressionModelConfig(
272+
directory=cls.config_get(config, above, "directory"),
273+
steps=cls.config_get(config, above, "steps"),
274+
epochs=cls.config_get(config, above, "epochs"),
275+
hidden=cls.config_get(config, above, "hidden"),
276+
label_name=cls.config_get(config, above, "label_name"),
277+
278+
)

0 commit comments

Comments
 (0)