diff --git a/README.md b/README.md index a9a8506..935eb49 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ The pipeline consists of several distinct blocks that can be customized and the ## Install Install the following libraries to get started: ```bash -pip install scikit-learn umap-learn sentence_transformers faiss-cpu plotly matplotlib datasets +pip install scikit-learn umap-learn sentence_transformers faiss-cpu datamapplot datasets ``` Clone this repository and navigate to the folder: ```bash @@ -100,4 +100,4 @@ You can also change how the clusters are labeled (multiple topics (default) vs s ## Examples -Check the `examples` folder for an example of clustering and topic labeling applied to the [AutoMathText](https://huggingface.co/datasets/math-ai/AutoMathText/) dataset, utilizing [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)'s web labeling approach. \ No newline at end of file +Check the `examples` folder for an example of clustering and topic labeling applied to the [AutoMathText](https://huggingface.co/datasets/math-ai/AutoMathText/) dataset, utilizing [Cosmopedia](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)'s web labeling approach. diff --git a/src/text_clustering.py b/src/text_clustering.py index 26e22df..b67fe3f 100644 --- a/src/text_clustering.py +++ b/src/text_clustering.py @@ -38,9 +38,7 @@ def __init__( embed_agg_strategy=None, umap_components=2, umap_metric="cosine", - dbscan_eps=0.08, - dbscan_min_samples=50, - dbscan_n_jobs=16, + min_cluster_size=100, summary_create=True, summary_model="mistralai/Mixtral-8x7B-Instruct-v0.1", topic_mode="multiple_topics", @@ -59,9 +57,7 @@ def __init__( self.umap_components = umap_components self.umap_metric = umap_metric - self.dbscan_eps = dbscan_eps - self.dbscan_min_samples = dbscan_min_samples - self.dbscan_n_jobs = dbscan_n_jobs + self.min_cluster_size = min_cluster_size self.summary_create = summary_create self.summary_model = summary_model @@ -162,12 +158,12 @@ def project(self, embeddings): def cluster(self, embeddings): print( - f"Using DBSCAN (eps, nim_samples)=({self.dbscan_eps,}, {self.dbscan_min_samples})" + f"Using HDBSCAN (min_cluster_size)=({self.min_cluster_size})" ) - clustering = DBSCAN( - eps=self.dbscan_eps, - min_samples=self.dbscan_min_samples, - n_jobs=self.dbscan_n_jobs, + clustering = fast_hdbscan.HDBSCAN( + min_cluster_size=self.min_cluster_size, + min_samples=10, + cluster_selection_method="leaf", ).fit(embeddings) return clustering.labels_ @@ -195,8 +191,8 @@ def summarize(self, texts, labels): examples=examples, instruction=self.summary_instruction ) response = client.text_generation(request) - if label == 0: - print(f"Request:\n{request}") + # if label == 0: + # print(f"Request:\n{request}") cluster_summaries[label] = self._postprocess_response(response) print(f"Number of clusters is {len(cluster_summaries)}") return cluster_summaries @@ -290,24 +286,77 @@ def load(self, folder): y = np.mean([self.projections[doc, 1] for doc in self.label2docs[label]]) self.cluster_centers[label] = (x, y) - def show(self, interactive=False): - df = pd.DataFrame( - data={ - "X": self.projections[:, 0], - "Y": self.projections[:, 1], - "labels": self.cluster_labels, - "content_display": [ - textwrap.fill(txt[:1024], 64) for txt in self.texts - ], - } + def show(self, plot_lib, **kwargs): + if plot_lib == "datamapplot": + self._show_dmp(**kwargs) + elif plot_lib == "plotly": + df = pd.DataFrame( + data={ + "X": self.projections[:, 0], + "Y": self.projections[:, 1], + "labels": self.cluster_labels, + "content_display": [ + textwrap.fill(txt[:1024], 64) for txt in self.texts + ], + } + ) + self._show_plotly(df, **kwargs) + elif plot_lib in ("matplotlib", "mpl"): + df = pd.DataFrame( + data={ + "X": self.projections[:, 0], + "Y": self.projections[:, 1], + "labels": self.cluster_labels, + "content_display": [ + textwrap.fill(txt[:1024], 64) for txt in self.texts + ], + } + ) + self._show_mpl(df, **kwargs) + else: + raise ValueError("plot_lib should be one of 'datamapplot', 'plotly' or 'matplotlib'") + + + def _show_dmp(self, interactive=False, title=None, sub_title=None, font=None, enable_search=True, **kwargs): + label_vector = np.asarray( + [self.cluster_summaries[x] if x >= 0 else "Unlabelled" for x in self.cluster_labels], + dtype=object ) + if font is None: + font_family = "Poppins" + else: + font_family = font + if interactive: - self._show_plotly(df) + hover_text = [ + text[:1021] + "..." if len(text) > 1024 else text + for text in self.texts + ] + plot = datamapplot.create_interactive_plot( + self.projections, + label_vector, + hover_text=hover_text, + title=title, + sub_title=sub_title, + font_family=font_family, + enable_search=enable_search, + **kwargs, + ) + return plot else: - self._show_mpl(df) + fig, ax = datamapplot.create_plot( + self.projections, + label_vector, + title=title, + sub_title=sub_title, + fontfamily=font_family, + **kwargs, + ) + return fig + - def _show_mpl(self, df): + def _show_mpl(self, df, **kwargs): fig, ax = plt.subplots(figsize=(12, 8), dpi=300) df["color"] = df["labels"].apply(lambda x: "C0" if x==-1 else f"C{(x%9)+1}") @@ -341,7 +390,7 @@ def _show_mpl(self, df): t.set_bbox(dict(facecolor='white', alpha=0.9, linewidth=0, boxstyle='square,pad=0.1')) ax.set_axis_off() - def _show_plotly(self, df): + def _show_plotly(self, df, **kwargs): fig = px.scatter( df, x="X",