diff --git a/cpp/dolfinx/fem/interpolate.h b/cpp/dolfinx/fem/interpolate.h
index d2a098d8c8c..f46270c22c4 100644
--- a/cpp/dolfinx/fem/interpolate.h
+++ b/cpp/dolfinx/fem/interpolate.h
@@ -1097,7 +1097,8 @@ geometry::PointOwnershipData<T> create_interpolation_data(
       x[3 * i + j] = coords[i + j * num_points];
 
   // Determine ownership of each point
-  return geometry::determine_point_ownership<T>(mesh1, x, padding);
+  return geometry::determine_point_ownership<T>(mesh1, x, padding,
+                                                std::nullopt);
 }
 
 /// @brief Interpolate a finite element Function defined on a mesh to a
diff --git a/cpp/dolfinx/fem/petsc.h b/cpp/dolfinx/fem/petsc.h
index 5fbcdcef9ad..eaccf2369c2 100644
--- a/cpp/dolfinx/fem/petsc.h
+++ b/cpp/dolfinx/fem/petsc.h
@@ -462,7 +462,8 @@ void apply_lifting(
 template <std::floating_point T>
 void set_bc(
     Vec b,
-    const std::vector<std::reference_wrapper<const DirichletBC<PetscScalar, T>>> bcs,
+    const std::vector<std::reference_wrapper<const DirichletBC<PetscScalar, T>>>
+        bcs,
     const Vec x0, PetscScalar alpha = 1)
 {
   PetscInt n = 0;
diff --git a/cpp/dolfinx/geometry/BoundingBoxTree.h b/cpp/dolfinx/geometry/BoundingBoxTree.h
index 64ede45a057..c2c932169ea 100644
--- a/cpp/dolfinx/geometry/BoundingBoxTree.h
+++ b/cpp/dolfinx/geometry/BoundingBoxTree.h
@@ -223,8 +223,8 @@ class BoundingBoxTree
   /// compute the bounding box for (may be empty, if none).
   /// @param[in] padding Value to pad (extend) the the bounding box of
   /// each entity by.
-  BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim,
-                  std::span<const std::int32_t> entities, double padding = 0)
+  BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim, double padding,
+                  std::span<const std::int32_t> entities)
       : _tdim(tdim)
   {
     if (tdim < 0 or tdim > mesh.topology()->dim())
@@ -266,7 +266,7 @@ class BoundingBoxTree
   /// build the bounding box tree for
   /// @param[in] padding Value to pad (extend) the the bounding box of
   /// each entity by.
-  BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim, T padding = 0)
+  BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim, T padding)
       : BoundingBoxTree::BoundingBoxTree(
             mesh, tdim, range(mesh.topology_mutable(), tdim), padding)
   {
diff --git a/cpp/dolfinx/geometry/utils.h b/cpp/dolfinx/geometry/utils.h
index 0f643e62643..c1cf24e6cef 100644
--- a/cpp/dolfinx/geometry/utils.h
+++ b/cpp/dolfinx/geometry/utils.h
@@ -663,41 +663,43 @@ graph::AdjacencyList<std::int32_t> compute_colliding_cells(
 /// @param[in] mesh The mesh
 /// @param[in] points Points to check for collision (`shape=(num_points,
 /// 3)`). Storage is row-major.
+/// @param[in] cells Cells to check for ownership
 /// @param[in] padding Amount of absolute padding of bounding boxes of the mesh.
 /// Each bounding box of the mesh is padded with this amount, to increase
 /// the number of candidates, avoiding rounding errors in determining the owner
 /// of a point if the point is on the surface of a cell in the mesh.
-/// @return Tuple `(src_owner, dest_owner, dest_points, dest_cells)`,
-/// where src_owner is a list of ranks corresponding to the input
-/// points. dest_owner is a list of ranks corresponding to dest_points,
-/// the points that this process owns. dest_cells contains the
-/// corresponding cell for each entry in dest_points.
+/// @return Point ownership data.
 ///
 /// @note `dest_owner` is sorted
-/// @note Returns -1 if no colliding process is found
+/// @note `src_owner` is -1 if no colliding process is found
 /// @note dest_points is flattened row-major, shape `(dest_owner.size(),
 /// 3)`
-/// @note Only looks through cells owned by the process
 /// @note A large padding value can increase the runtime of the function by
 /// orders of magnitude, because for non-colliding cells
 /// one has to determine the closest cell among all processes with an
 /// intersecting bounding box, which is an expensive operation to perform.
 template <std::floating_point T>
-PointOwnershipData<T> determine_point_ownership(const mesh::Mesh<T>& mesh,
-                                                std::span<const T> points,
-                                                T padding)
+PointOwnershipData<T>
+determine_point_ownership(const mesh::Mesh<T>& mesh, std::span<const T> points,
+                          T padding,
+                          std::optional<std::span<const std::int32_t>> cells)
 {
   MPI_Comm comm = mesh.comm();
 
+  const int tdim = mesh.topology()->dim();
+
+  std::vector<std::int32_t> local_cells;
+  if (not(cells.has_value()))
+  {
+    auto cell_map = mesh.topology()->index_map(tdim);
+    local_cells.resize(cell_map->size_local());
+    std::iota(local_cells.begin(), local_cells.end(), 0);
+    cells
+        = std::span<const std::int32_t>(local_cells.data(), local_cells.size());
+  }
   // Create a global bounding-box tree to find candidate processes with
   // cells that could collide with the points
-  const int tdim = mesh.topology()->dim();
-  auto cell_map = mesh.topology()->index_map(tdim);
-  const std::int32_t num_cells = cell_map->size_local();
-  // NOTE: Should we send the cells in as input?
-  std::vector<std::int32_t> cells(num_cells, 0);
-  std::iota(cells.begin(), cells.end(), 0);
-  BoundingBoxTree bb(mesh, tdim, cells, padding);
+  BoundingBoxTree bb(mesh, tdim, padding, cells.value());
   BoundingBoxTree global_bbtree = bb.create_global_tree(comm);
 
   // Compute collisions:
diff --git a/cpp/test/mesh/read_named_meshtags.cpp b/cpp/test/mesh/read_named_meshtags.cpp
index cf3830aa1ae..d5b561c98a2 100644
--- a/cpp/test/mesh/read_named_meshtags.cpp
+++ b/cpp/test/mesh/read_named_meshtags.cpp
@@ -49,7 +49,8 @@ void test_read_named_meshtags()
                                             material_values);
   mt_materials.name = "material";
 
-  io::XDMFFile file(mesh->comm(), mesh_file_name, "w", io::XDMFFile::Encoding::HDF5);
+  io::XDMFFile file(mesh->comm(), mesh_file_name, "w",
+                    io::XDMFFile::Encoding::HDF5);
   file.write_mesh(*mesh);
   file.write_meshtags(mt_domains, mesh->geometry(),
                       "/Xdmf/Domain/mesh/Geometry");
@@ -58,7 +59,7 @@ void test_read_named_meshtags()
   file.close();
 
   io::XDMFFile mesh_file(MPI_COMM_WORLD, mesh_file_name, "r",
-                        io::XDMFFile::Encoding::HDF5);
+                         io::XDMFFile::Encoding::HDF5);
   mesh = std::make_shared<mesh::Mesh<double>>(mesh_file.read_mesh(
       fem::CoordinateElement<double>(mesh::CellType::triangle, 1),
       mesh::GhostMode::none, "mesh"));
diff --git a/python/demo/demo_static-condensation.py b/python/demo/demo_static-condensation.py
index 953f7f64af2..7468d04c96f 100644
--- a/python/demo/demo_static-condensation.py
+++ b/python/demo/demo_static-condensation.py
@@ -183,7 +183,7 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL):
 A.assemble()
 
 # Create bounding box for function evaluation
-bb_tree = geometry.bb_tree(msh, 2)
+bb_tree = geometry.bb_tree(msh, 2, 0.0)
 
 # Check against standard table value
 p = np.array([[48.0, 52.0, 0.0]], dtype=np.float64)
diff --git a/python/dolfinx/fem/bcs.py b/python/dolfinx/fem/bcs.py
index e433bc81d95..0c393022351 100644
--- a/python/dolfinx/fem/bcs.py
+++ b/python/dolfinx/fem/bcs.py
@@ -11,12 +11,11 @@
 import numbers
 import typing
 
-import numpy.typing as npt
-
 if typing.TYPE_CHECKING:
     from dolfinx.fem.function import Constant, Function
 
 import numpy as np
+import numpy.typing as npt
 
 import dolfinx
 from dolfinx import cpp as _cpp
diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py
index b5491941e18..0808f05e7d4 100644
--- a/python/dolfinx/fem/forms.py
+++ b/python/dolfinx/fem/forms.py
@@ -111,7 +111,7 @@ def integral_types(self):
 
 def get_integration_domains(
     integral_type: IntegralType,
-    subdomain: typing.Optional[typing.Union[MeshTags, list[tuple[int, np.ndarray]]]],
+    subdomain: typing.Optional[typing.Union[MeshTags, list[tuple[int, npt.NDArray[np.int32]]]]],
     subdomain_ids: list[int],
 ) -> list[tuple[int, np.ndarray]]:
     """Get integration domains from subdomain data.
diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py
index 01123fc61bf..0ca872224ca 100644
--- a/python/dolfinx/fem/function.py
+++ b/python/dolfinx/fem/function.py
@@ -89,7 +89,7 @@ class Expression:
     def __init__(
         self,
         e: ufl.core.expr.Expr,
-        X: np.ndarray,
+        X: typing.Union[npt.NDArray[np.float32], npt.NDArray[np.float64]],
         comm: typing.Optional[_MPI.Comm] = None,
         form_compiler_options: typing.Optional[dict] = None,
         jit_options: typing.Optional[dict] = None,
@@ -195,7 +195,7 @@ def _create_expression(dtype):
     def eval(
         self,
         mesh: Mesh,
-        entities: np.ndarray,
+        entities: npt.NDArray[np.int32],
         values: typing.Optional[np.ndarray] = None,
     ) -> np.ndarray:
         """Evaluate Expression on entities.
@@ -411,8 +411,8 @@ def interpolate_nonmatching(
     def interpolate(
         self,
         u0: typing.Union[typing.Callable, Expression, Function],
-        cells0: typing.Optional[np.ndarray] = None,
-        cells1: typing.Optional[np.ndarray] = None,
+        cells0: typing.Optional[npt.NDArray[np.int32]] = None,
+        cells1: typing.Optional[npt.NDArray[np.int32]] = None,
     ) -> None:
         """Interpolate an expression.
 
@@ -563,7 +563,7 @@ class ElementMetaData(typing.NamedTuple):
 def _create_dolfinx_element(
     cell_type: _cpp.mesh.CellType,
     ufl_e: ufl.FiniteElementBase,
-    dtype: np.dtype,
+    dtype: npt.DTypeLike,
 ) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
     """Create a DOLFINx element from a basix.ufl element."""
     if np.issubdtype(dtype, np.float32):
diff --git a/python/dolfinx/geometry.py b/python/dolfinx/geometry.py
index 97b44095293..7fc48faa6fd 100644
--- a/python/dolfinx/geometry.py
+++ b/python/dolfinx/geometry.py
@@ -103,8 +103,8 @@ def create_global_tree(self, comm) -> BoundingBoxTree:
 def bb_tree(
     mesh: Mesh,
     dim: int,
+    padding: float,
     entities: typing.Optional[npt.NDArray[np.int32]] = None,
-    padding: float = 0.0,
 ) -> BoundingBoxTree:
     """Create a bounding box tree for use in collision detection.
 
@@ -128,11 +128,11 @@ def bb_tree(
     dtype = mesh.geometry.x.dtype
     if np.issubdtype(dtype, np.float32):
         return BoundingBoxTree(
-            _cpp.geometry.BoundingBoxTree_float32(mesh._cpp_object, dim, entities, padding)
+            _cpp.geometry.BoundingBoxTree_float32(mesh._cpp_object, dim, padding, entities)
         )
     elif np.issubdtype(dtype, np.float64):
         return BoundingBoxTree(
-            _cpp.geometry.BoundingBoxTree_float64(mesh._cpp_object, dim, entities, padding)
+            _cpp.geometry.BoundingBoxTree_float64(mesh._cpp_object, dim, padding, entities)
         )
     else:
         raise NotImplementedError(f"Type {dtype} not supported.")
@@ -270,3 +270,42 @@ def compute_distance_gjk(
 
     """
     return _cpp.geometry.compute_distance_gjk(p, q)
+
+
+def determine_point_ownership(
+    mesh: Mesh,
+    points: npt.NDArray[np.floating],
+    padding: float,
+    cells: typing.Optional[npt.NDArray[np.int32]] = None,
+) -> PointOwnershipData:
+    """Build point ownership data for a mesh-points pair.
+
+    First, potential collisions are found by computing intersections
+    between the bounding boxes of the cells and the set of points.
+    Then, actual containment pairs are determined using the GJK algorithm.
+
+    Args:
+        mesh: The mesh
+        points: Points to check for collision (``shape=(num_points, gdim)``)
+        padding: Amount of absolute padding of bounding boxes of the mesh.
+            Each bounding box of the mesh is padded with this amount, to increase
+            the number of candidates, avoiding rounding errors in determining the owner
+            of a point if the point is on the surface of a cell in the mesh.
+        cells: Cells to check for ownership
+            If ``None`` then all cells are considered.
+
+    Returns:
+        Point ownership data
+
+    Note:
+        ``dest_owner`` is sorted
+
+        ``src_owner`` is -1 if no colliding process is found
+
+        A large padding value will increase the run-time of the code by orders
+            of magnitude. General advice is to use a padding on the scale of the
+            cell size.
+    """
+    return PointOwnershipData(
+        _cpp.geometry.determine_point_ownership(mesh._cpp_object, points, padding, cells)
+    )
diff --git a/python/dolfinx/graph.py b/python/dolfinx/graph.py
index dd16514a79a..df8d27dbfae 100644
--- a/python/dolfinx/graph.py
+++ b/python/dolfinx/graph.py
@@ -7,7 +7,10 @@
 
 from __future__ import annotations
 
+import typing
+
 import numpy as np
+import numpy.typing as npt
 
 from dolfinx import cpp as _cpp
 from dolfinx.cpp.graph import partitioner
@@ -31,7 +34,10 @@
 __all__ = ["adjacencylist", "partitioner"]
 
 
-def adjacencylist(data: np.ndarray, offsets=None):
+def adjacencylist(
+    data: typing.Union[npt.NDArray[np.int32], npt.NDArray[np.int64]],
+    offsets: typing.Optional[npt.NDArray[np.int32]] = None,
+):
     """Create an AdjacencyList for int32 or int64 datasets.
 
     Args:
diff --git a/python/dolfinx/wrappers/geometry.cpp b/python/dolfinx/wrappers/geometry.cpp
index bf3958a1fc2..14bbc3937d4 100644
--- a/python/dolfinx/wrappers/geometry.cpp
+++ b/python/dolfinx/wrappers/geometry.cpp
@@ -16,6 +16,7 @@
 #include <nanobind/nanobind.h>
 #include <nanobind/ndarray.h>
 #include <nanobind/stl/array.h>
+#include <nanobind/stl/optional.h>
 #include <nanobind/stl/tuple.h>
 #include <nanobind/stl/vector.h>
 #include <span>
@@ -33,18 +34,17 @@ void declare_bbtree(nb::module_& m, std::string type)
       .def(
           "__init__",
           [](dolfinx::geometry::BoundingBoxTree<T>* bbt,
-             const dolfinx::mesh::Mesh<T>& mesh, int dim,
+             const dolfinx::mesh::Mesh<T>& mesh, int dim, double padding,
              nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>
-                 entities,
-             double padding)
+                 entities)
           {
             new (bbt) dolfinx::geometry::BoundingBoxTree<T>(
-                mesh, dim,
-                std::span<const std::int32_t>(entities.data(), entities.size()),
-                padding);
+                mesh, dim, padding,
+                std::span<const std::int32_t>(entities.data(),
+                                              entities.size()));
           },
-          nb::arg("mesh"), nb::arg("dim"), nb::arg("entities"),
-          nb::arg("padding") = 0.0)
+          nb::arg("mesh"), nb::arg("dim"), nb::arg("padding"),
+          nb::arg("entities"))
       .def_prop_ro("num_bboxes",
                    &dolfinx::geometry::BoundingBoxTree<T>::num_bboxes)
       .def(
@@ -180,15 +180,27 @@ void declare_bbtree(nb::module_& m, std::string type)
                 mesh, dim, std::span(indices.data(), indices.size()), _p));
       },
       nb::arg("mesh"), nb::arg("dim"), nb::arg("indices"), nb::arg("points"));
-  m.def("determine_point_ownership",
-        [](const dolfinx::mesh::Mesh<T>& mesh,
-           nb::ndarray<const T, nb::c_contig> points, const T padding)
-        {
-          std::size_t p_s0 = points.ndim() == 1 ? 1 : points.shape(0);
-          std::span<const T> _p(points.data(), 3 * p_s0);
-          return dolfinx::geometry::determine_point_ownership<T>(mesh, _p,
-                                                                 padding);
-        });
+  m.def(
+      "determine_point_ownership",
+      [](const dolfinx::mesh::Mesh<T>& mesh,
+         nb::ndarray<const T, nb::c_contig> points, const T padding,
+         std::optional<
+             nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>>
+             cells)
+      {
+        std::size_t p_s0 = points.ndim() == 1 ? 1 : points.shape(0);
+        std::span<const T> _p(points.data(), 3 * p_s0);
+        std::optional<std::span<const std::int32_t>> _cells
+            = cells.has_value()
+                  ? std::span<const std::int32_t>(cells.value().data(),
+                                                  cells.value().size())
+                  : std::optional<std::span<const std::int32_t>>(std::nullopt);
+        return dolfinx::geometry::determine_point_ownership<T>(mesh, _p,
+                                                               padding, _cells);
+      },
+      nb::arg("mesh"), nb::arg("points"), nb::arg("padding"),
+      nb::arg("cells").none(),
+      "Compute point ownership data for mesh-points pair.");
 
   std::string pod_pyclass_name = "PointOwnershipData_" + type;
   nb::class_<dolfinx::geometry::PointOwnershipData<T>>(m,
diff --git a/python/test/unit/fem/test_function.py b/python/test/unit/fem/test_function.py
index 550b0528e7a..c131aa47c4f 100644
--- a/python/test/unit/fem/test_function.py
+++ b/python/test/unit/fem/test_function.py
@@ -90,7 +90,7 @@ def e3(x):
     u3.interpolate(e3)
 
     x0 = (mesh.geometry.x[0] + mesh.geometry.x[1]) / 2.0
-    tree = bb_tree(mesh, mesh.geometry.dim)
+    tree = bb_tree(mesh, mesh.geometry.dim, 0.0)
     cell_candidates = compute_collisions_points(tree, x0)
     cell = compute_colliding_cells(mesh, cell_candidates, x0).array
     assert len(cell) > 0
diff --git a/python/test/unit/fem/test_interpolation.py b/python/test/unit/fem/test_interpolation.py
index 126e265fc88..994888b1a25 100644
--- a/python/test/unit/fem/test_interpolation.py
+++ b/python/test/unit/fem/test_interpolation.py
@@ -1025,7 +1025,7 @@ def f_test2(x):
     u1_exact.x.scatter_forward()
 
     # Find the single cell in mesh1 which is overlapped by mesh2
-    tree1 = bb_tree(mesh1, mesh1.topology.dim)
+    tree1 = bb_tree(mesh1, mesh1.topology.dim, 0.0)
     cells_overlapped1 = compute_collisions_points(
         tree1, np.array([p0_mesh2, p0_mesh2, 0.0]) / 2
     ).array
diff --git a/python/test/unit/geometry/test_bounding_box_tree.py b/python/test/unit/geometry/test_bounding_box_tree.py
index 2cc328300b6..be485b60ff0 100644
--- a/python/test/unit/geometry/test_bounding_box_tree.py
+++ b/python/test/unit/geometry/test_bounding_box_tree.py
@@ -12,6 +12,7 @@
 
 from dolfinx import cpp as _cpp
 from dolfinx.geometry import (
+    PointOwnershipData,
     bb_tree,
     compute_closest_entity,
     compute_colliding_cells,
@@ -19,9 +20,11 @@
     compute_collisions_trees,
     compute_distance_gjk,
     create_midpoint_tree,
+    determine_point_ownership,
 )
 from dolfinx.mesh import (
     CellType,
+    compute_midpoints,
     create_box,
     create_unit_cube,
     create_unit_interval,
@@ -147,7 +150,7 @@ def rotation_matrix(axis, angle):
 @pytest.mark.parametrize("dtype", [np.float32, np.float64])
 def test_empty_tree(dtype):
     mesh = create_unit_interval(MPI.COMM_WORLD, 16, dtype=dtype)
-    bbtree = bb_tree(mesh, mesh.topology.dim, np.array([], dtype=dtype))
+    bbtree = bb_tree(mesh, mesh.topology.dim, 0.0, np.array([], dtype=dtype))
     assert bbtree.num_bboxes == 0
 
 
@@ -164,7 +167,7 @@ def test_compute_collisions_point_1d(dtype):
 
     # Compute collision
     tdim = mesh.topology.dim
-    tree = bb_tree(mesh, tdim)
+    tree = bb_tree(mesh, tdim, 0.0)
     entities = compute_collisions_points(tree, p)
     assert len(entities.array) == 1
 
@@ -209,8 +212,8 @@ def locator_B(x):
     cells_B = np.sort(np.unique(np.hstack([v_to_c.links(vertex) for vertex in vertices_B])))
 
     # Find colliding entities using bounding box trees
-    tree_A = bb_tree(mesh_A, mesh_A.topology.dim)
-    tree_B = bb_tree(mesh_B, mesh_B.topology.dim)
+    tree_A = bb_tree(mesh_A, mesh_A.topology.dim, 0.0)
+    tree_B = bb_tree(mesh_B, mesh_B.topology.dim, 0.0)
     entities = compute_collisions_trees(tree_A, tree_B)
     entities_A = np.sort(np.unique([q[0] for q in entities]))
     entities_B = np.sort(np.unique([q[1] for q in entities]))
@@ -226,8 +229,8 @@ def test_compute_collisions_tree_2d(point, dtype):
     mesh_B = create_unit_square(MPI.COMM_WORLD, 5, 5, dtype=dtype)
     bgeom = mesh_B.geometry.x
     bgeom += point
-    tree_A = bb_tree(mesh_A, mesh_A.topology.dim)
-    tree_B = bb_tree(mesh_B, mesh_B.topology.dim)
+    tree_A = bb_tree(mesh_A, mesh_A.topology.dim, 0.0)
+    tree_B = bb_tree(mesh_B, mesh_B.topology.dim, 0.0)
     entities = compute_collisions_trees(tree_A, tree_B)
 
     entities_A = np.sort(np.unique([q[0] for q in entities]))
@@ -248,8 +251,8 @@ def test_compute_collisions_tree_3d(point, dtype):
     bgeom = mesh_B.geometry.x
     bgeom += point
 
-    tree_A = bb_tree(mesh_A, mesh_A.topology.dim)
-    tree_B = bb_tree(mesh_B, mesh_B.topology.dim)
+    tree_A = bb_tree(mesh_A, mesh_A.topology.dim, 0.0)
+    tree_B = bb_tree(mesh_B, mesh_B.topology.dim, 0.0)
     entities = compute_collisions_trees(tree_A, tree_B)
     entities_A = np.sort(np.unique([q[0] for q in entities]))
     entities_B = np.sort(np.unique([q[1] for q in entities]))
@@ -266,7 +269,7 @@ def test_compute_closest_entity_1d(dim, dtype):
     N = 16
     points = np.array([[-ref_distance, 0, 0], [2 / N, 2 * ref_distance, 0]], dtype=dtype)
     mesh = create_unit_interval(MPI.COMM_WORLD, N, dtype=dtype)
-    tree = bb_tree(mesh, dim)
+    tree = bb_tree(mesh, dim, 0.0)
     num_entities_local = (
         mesh.topology.index_map(dim).size_local + mesh.topology.index_map(dim).num_ghosts
     )
@@ -300,7 +303,7 @@ def test_compute_closest_entity_2d(dim, dtype):
     points = np.array([-1.0, -0.01, 0.0], dtype=dtype)
     mesh = create_unit_square(MPI.COMM_WORLD, 15, 15, dtype=dtype)
     mesh.topology.create_entities(dim)
-    tree = bb_tree(mesh, dim)
+    tree = bb_tree(mesh, dim, 0.0)
     num_entities_local = (
         mesh.topology.index_map(dim).size_local + mesh.topology.index_map(dim).num_ghosts
     )
@@ -332,7 +335,7 @@ def test_compute_closest_entity_3d(dim, dtype):
     mesh = create_unit_cube(MPI.COMM_WORLD, 8, 8, 8, dtype=dtype)
     mesh.topology.create_entities(dim)
 
-    tree = bb_tree(mesh, dim)
+    tree = bb_tree(mesh, dim, 0.0)
     num_entities_local = (
         mesh.topology.index_map(dim).size_local + mesh.topology.index_map(dim).num_ghosts
     )
@@ -365,7 +368,7 @@ def test_compute_closest_sub_entity(dim, dtype):
     mesh = create_unit_cube(MPI.COMM_WORLD, 8, 8, 8, dtype=dtype)
     mesh.topology.create_entities(dim)
     left_entities = locate_entities(mesh, dim, lambda x: x[0] <= xc)
-    tree = bb_tree(mesh, dim, left_entities)
+    tree = bb_tree(mesh, dim, 0.0, left_entities)
     midpoint_tree = create_midpoint_tree(mesh, dim, left_entities)
     closest_entities = compute_closest_entity(tree, midpoint_tree, mesh, points)
 
@@ -393,7 +396,7 @@ def test_surface_bbtree(dtype):
     tdim = mesh.topology.dim
     f_to_c = mesh.topology.connectivity(tdim - 1, tdim)
     cells = np.array([f_to_c.links(f)[0] for f in sf], dtype=np.int32)
-    bbtree = bb_tree(mesh, tdim, cells)
+    bbtree = bb_tree(mesh, tdim, 0.0, cells)
 
     # test collision (should not collide with any)
     p = np.array([0.5, 0.5, 0.5])
@@ -410,7 +413,7 @@ def test_sub_bbtree_codim1(dtype):
     top_facets = locate_entities_boundary(mesh, fdim, lambda x: np.isclose(x[2], 1))
     f_to_c = mesh.topology.connectivity(tdim - 1, tdim)
     cells = np.array([f_to_c.links(f)[0] for f in top_facets], dtype=np.int32)
-    bbtree = bb_tree(mesh, tdim, cells)
+    bbtree = bb_tree(mesh, tdim, 0.0, cells)
 
     # Compute a BBtree for all processes
     process_bbtree = bbtree.create_global_tree(mesh.comm)
@@ -438,7 +441,7 @@ def test_serial_global_bb_tree(dtype, comm):
     # entity tree with a serial mesh
     x = np.array([[2.0, 2.0, 3.0], [0.3, 0.2, 0.1]], dtype=dtype)
 
-    tree = bb_tree(mesh, mesh.topology.dim)
+    tree = bb_tree(mesh, mesh.topology.dim, 0.0)
     global_tree = tree.create_global_tree(mesh.comm)
 
     tree_col = compute_collisions_points(tree, x)
@@ -462,12 +465,12 @@ def test_sub_bbtree_box(ct, N, dtype):
     facets = locate_entities_boundary(mesh, fdim, lambda x: np.isclose(x[1], 1.0))
     f_to_c = mesh.topology.connectivity(fdim, tdim)
     cells = np.int32(np.unique([f_to_c.links(f)[0] for f in facets]))
-    bbtree = bb_tree(mesh, tdim, cells)
+    bbtree = bb_tree(mesh, tdim, 0.0, cells)
     num_boxes = bbtree.num_bboxes
     if num_boxes > 0:
         bbox = bbtree.get_bbox(num_boxes - 1)
         assert np.isclose(bbox[0][1], (N - 1) / N)
-    tree = bb_tree(mesh, tdim)
+    tree = bb_tree(mesh, tdim, 0.0)
     assert num_boxes < tree.num_bboxes
 
 
@@ -486,13 +489,227 @@ def test_surface_bbtree_collision(dtype):
 
     # Compute unique set of cells (some will be counted multiple times)
     cells = np.array(list(set([f_to_c.links(f)[0] for f in sf])), dtype=np.int32)
-    bbtree1 = bb_tree(mesh1, tdim, cells)
+    bbtree1 = bb_tree(mesh1, tdim, 0.0, cells)
 
     mesh2.topology.create_connectivity(mesh2.topology.dim - 1, mesh2.topology.dim)
     sf = exterior_facet_indices(mesh2.topology)
     f_to_c = mesh2.topology.connectivity(tdim - 1, tdim)
     cells = np.array(list(set([f_to_c.links(f)[0] for f in sf])), dtype=np.int32)
-    bbtree2 = bb_tree(mesh2, tdim, cells)
+    bbtree2 = bb_tree(mesh2, tdim, 0.0, cells)
 
     collisions = compute_collisions_trees(bbtree1, bbtree2)
     assert len(collisions) == 1
+
+
+@pytest.mark.parametrize("dim", [2, 3])
+@pytest.mark.parametrize("affine", [True, False])
+@pytest.mark.parametrize("dtype", [np.float32, np.float64])
+def test_determine_point_ownership(dim, affine, dtype):
+    """Find point owners (ranks and cells) using bounding box trees + global communication
+    and compare to point ownership data results."""
+    comm = MPI.COMM_WORLD
+    rank = comm.Get_rank()
+    mpi_dtype = MPI.DOUBLE if dtype == np.float64 else MPI.FLOAT
+
+    tdim = dim
+    num_cells_side = 4
+    if tdim == 2:
+        ct = CellType.triangle if affine else CellType.quadrilateral
+        mesh = create_unit_square(MPI.COMM_WORLD, num_cells_side, num_cells_side, ct, dtype=dtype)
+    else:
+        ct = CellType.tetrahedron if affine else CellType.hexahedron
+        mesh = create_unit_cube(
+            MPI.COMM_WORLD,
+            num_cells_side,
+            num_cells_side,
+            num_cells_side,
+            ct,
+            dtype=dtype,
+        )
+    cell_map = mesh.topology.index_map(tdim)
+
+    tree = bb_tree(mesh, mesh.topology.dim, 0.0, np.arange(cell_map.size_local))
+    num_global_cells = num_cells_side**tdim
+    if affine:
+        num_global_cells *= 2 * (3 ** (tdim - 2))
+    local_midpoints = compute_midpoints(
+        mesh, tdim, np.arange(mesh.topology.index_map(tdim).size_local)
+    )
+    midpoints_per_rank = np.zeros(comm.size, dtype=np.int32)
+    midpoints_offsets = np.zeros(comm.size, dtype=np.int32)
+    comm.Allgather(np.array([local_midpoints.shape[0]], dtype=np.int32), midpoints_per_rank)
+    midpoints_offsets[1:] = np.cumsum(midpoints_per_rank[:-1])
+    all_midpoints = np.zeros((num_global_cells, 3), dtype=dtype)
+    comm.Allgatherv(
+        local_midpoints, [all_midpoints, midpoints_per_rank * 3, midpoints_offsets * 3, mpi_dtype]
+    )
+    # Find potential owner cells
+    tree_col = compute_collisions_points(tree, all_midpoints)
+
+    mesh.topology.create_connectivity(tdim - 1, 0)
+    mesh.topology.create_connectivity(0, tdim)
+    cfc = mesh.topology.connectivity(tdim, tdim - 1)
+    fpc = mesh.topology.connectivity(tdim - 1, 0)
+
+    # Narrow it down to a single owner cell
+    def is_inside(mesh, icell, point):
+        fdim = tdim - 1
+        is_inside = True
+        cpoints = mesh.geometry.x[mesh.geometry.dofmap[icell, :]]  # cell points
+        ccentroid = np.average(cpoints, axis=0)  # cell centroid
+        for ifacet in cfc.links(icell):
+            fpoints_indices = _cpp.mesh.entities_to_geometry(
+                mesh._cpp_object,
+                0,
+                fpc.links(ifacet),
+                False,
+            )
+            fpoints_indices = fpoints_indices.reshape(fpoints_indices.size)
+            fpoints = mesh.geometry.x[fpoints_indices]
+            fcentroid = np.average(fpoints, axis=0)  # facet centroid
+            # Compute facet normal pointing to outside of owner cell
+            normal = np.zeros(3, dtype=dtype)
+            facet_vector1 = fpoints[1, :] - fpoints[0, :]
+            if fdim == 1:
+                normal[0] = -facet_vector1[1]
+                normal[1] = +facet_vector1[0]
+            elif fdim == 2:
+                facet_vector2 = fpoints[2, :] - fpoints[0, :]
+                normal = np.cross(facet_vector1, facet_vector2)
+            else:
+                raise ValueError("Unexpected facet dimension.")
+            normal /= np.linalg.norm(normal)
+            # Re-align if pointing to inside the parent cell
+            normal = -normal if (np.dot((ccentroid - fcentroid), normal) > 0) else normal
+            # Test the point
+            signed_distance = np.dot((point - fcentroid), normal)
+            if signed_distance > 1e-9:
+                is_inside = False
+                break
+        return is_inside
+
+    processwise_owners = np.zeros(2 * num_global_cells, dtype=np.int32)
+    owners = np.empty_like(processwise_owners)
+    for ipoint in range(num_global_cells):
+        potential_owners = tree_col.links(ipoint)
+        owner_cells = []
+        for cell in potential_owners:
+            if is_inside(mesh, cell, all_midpoints[ipoint, :]):
+                owner_cells.append(cell)
+        if owner_cells:
+            assert len(owner_cells) == 1
+            processwise_owners[2 * ipoint] = rank
+            processwise_owners[2 * ipoint + 1] = owner_cells[0]
+
+    # Since ghost cells are left out and the points considered are midpoints
+    # of cells, they are only contained in a single process / cell
+    # The value at a given index is null if it doesn't correspond
+    # to the current process.
+    # We can sum the processwise arrays to obtain a global array
+    comm.Allreduce(processwise_owners, owners, op=MPI.SUM)
+    owner_ranks = owners[[2 * i for i in range(num_global_cells)]]
+    owner_cells = owners[[2 * i + 1 for i in range(num_global_cells)]]
+
+    # Reorganize ownership data (point, index, rank, cell) into dictionary
+    ownership_data = {}
+    for ipoint in range(num_global_cells):
+        ownership_data[tuple(all_midpoints[ipoint])] = (
+            ipoint,
+            owner_ranks[ipoint],
+            owner_cells[ipoint],
+        )
+
+    def check_po(po: PointOwnershipData, src_points, ownership_data, global_dest_owners):
+        """
+        Check point ownership data
+
+        po: PointOwnershipData object to check
+        src_points: Points sent by process
+        ownership_data: {point:(global_index,rank,cell}
+        global_dest_owners: Rank who sent each point
+        """
+        # Check src_owner: Check owner ranks of sent points
+        src_owner = po.src_owner()
+        for ipoint in range(src_points.shape[0]):
+            assert ownership_data[tuple(src_points[ipoint])][1] == src_owner[ipoint]
+
+        dest_points = po.dest_points()
+        dest_owners = po.dest_owner()
+        dest_cells = po.dest_cells()
+
+        # Check dest_points: All points that should have been found have been found
+        dest_points_indices = list(range(dest_points.shape[0]))
+        for point, data in ownership_data.items():
+            (iglobal, processor, _) = data
+            if processor == rank:
+                found = False
+                point = np.array(point, dtype=dtype)
+                for jpoint in dest_points_indices:
+                    found = np.allclose(point, dest_points[jpoint])
+                    if found:
+                        break
+                assert found
+                dest_points_indices.remove(jpoint)
+
+        # Check dest_owners and dest_cells
+        # dest_owners: Ranks that asked about the points we own
+        # dest_cells: Local index of cell that contains the points we own
+        for ipoint in range(dest_points.shape[0]):
+            iglobal = ownership_data[tuple(dest_points[ipoint])][0]
+            c = ownership_data[tuple(dest_points[ipoint])][2]
+            assert dest_owners[ipoint] == global_dest_owners[iglobal]
+            assert dest_cells[ipoint] == c
+
+    def set_local_range(array):
+        N = array.shape[0]
+        n = N // comm.size
+        r = N % comm.size
+        # First r processes has one extra value
+        if rank < r:
+            (start, stop) = [rank * (n + 1), (rank + 1) * (n + 1)]
+        else:
+            (start, stop) = [rank * n + r, (rank + 1) * n + r]
+        return array[start:stop], start, stop
+
+    def compute_global_owners(N, start, stop):
+        """Compute array of ranks who own each point"""
+        mask_points_owned = np.zeros(N, np.int32)
+        global_owners = np.empty_like(mask_points_owned)
+        mask_points_owned[start:stop] = rank
+        comm.Allreduce(mask_points_owned, global_owners, op=MPI.SUM)
+        return global_owners
+
+    # All cells
+    points, start, stop = set_local_range(all_midpoints)
+    owners = compute_global_owners(np.int64(all_midpoints.shape[0]), start, stop)
+    all_cells = np.arange(cell_map.size_local, dtype=dtype)
+    po = determine_point_ownership(mesh, points, 0.0, all_cells)
+
+    check_po(po, points, ownership_data, owners)
+
+    # Left half
+    num_left_cells = np.rint(num_global_cells / 2).astype(np.int32)
+    left_midpoints = np.zeros((num_left_cells, 3), dtype=dtype)
+    counter = 0
+    indices_left = []
+    for ipoint in range(num_global_cells):
+        if all_midpoints[ipoint, 0] <= 0.5:
+            left_midpoints[counter] = all_midpoints[ipoint]
+            indices_left.append(ipoint)
+            counter += 1
+    points, start, stop = set_local_range(left_midpoints)
+    owners = compute_global_owners(np.int64(all_midpoints.shape[0]), start, stop)
+    left_cells = locate_entities(mesh, tdim, lambda x: x[0] <= 0.5)
+    left_cells = np.array(
+        [cell for cell in left_cells if cell < cell_map.size_local], dtype=np.int32
+    )  # Filter out ghost cells
+    lpo = determine_point_ownership(mesh, points, 0.0, left_cells)
+
+    left_ownership_data = {}
+    for idx, ipoint in enumerate(indices_left):
+        left_ownership_data[tuple(all_midpoints[ipoint])] = (
+            idx,
+            owner_ranks[ipoint],
+            owner_cells[ipoint],
+        )
+    check_po(lpo, points, left_ownership_data, owners)
diff --git a/python/test/unit/geometry/test_gjk.py b/python/test/unit/geometry/test_gjk.py
index 8b895a4633f..fb2dc706343 100644
--- a/python/test/unit/geometry/test_gjk.py
+++ b/python/test/unit/geometry/test_gjk.py
@@ -193,7 +193,7 @@ def test_collision_2nd_order_triangle(dtype):
     sample_points = np.array([[0.1, 0.3, 0.0], [0.2, 0.5, 0.0], [0.6, 0.6, 0.0]])
 
     # Create boundingboxtree
-    tree = geometry.bb_tree(mesh, mesh.geometry.dim)
+    tree = geometry.bb_tree(mesh, mesh.geometry.dim, 0.0)
     cell_candidates = geometry.compute_collisions_points(tree, sample_points)
     colliding_cells = geometry.compute_colliding_cells(mesh, cell_candidates, sample_points)
     # Check for collision
diff --git a/python/test/unit/mesh/test_manifold_point_search.py b/python/test/unit/mesh/test_manifold_point_search.py
index ec3428ca3f9..80704d2e0bc 100644
--- a/python/test/unit/mesh/test_manifold_point_search.py
+++ b/python/test/unit/mesh/test_manifold_point_search.py
@@ -18,7 +18,7 @@ def test_manifold_point_search():
     cells = np.array([[0, 1, 2], [0, 1, 3]], dtype=np.int64)
     domain = ufl.Mesh(element("Lagrange", "triangle", 1, shape=(2,)))
     mesh = create_mesh(MPI.COMM_WORLD, cells, vertices, domain)
-    bb = bb_tree(mesh, mesh.topology.dim)
+    bb = bb_tree(mesh, mesh.topology.dim, 0.0)
 
     # Find cell colliding with point
     points = np.array([[0.5, 0.25, 0.75], [0.25, 0.5, 0.75]], dtype=default_real_type)