Skip to content

Commit

Permalink
Interactive DataMapPlot (#1853) and deprecate non-chat OpenAI models (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Feb 17, 2025
1 parent f3900ad commit 7d2aa5b
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 119 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ venv/
ENV/
env.bak/
venv.bak/
*.lock

# Artifacts
.idea
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ from bertopic.representation import OpenAI

# Fine-tune topic representations with GPT
client = openai.OpenAI(api_key="sk-...")
representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True)
representation_model = OpenAI(client, model="gpt-4o-mini", chat=True)
topic_model = BERTopic(representation_model=representation_model)
```

Expand Down
40 changes: 28 additions & 12 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,38 +2532,51 @@ def visualize_documents(

def visualize_document_datamap(
self,
docs: List[str],
docs: List[str] = None,
topics: List[int] = None,
embeddings: np.ndarray = None,
reduced_embeddings: np.ndarray = None,
custom_labels: Union[bool, str] = False,
title: str = "Documents and Topics",
sub_title: Union[str, None] = None,
width: int = 1200,
height: int = 1200,
**datamap_kwds,
height: int = 750,
interactive: bool = False,
enable_search: bool = False,
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
):
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
Arguments:
topic_model: A fitted BERTopic instance.
docs: The documents you used when calling either `fit` or `fit_transform`
docs: The documents you used when calling either `fit` or `fit_transform`.
topics: A selection of topics to visualize.
Not to be confused with the topics that you get from .fit_transform. For example, if you want to visualize only topics 1 through 5: topics = [1, 2, 3, 4, 5]. Documents not in these topics will be shown as noise points.
Not to be confused with the topics that you get from `.fit_transform`.
For example, if you want to visualize only topics 1 through 5:
`topics = [1, 2, 3, 4, 5]`. Documents not in these topics will be shown
as noise points.
embeddings: The embeddings of all documents in `docs`.
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
custom_labels: If bool, whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
title: Title of the plot.
sub_title: Sub-title of the plot.
width: The width of the figure.
height: The height of the figure.
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
`create_plot` function. See the DataMapPlot documentation
for more details.
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
topic_prefix: Prefix to add to the topic number when displaying the topic name.
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
if you are not using the interactive version.
See the DataMapPlot documentation for more details.
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
if you are using the interactive version.
See the DataMapPlot documentation for more details.
Returns:
figure: A Matplotlib Figure object.
Expand Down Expand Up @@ -2610,7 +2623,6 @@ def visualize_document_datamap(
```
"""
check_is_fitted(self)
check_documents_type(docs)
return plotting.visualize_document_datamap(
self,
docs,
Expand All @@ -2622,7 +2634,11 @@ def visualize_document_datamap(
sub_title,
width,
height,
**datamap_kwds,
interactive,
enable_search,
topic_prefix,
datamap_kwds,
int_datamap_kwds,
)

def visualize_hierarchical_documents(
Expand Down
64 changes: 44 additions & 20 deletions bertopic/plotting/_datamap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@ class Figure(object):

def visualize_document_datamap(
topic_model,
docs: List[str],
docs: List[str] = None,
topics: List[int] = None,
embeddings: np.ndarray = None,
reduced_embeddings: np.ndarray = None,
custom_labels: Union[bool, str] = False,
title: str = "Documents and Topics",
sub_title: Union[str, None] = None,
width: int = 1200,
height: int = 1200,
**datamap_kwds,
height: int = 750,
interactive: bool = False,
enable_search: bool = False,
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
) -> Figure:
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot.
Arguments:
topic_model: A fitted BERTopic instance.
docs: The documents you used when calling either `fit` or `fit_transform`
docs: The documents you used when calling either `fit` or `fit_transform`.
topics: A selection of topics to visualize.
Not to be confused with the topics that you get from `.fit_transform`.
For example, if you want to visualize only topics 1 through 5:
Expand All @@ -48,9 +52,15 @@ def visualize_document_datamap(
sub_title: Sub-title of the plot.
width: The width of the figure.
height: The height of the figure.
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
`create_plot` function. See the DataMapPlot documentation
for more details.
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
topic_prefix: Prefix to add to the topic number when displaying the topic name.
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
if you are not using the interactive version.
See the DataMapPlot documentation for more details.
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
if you are using the interactive version.
See the DataMapPlot documentation for more details.
Returns:
figure: A Matplotlib Figure object.
Expand Down Expand Up @@ -127,10 +137,13 @@ def visualize_document_datamap(
elif topic_model.custom_labels_ is not None and custom_labels:
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
else:
names = [
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
for topic in unique_topics
]
if topic_prefix:
names = [
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
for topic in unique_topics
]
else:
names = [" ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]

topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)}
topic_name_mapping[-1] = "Unlabelled"
Expand All @@ -145,14 +158,25 @@ def visualize_document_datamap(
# Map in topic names and plot
named_topic_per_doc = pd.Series(topic_per_doc).map(topic_name_mapping).values

figure, axes = datamapplot.create_plot(
embeddings_2d,
named_topic_per_doc,
figsize=(width / 100, height / 100),
dpi=100,
title=title,
sub_title=sub_title,
**datamap_kwds,
)
if interactive:
figure = datamapplot.create_interactive_plot(
embeddings_2d,
named_topic_per_doc,
hover_text=docs,
enable_search=enable_search,
width=width,
height=height,
**int_datamap_kwds,
)
else:
figure, _ = datamapplot.create_plot(
embeddings_2d,
named_topic_per_doc,
figsize=(width / 100, height / 100),
dpi=100,
title=title,
sub_title=sub_title,
**datamap_kwds,
)

return figure
Loading

0 comments on commit 7d2aa5b

Please sign in to comment.