Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interactive DataMapPlot (#1853) and deprecate non-chat OpenAI models,… #2287

Merged
merged 2 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ venv/
ENV/
env.bak/
venv.bak/
*.lock

# Artifacts
.idea
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ from bertopic.representation import OpenAI

# Fine-tune topic representations with GPT
client = openai.OpenAI(api_key="sk-...")
representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True)
representation_model = OpenAI(client, model="gpt-4o-mini", chat=True)
topic_model = BERTopic(representation_model=representation_model)
```

Expand Down
40 changes: 28 additions & 12 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,38 +2532,51 @@ def visualize_documents(

def visualize_document_datamap(
self,
docs: List[str],
docs: List[str] = None,
topics: List[int] = None,
embeddings: np.ndarray = None,
reduced_embeddings: np.ndarray = None,
custom_labels: Union[bool, str] = False,
title: str = "Documents and Topics",
sub_title: Union[str, None] = None,
width: int = 1200,
height: int = 1200,
**datamap_kwds,
height: int = 750,
interactive: bool = False,
enable_search: bool = False,
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
):
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.

Arguments:
topic_model: A fitted BERTopic instance.
docs: The documents you used when calling either `fit` or `fit_transform`
docs: The documents you used when calling either `fit` or `fit_transform`.
topics: A selection of topics to visualize.
Not to be confused with the topics that you get from .fit_transform. For example, if you want to visualize only topics 1 through 5: topics = [1, 2, 3, 4, 5]. Documents not in these topics will be shown as noise points.
Not to be confused with the topics that you get from `.fit_transform`.
For example, if you want to visualize only topics 1 through 5:
`topics = [1, 2, 3, 4, 5]`. Documents not in these topics will be shown
as noise points.
embeddings: The embeddings of all documents in `docs`.
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
custom_labels: If bool, whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
title: Title of the plot.
sub_title: Sub-title of the plot.
width: The width of the figure.
height: The height of the figure.
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
`create_plot` function. See the DataMapPlot documentation
for more details.
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
topic_prefix: Prefix to add to the topic number when displaying the topic name.
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
if you are not using the interactive version.
See the DataMapPlot documentation for more details.
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
if you are using the interactive version.
See the DataMapPlot documentation for more details.

Returns:
figure: A Matplotlib Figure object.
Expand Down Expand Up @@ -2610,7 +2623,6 @@ def visualize_document_datamap(
```
"""
check_is_fitted(self)
check_documents_type(docs)
return plotting.visualize_document_datamap(
self,
docs,
Expand All @@ -2622,7 +2634,11 @@ def visualize_document_datamap(
sub_title,
width,
height,
**datamap_kwds,
interactive,
enable_search,
topic_prefix,
datamap_kwds,
int_datamap_kwds,
)

def visualize_hierarchical_documents(
Expand Down
64 changes: 44 additions & 20 deletions bertopic/plotting/_datamap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@ class Figure(object):

def visualize_document_datamap(
topic_model,
docs: List[str],
docs: List[str] = None,
topics: List[int] = None,
embeddings: np.ndarray = None,
reduced_embeddings: np.ndarray = None,
custom_labels: Union[bool, str] = False,
title: str = "Documents and Topics",
sub_title: Union[str, None] = None,
width: int = 1200,
height: int = 1200,
**datamap_kwds,
height: int = 750,
interactive: bool = False,
enable_search: bool = False,
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
) -> Figure:
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot.

Arguments:
topic_model: A fitted BERTopic instance.
docs: The documents you used when calling either `fit` or `fit_transform`
docs: The documents you used when calling either `fit` or `fit_transform`.
topics: A selection of topics to visualize.
Not to be confused with the topics that you get from `.fit_transform`.
For example, if you want to visualize only topics 1 through 5:
Expand All @@ -48,9 +52,15 @@ def visualize_document_datamap(
sub_title: Sub-title of the plot.
width: The width of the figure.
height: The height of the figure.
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
`create_plot` function. See the DataMapPlot documentation
for more details.
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
topic_prefix: Prefix to add to the topic number when displaying the topic name.
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
if you are not using the interactive version.
See the DataMapPlot documentation for more details.
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
if you are using the interactive version.
See the DataMapPlot documentation for more details.

Returns:
figure: A Matplotlib Figure object.
Expand Down Expand Up @@ -127,10 +137,13 @@ def visualize_document_datamap(
elif topic_model.custom_labels_ is not None and custom_labels:
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
else:
names = [
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
for topic in unique_topics
]
if topic_prefix:
names = [
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
for topic in unique_topics
]
else:
names = [" ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]

topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)}
topic_name_mapping[-1] = "Unlabelled"
Expand All @@ -145,14 +158,25 @@ def visualize_document_datamap(
# Map in topic names and plot
named_topic_per_doc = pd.Series(topic_per_doc).map(topic_name_mapping).values

figure, axes = datamapplot.create_plot(
embeddings_2d,
named_topic_per_doc,
figsize=(width / 100, height / 100),
dpi=100,
title=title,
sub_title=sub_title,
**datamap_kwds,
)
if interactive:
figure = datamapplot.create_interactive_plot(
embeddings_2d,
named_topic_per_doc,
hover_text=docs,
enable_search=enable_search,
width=width,
height=height,
**int_datamap_kwds,
)
else:
figure, _ = datamapplot.create_plot(
embeddings_2d,
named_topic_per_doc,
figsize=(width / 100, height / 100),
dpi=100,
title=title,
sub_title=sub_title,
**datamap_kwds,
)

return figure
Loading