Skip to content

Gh stronger detection classifiers #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,5 @@ ENV/
# OS Files
.DS_Store

# vcode stuff
.vcode/
4 changes: 3 additions & 1 deletion sdmetrics/single_table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from sdmetrics.single_table.bayesian_network import BNLikelihood, BNLogLikelihood
from sdmetrics.single_table.detection.base import DetectionMetric
from sdmetrics.single_table.detection.sklearn import (
LogisticDetection, ScikitLearnClassifierDetectionMetric, SVCDetection)
GradientBoostingDetection, LogisticDetection, ScikitLearnClassifierDetectionMetric,
SVCDetection)
from sdmetrics.single_table.efficacy.base import MLEfficacyMetric
from sdmetrics.single_table.efficacy.binary import (
BinaryAdaBoostClassifier, BinaryDecisionTreeClassifier, BinaryEfficacyMetric,
Expand Down Expand Up @@ -47,6 +48,7 @@
'DetectionMetric',
'LogisticDetection',
'SVCDetection',
'GradientBoostingDetection',
'ScikitLearnClassifierDetectionMetric',
'MLEfficacyMetric',
'BinaryEfficacyMetric',
Expand Down
4 changes: 3 additions & 1 deletion sdmetrics/single_table/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Machine Learning Detection metrics for single table datasets."""

from sdmetrics.single_table.detection.sklearn import LogisticDetection, SVCDetection
from sdmetrics.single_table.detection.sklearn import (
GradientBoostingDetection, LogisticDetection, SVCDetection)

__all__ = [
'GradientBoostingDetection',
'LogisticDetection',
'SVCDetection'
]
17 changes: 17 additions & 0 deletions sdmetrics/single_table/detection/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""scikit-learn based DetectionMetrics for single table datasets."""

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -67,3 +68,19 @@ class SVCDetection(ScikitLearnClassifierDetectionMetric):
@staticmethod
def _get_classifier():
return SVC(probability=True, gamma='scale')


class GradientBoostingDetection(ScikitLearnClassifierDetectionMetric):
"""ScikitLearnClassifierDetectionMetric based on a GradientBoostingClassifier.

This metric builds a GradientBoostingClassifier Classifier that learns to tell the synthetic
data apart from the real data, which later on is evaluated using Cross Validation.

The output of the metric is one minus the average ROC AUC score obtained.
"""

name = 'GradientBoosting Detection'

@staticmethod
def _get_classifier():
return GradientBoostingClassifier()
4 changes: 3 additions & 1 deletion tests/integration/single_table/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.single_table.bayesian_network import BNLikelihood, BNLogLikelihood
from sdmetrics.single_table.detection import LogisticDetection, SVCDetection
from sdmetrics.single_table.detection import (
GradientBoostingDetection, LogisticDetection, SVCDetection)
from sdmetrics.single_table.multi_column_pairs import (
ContingencySimilarity, ContinuousKLDivergence, DiscreteKLDivergence)
from sdmetrics.single_table.multi_single_column import (
Expand All @@ -17,6 +18,7 @@
METRICS = [
CSTest,
KSComplement,
GradientBoostingDetection,
LogisticDetection,
SVCDetection,
ContinuousKLDivergence,
Expand Down