Skip to content

Datamap plot for plotting; HDBSCAN for clustering #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The pipeline consists of several distinct blocks that can be customized and the
## Install
Install the following libraries to get started:
```bash
pip install scikit-learn umap-learn sentence_transformers faiss-cpu plotly matplotlib datasets
pip install scikit-learn umap-learn sentence_transformers faiss-cpu datamapplot datasets
```
Clone this repository and navigate to the folder:
```bash
Expand Down Expand Up @@ -100,4 +100,4 @@ You can also change how the clusters are labeled (multiple topics (default) vs s

## Examples

Check the `examples` folder for an example of clustering and topic labeling applied to the [AutoMathText](https://huggingface.co/datasets/math-ai/AutoMathText/) dataset, utilizing [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)'s web labeling approach.
Check the `examples` folder for an example of clustering and topic labeling applied to the [AutoMathText](https://huggingface.co/datasets/math-ai/AutoMathText/) dataset, utilizing [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)'s web labeling approach.
103 changes: 76 additions & 27 deletions src/text_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
embed_agg_strategy=None,
umap_components=2,
umap_metric="cosine",
dbscan_eps=0.08,
dbscan_min_samples=50,
dbscan_n_jobs=16,
min_cluster_size=100,
summary_create=True,
summary_model="mistralai/Mixtral-8x7B-Instruct-v0.1",
topic_mode="multiple_topics",
Expand All @@ -59,9 +57,7 @@ def __init__(
self.umap_components = umap_components
self.umap_metric = umap_metric

self.dbscan_eps = dbscan_eps
self.dbscan_min_samples = dbscan_min_samples
self.dbscan_n_jobs = dbscan_n_jobs
self.min_cluster_size = min_cluster_size

self.summary_create = summary_create
self.summary_model = summary_model
Expand Down Expand Up @@ -162,12 +158,12 @@ def project(self, embeddings):

def cluster(self, embeddings):
print(
f"Using DBSCAN (eps, nim_samples)=({self.dbscan_eps,}, {self.dbscan_min_samples})"
f"Using HDBSCAN (min_cluster_size)=({self.min_cluster_size})"
)
clustering = DBSCAN(
eps=self.dbscan_eps,
min_samples=self.dbscan_min_samples,
n_jobs=self.dbscan_n_jobs,
clustering = fast_hdbscan.HDBSCAN(
min_cluster_size=self.min_cluster_size,
min_samples=10,
cluster_selection_method="leaf",
).fit(embeddings)

return clustering.labels_
Expand Down Expand Up @@ -195,8 +191,8 @@ def summarize(self, texts, labels):
examples=examples, instruction=self.summary_instruction
)
response = client.text_generation(request)
if label == 0:
print(f"Request:\n{request}")
# if label == 0:
# print(f"Request:\n{request}")
cluster_summaries[label] = self._postprocess_response(response)
print(f"Number of clusters is {len(cluster_summaries)}")
return cluster_summaries
Expand Down Expand Up @@ -290,24 +286,77 @@ def load(self, folder):
y = np.mean([self.projections[doc, 1] for doc in self.label2docs[label]])
self.cluster_centers[label] = (x, y)

def show(self, interactive=False):
df = pd.DataFrame(
data={
"X": self.projections[:, 0],
"Y": self.projections[:, 1],
"labels": self.cluster_labels,
"content_display": [
textwrap.fill(txt[:1024], 64) for txt in self.texts
],
}
def show(self, plot_lib, **kwargs):
if plot_lib == "datamapplot":
self._show_dmp(**kwargs)
elif plot_lib == "plotly":
df = pd.DataFrame(
data={
"X": self.projections[:, 0],
"Y": self.projections[:, 1],
"labels": self.cluster_labels,
"content_display": [
textwrap.fill(txt[:1024], 64) for txt in self.texts
],
}
)
self._show_plotly(df, **kwargs)
elif plot_lib in ("matplotlib", "mpl"):
df = pd.DataFrame(
data={
"X": self.projections[:, 0],
"Y": self.projections[:, 1],
"labels": self.cluster_labels,
"content_display": [
textwrap.fill(txt[:1024], 64) for txt in self.texts
],
}
)
self._show_mpl(df, **kwargs)
else:
raise ValueError("plot_lib should be one of 'datamapplot', 'plotly' or 'matplotlib'")


def _show_dmp(self, interactive=False, title=None, sub_title=None, font=None, enable_search=True, **kwargs):
label_vector = np.asarray(
[self.cluster_summaries[x] if x >= 0 else "Unlabelled" for x in self.cluster_labels],
dtype=object
)

if font is None:
font_family = "Poppins"
else:
font_family = font

if interactive:
self._show_plotly(df)
hover_text = [
text[:1021] + "..." if len(text) > 1024 else text
for text in self.texts
]
plot = datamapplot.create_interactive_plot(
self.projections,
label_vector,
hover_text=hover_text,
title=title,
sub_title=sub_title,
font_family=font_family,
enable_search=enable_search,
**kwargs,
)
return plot
else:
self._show_mpl(df)
fig, ax = datamapplot.create_plot(
self.projections,
label_vector,
title=title,
sub_title=sub_title,
fontfamily=font_family,
**kwargs,
)
return fig


def _show_mpl(self, df):
def _show_mpl(self, df, **kwargs):
fig, ax = plt.subplots(figsize=(12, 8), dpi=300)

df["color"] = df["labels"].apply(lambda x: "C0" if x==-1 else f"C{(x%9)+1}")
Expand Down Expand Up @@ -341,7 +390,7 @@ def _show_mpl(self, df):
t.set_bbox(dict(facecolor='white', alpha=0.9, linewidth=0, boxstyle='square,pad=0.1'))
ax.set_axis_off()

def _show_plotly(self, df):
def _show_plotly(self, df, **kwargs):
fig = px.scatter(
df,
x="X",
Expand Down