diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 5337420d..5b7f569c 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -631,7 +631,7 @@ def render_images( def render_labels( self, element: str | None = None, - color: str | None = None, + color: ColorLike | None = None, *, groups: list[str] | str | None = None, contour_px: int | None = 3, @@ -640,7 +640,7 @@ def render_labels( norm: Normalize | None = None, na_color: ColorLike | None = "default", outline_alpha: float | int = 0.0, - fill_alpha: float | int = 0.4, + fill_alpha: float | int | None = None, scale: str | None = None, colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, @@ -662,11 +662,13 @@ def render_labels( element : str | None The name of the labels element to render. If `None`, all label elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible. - color : str | None - Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of - :attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the - color column is found in multiple locations, please provide the table_name to be used for the element if you - would like a specific table to be used. By default one table will automatically be choosen. + color : ColorLike | None + Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or + "#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value + is indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in + :attr:`sdata.table.obs` or in the index of :attr:`sdata.table.var`. The latter can be used to color by + categorical or continuous variables. If the color column is found in multiple locations, please provide the + table_name to be used for the element if you would like a specific table to be used. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of them. Other values are set to NA. The list can contain multiple discrete labels to be visualized. @@ -687,8 +689,9 @@ def render_labels( won't be shown. outline_alpha : float | int, default 0.0 Alpha value for the outline of the labels. Invisible by default. - fill_alpha : float | int, default 0.4 - Alpha value for the fill of the labels. + fill_alpha : float | int | None, optional + Alpha value for the fill of the labels. By default, it is set to 0.4 or, if a color is given that implies + an alpha, that value is used for `fill_alpha`. scale : str | None Influences the resolution of the rendering. Possibilities for setting this parameter: 1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale @@ -749,6 +752,7 @@ def render_labels( sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( element=element, color=param_values["color"], + col_for_color=param_values["col_for_color"], groups=param_values["groups"], contour_px=param_values["contour_px"], cmap_params=cmap_params, @@ -1130,14 +1134,13 @@ def _draw_colorbar( if wanted_labels_on_this_cs: table = params_copy.table_name - if table is not None: - assert isinstance(params_copy.color, str) - colors = sc.get.obs_df(sdata[table], [params_copy.color]) - if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype): + if table is not None and params_copy.col_for_color is not None: + colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) + if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype): _maybe_set_colors( source=sdata[table], target=sdata[table], - key=params_copy.color, + key=params_copy.col_for_color, palette=params_copy.palette, ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 5334e287..6cc748bc 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1265,7 +1265,7 @@ def _render_labels( table_name = render_params.table_name table_layer = render_params.table_layer palette = render_params.palette - color = render_params.color + col_for_color = render_params.col_for_color groups = render_params.groups scale = render_params.scale @@ -1314,23 +1314,29 @@ def _render_labels( _, trans_data = _prepare_transformation(label, coordinate_system, ax) + na_color = ( + render_params.color + if col_for_color is None and render_params.color is not None + else render_params.cmap_params.na_color + ) color_source_vector, color_vector, categorical = _set_color_source_vec( sdata=sdata_filt, element=label, element_name=element, - value_to_plot=color, + value_to_plot=col_for_color, groups=groups, palette=palette, - na_color=render_params.cmap_params.na_color, + na_color=na_color, cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, + render_type="labels", coordinate_system=coordinate_system, ) # rasterize could have removed labels from label # only problematic if color is specified - if rasterize and color is not None: + if rasterize and col_for_color is not None: labels_in_rasterized_image = np.unique(label.values) mask = np.isin(instance_id, labels_in_rasterized_image) instance_id = instance_id[mask] @@ -1351,7 +1357,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) cmap_params=render_params.cmap_params, seg_erosionpx=seg_erosionpx, seg_boundaries=seg_boundaries, - na_color=render_params.cmap_params.na_color, + na_color=na_color, ) _cax = ax.imshow( @@ -1408,7 +1414,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) colorbar_requested = _should_request_colorbar( render_params.colorbar, has_mappable=cax is not None, - is_continuous=color is not None and color_source_vector is None and not categorical, + is_continuous=col_for_color is not None and color_source_vector is None and not categorical, ) _ = _decorate_axs( @@ -1416,7 +1422,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) cax=cax, fig_params=fig_params, adata=table, - value_to_plot=color, + value_to_plot=col_for_color, color_source_vector=color_source_vector, color_vector=color_vector, palette=palette, @@ -1432,7 +1438,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) colorbar_requests=colorbar_requests, colorbar_label=_resolve_colorbar_label( render_params.colorbar_params, - color if isinstance(color, str) else None, + col_for_color if isinstance(col_for_color, str) else None, ), scalebar_dx=scalebar_params.scalebar_dx, scalebar_units=scalebar_params.scalebar_units, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 4936468f..a108e131 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -278,7 +278,8 @@ class LabelsRenderParams: cmap_params: CmapParams element: str - color: str | None = None + color: Color | None = None + col_for_color: str | None = None groups: str | list[str] | None = None contour_px: int | None = None outline: bool = False diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 6c5aaddf..aa771241 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -94,6 +94,20 @@ # once https://github.com/scverse/spatialdata/pull/689/ is in a release ColorLike = tuple[float, ...] | list[float] | str +_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." + + +def _gate_palette_and_groups( + element_params: dict[str, Any], + param_dict: dict[str, Any], +) -> None: + """Set palette/groups on element_params only when col_for_color is present, else warn.""" + has_col = element_params.get("col_for_color") is not None + element_params["palette"] = param_dict["palette"] if has_col else None + if not has_col and param_dict["groups"] is not None: + logger.warning(_GROUPS_IGNORED_WARNING) + element_params["groups"] = param_dict["groups"] if has_col else None + def _extract_scalar_value(value: Any, default: float = 0.0) -> float: """ @@ -981,7 +995,7 @@ def _set_color_source_vec( alpha: float = 1.0, table_name: str | None = None, table_layer: str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, coordinate_system: str | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: @@ -1454,7 +1468,7 @@ def _get_categorical_color_mapping( alpha: float = 1, groups: list[str] | str | None = None, palette: list[str] | str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, ) -> Mapping[str, str]: if not isinstance(color_source_vector, Categorical): raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") @@ -2145,7 +2159,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st }: if not isinstance(color, str | tuple | list): raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") - if element_type in {"shapes", "points"}: + if element_type in {"shapes", "points", "labels"}: if _is_color_like(color): logger.info("Value for parameter 'color' appears to be a color, using it as such.") param_dict["col_for_color"] = None @@ -2153,7 +2167,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if param_dict["color"].alpha_is_user_defined(): if element_type == "points" and param_dict.get("alpha") is None: param_dict["alpha"] = param_dict["color"].get_alpha_as_float() - elif element_type == "shapes" and param_dict.get("fill_alpha") is None: + elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() else: logger.info( @@ -2165,7 +2179,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st param_dict["color"] = None else: raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") - elif "color" in param_dict and element_type != "labels": + elif "color" in param_dict and element_type != "images": param_dict["col_for_color"] = None outline_width = param_dict.get("outline_width") @@ -2256,6 +2270,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif element_type == "shapes": # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) param_dict["fill_alpha"] = 1.0 + elif element_type == "labels": + # set default fill_alpha for labels if not given by user explicitly or implicitly (as part of color) + param_dict["fill_alpha"] = 0.4 cmap = param_dict.get("cmap") palette = param_dict.get("palette") @@ -2412,8 +2429,8 @@ def _validate_label_render_params( sdata: sd.SpatialData, element: str | None, cmap: list[Colormap | str] | Colormap | str | None, - color: str | None, - fill_alpha: float | int, + color: ColorLike | None, + fill_alpha: float | int | None, contour_px: int | None, groups: list[str] | str | None, palette: list[str] | str | None, @@ -2462,15 +2479,16 @@ def _validate_label_render_params( element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["table_name"] = None - element_params[el]["color"] = None - color = param_dict["color"] - if color is not None: - color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True) + element_params[el]["color"] = param_dict["color"] # literal Color or None + element_params[el]["col_for_color"] = None + if (col_for_color := param_dict["col_for_color"]) is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], labels=True + ) element_params[el]["table_name"] = table_name - element_params[el]["color"] = color + element_params[el]["col_for_color"] = col_for_color - element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None - element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None + _gate_palette_and_groups(element_params[el], param_dict) element_params[el]["colorbar"] = param_dict["colorbar"] element_params[el]["colorbar_params"] = param_dict["colorbar_params"] @@ -2537,8 +2555,7 @@ def _validate_points_render_params( element_params[el]["table_name"] = table_name element_params[el]["col_for_color"] = col_for_color - element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None - element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None + _gate_palette_and_groups(element_params[el], param_dict) element_params[el]["ds_reduction"] = param_dict["ds_reduction"] element_params[el]["colorbar"] = param_dict["colorbar"] element_params[el]["colorbar_params"] = param_dict["colorbar_params"] @@ -2621,8 +2638,7 @@ def _validate_shape_render_params( element_params[el]["table_name"] = table_name element_params[el]["col_for_color"] = col_for_color - element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None - element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None + _gate_palette_and_groups(element_params[el], param_dict) element_params[el]["method"] = param_dict["method"] element_params[el]["ds_reduction"] = param_dict["ds_reduction"] element_params[el]["colorbar"] = param_dict["colorbar"] diff --git a/tests/_images/Labels_alpha_overwrites_opacity_from_color.png b/tests/_images/Labels_alpha_overwrites_opacity_from_color.png new file mode 100644 index 00000000..b5379f0a Binary files /dev/null and b/tests/_images/Labels_alpha_overwrites_opacity_from_color.png differ diff --git a/tests/_images/Labels_can_color_by_hex.png b/tests/_images/Labels_can_color_by_hex.png new file mode 100644 index 00000000..d4615119 Binary files /dev/null and b/tests/_images/Labels_can_color_by_hex.png differ diff --git a/tests/_images/Labels_can_color_by_hex_with_alpha.png b/tests/_images/Labels_can_color_by_hex_with_alpha.png new file mode 100644 index 00000000..b88663de Binary files /dev/null and b/tests/_images/Labels_can_color_by_hex_with_alpha.png differ diff --git a/tests/_images/Labels_can_color_by_rgba_array.png b/tests/_images/Labels_can_color_by_rgba_array.png new file mode 100644 index 00000000..d3b5bd44 Binary files /dev/null and b/tests/_images/Labels_can_color_by_rgba_array.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index a585d4eb..4449dbd3 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -84,6 +84,18 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData): .pl.show() ) + def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5]).pl.show() + + def test_plot_can_color_by_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="#88a136").pl.show() + + def test_plot_can_color_by_hex_with_alpha(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="#88a13688").pl.show() + + def test_plot_alpha_overwrites_opacity_from_color(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5], fill_alpha=1.0).pl.show() + def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()