Skip to content

Streamlit Component

The Python package also provides a Streamlit component to use Embedding Atlas in your Streamlit apps.

Installation

bash
pip install embedding-atlas

Example

python
from embedding_atlas.streamlit import embedding_atlas
from embedding_atlas.projection import compute_text_projection

# Compute text embedding and projection of the embedding
compute_text_projection(df, text="description",
    x="projection_x", y="projection_y", neighbors="neighbors"
)

# Create an Embedding Atlas component for a given data frame
value = embedding_atlas(
    df, text="description",
    x="projection_x", y="projection_y", neighbors="neighbors",
    show_table=True
)

The returned value is a dict with a predicate string. The predicate is a SQL expression for the current selection in the component. You may use DuckDB to query the data frame with the predicate:

python
import duckdb

predicate = value.get("predicate")
if predicate is not None:
    # Query the data frame with the SQL predicate
    selection = duckdb.query_df(
        df, "dataframe", "SELECT * FROM dataframe WHERE " + predicate
    )
    # Show the selection
    st.dataframe(selection)

Note that it's also possible to use the component without projection:

python
value = embedding_atlas(df)

Without x and y the widget will fall back to a table and charts only mode.

Reference

python
from embedding_atlas.streamlit import embedding_atlas

Below are the options and return value of the embedding_atlas function:

Create an Embedding Atlas widget in Streamlit.

Args:
data_frame:

The data frame to visualize.

x:

The column name for X axis in the embedding.

y:

The column name for Y axis in the embedding.

text:

The column name for the textual data.

neighbors:

The column name containing precomputed K-nearest neighbors for each point. Each value in the column should be a dictionary with the format: { "ids": [id1, id2, ...], "distances": [distance1, distance2, ...] }.

  • "ids" should be an array of row ids of the neighbors (if row_id is specified, match the value in row_id, otherwise use the row index), sorted by distance.

  • "distances" should contain the corresponding distances to each neighbor.

labels:

Labels for the embedding view. Set to string "automatic" to generate labels automatically, or "disabled" to disable auto labels. Automatic labels are generated by clustering the 2D density distribution and selecting representative keywords using TF-IDF ranking. You can also pass in a list of labels. Each label must contain x and y coordinates and text for the label content. Optionally, you may specify an integer level to roughly control the zoom level where the label appears, and priority for the label's priority. Higher priority labels have a better chance to appear when multiple labels overlap.

stop_words:

Stop words for automatic label generation.

point_size:

Override the default point size for the embedding view.

show_table:

Whether to display the data table when the widget opens.

show_charts:

Whether to display charts when the widget opens.

show_embedding:

Whether to display the embedding view when the widget opens.

key:

The key of the Streamlit widget.

Returns:

A dict with the following key:

  • predicate: the SQL predicate for the current selection in the widget.