diff --git a/charts/nucliadb_search/templates/search.vs.yaml b/charts/nucliadb_search/templates/search.vs.yaml index 7c9cff7b4e..106cfd71ea 100644 --- a/charts/nucliadb_search/templates/search.vs.yaml +++ b/charts/nucliadb_search/templates/search.vs.yaml @@ -54,6 +54,10 @@ spec: regex: '^/api/v\d+/kb/[^/]+/suggest' method: regex: "GET|OPTIONS" + - uri: + regex: '^/api/v\d+/kb/[^/]+/graph' + method: + regex: "POST|OPTIONS" retries: attempts: 3 retryOn: connect-failure diff --git a/nucliadb/src/nucliadb/common/nidx.py b/nucliadb/src/nucliadb/common/nidx.py index 96a0e372f5..a5545061da 100644 --- a/nucliadb/src/nucliadb/common/nidx.py +++ b/nucliadb/src/nucliadb/common/nidx.py @@ -262,6 +262,7 @@ def __init__(self, api_client, searcher_client): # Searcher methods self.Search = searcher_client.Search self.Suggest = searcher_client.Suggest + self.GraphSearch = searcher_client.GraphSearch self.Paragraphs = searcher_client.Paragraphs self.Documents = searcher_client.Documents diff --git a/nucliadb/src/nucliadb/search/api/v1/__init__.py b/nucliadb/src/nucliadb/search/api/v1/__init__.py index 7add9d853d..0394eac172 100644 --- a/nucliadb/src/nucliadb/search/api/v1/__init__.py +++ b/nucliadb/src/nucliadb/search/api/v1/__init__.py @@ -17,16 +17,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -from . import ask # noqa -from . import catalog # noqa -from . import feedback # noqa -from . import find # noqa -from . import knowledgebox # noqa -from . import predict_proxy # noqa -from . import search # noqa -from . import suggest # noqa -from . import summarize # noqa -from .resource import ask as ask_resource # noqa -from .resource import search as search_resource # noqa -from .resource import ingestion_agents as ingestion_agents_resource # noqa -from .router import api # noqa +from . import ( # noqa: F401 + ask, + catalog, + feedback, + find, + graph, + knowledgebox, + predict_proxy, + search, + suggest, + summarize, +) +from .resource import ask as ask_resource # noqa: F401 +from .resource import ingestion_agents as ingestion_agents_resource # noqa: F401 +from .resource import search as search_resource # noqa: F401 +from .router import api # noqa: F401 diff --git a/nucliadb/src/nucliadb/search/api/v1/catalog.py b/nucliadb/src/nucliadb/search/api/v1/catalog.py index cf0f16388b..4e64610933 100644 --- a/nucliadb/src/nucliadb/search/api/v1/catalog.py +++ b/nucliadb/src/nucliadb/search/api/v1/catalog.py @@ -36,7 +36,7 @@ from nucliadb.search.search.exceptions import InvalidQueryError from nucliadb.search.search.merge import fetch_resources from nucliadb.search.search.pgcatalog import pgcatalog_search -from nucliadb.search.search.query_parser.catalog import parse_catalog +from nucliadb.search.search.query_parser.parsers import parse_catalog from nucliadb.search.search.utils import ( maybe_log_request_payload, ) diff --git a/nucliadb/src/nucliadb/search/api/v1/graph.py b/nucliadb/src/nucliadb/search/api/v1/graph.py new file mode 100644 index 0000000000..dbd874f98b --- /dev/null +++ b/nucliadb/src/nucliadb/search/api/v1/graph.py @@ -0,0 +1,130 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from fastapi import Header, Request, Response +from fastapi_versioning import version + +from nucliadb.search.api.v1.router import KB_PREFIX, api +from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.search.graph_merge import ( + build_graph_nodes_response, + build_graph_relations_response, + build_graph_response, +) +from nucliadb.search.search.query_parser.parsers import ( + parse_graph_node_search, + parse_graph_relation_search, + parse_graph_search, +) +from nucliadb_models.graph.requests import ( + GraphNodesSearchRequest, + GraphRelationsSearchRequest, + GraphSearchRequest, +) +from nucliadb_models.graph.responses import ( + GraphNodesSearchResponse, + GraphRelationsSearchResponse, + GraphSearchResponse, +) +from nucliadb_models.resource import NucliaDBRoles +from nucliadb_models.search import ( + NucliaDBClientType, +) +from nucliadb_utils.authentication import requires + + +@api.post( + f"/{KB_PREFIX}/{{kbid}}/graph", + status_code=200, + summary="Search Knowledge Box graph", + description="Search on the Knowledge Box graph and retrieve triplets of vertex-edge-vertex", + response_model_exclude_unset=True, + include_in_schema=False, + tags=["Search"], +) +@requires(NucliaDBRoles.READER) +@version(1) +async def graph_search_knowledgebox( + request: Request, + response: Response, + kbid: str, + item: GraphSearchRequest, + x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), + x_nucliadb_user: str = Header(""), + x_forwarded_for: str = Header(""), +) -> GraphSearchResponse: + pb_query = parse_graph_search(item) + + results, _, _ = await node_query(kbid, Method.GRAPH, pb_query) + + return build_graph_response(results) + + +@api.post( + f"/{KB_PREFIX}/{{kbid}}/graph/nodes", + status_code=200, + summary="Search Knowledge Box graph nodes", + description="Search on the Knowledge Box graph and retrieve nodes (vertices)", + response_model_exclude_unset=True, + include_in_schema=False, + tags=["Search"], +) +@requires(NucliaDBRoles.READER) +@version(1) +async def graph_nodes_search_knowledgebox( + request: Request, + response: Response, + kbid: str, + item: GraphNodesSearchRequest, + x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), + x_nucliadb_user: str = Header(""), + x_forwarded_for: str = Header(""), +) -> GraphNodesSearchResponse: + pb_query = parse_graph_node_search(item) + + results, _, _ = await node_query(kbid, Method.GRAPH, pb_query) + + return build_graph_nodes_response(results) + + +@api.post( + f"/{KB_PREFIX}/{{kbid}}/graph/relations", + status_code=200, + summary="Search Knowledge Box graph relations", + description="Search on the Knowledge Box graph and retrieve relations (edges)", + response_model_exclude_unset=True, + include_in_schema=False, + tags=["Search"], +) +@requires(NucliaDBRoles.READER) +@version(1) +async def graph_relations_search_knowledgebox( + request: Request, + response: Response, + kbid: str, + item: GraphRelationsSearchRequest, + x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), + x_nucliadb_user: str = Header(""), + x_forwarded_for: str = Header(""), +) -> GraphRelationsSearchResponse: + pb_query = parse_graph_relation_search(item) + + results, _, _ = await node_query(kbid, Method.GRAPH, pb_query) + + return build_graph_relations_response(results) diff --git a/nucliadb/src/nucliadb/search/requesters/utils.py b/nucliadb/src/nucliadb/search/requesters/utils.py index 42836cb404..9f5e59fd85 100644 --- a/nucliadb/src/nucliadb/search/requesters/utils.py +++ b/nucliadb/src/nucliadb/search/requesters/utils.py @@ -33,11 +33,14 @@ from nucliadb.common.cluster.utils import get_shard_manager from nucliadb.search import logger from nucliadb.search.search.shards import ( + graph_search_shard, query_shard, suggest_shard, ) from nucliadb.search.settings import settings from nucliadb_protos.nodereader_pb2 import ( + GraphSearchRequest, + GraphSearchResponse, SearchRequest, SearchResponse, SuggestRequest, @@ -50,19 +53,22 @@ class Method(Enum): SEARCH = auto() SUGGEST = auto() + GRAPH = auto() METHODS = { Method.SEARCH: query_shard, Method.SUGGEST: suggest_shard, + Method.GRAPH: graph_search_shard, } -REQUEST_TYPE = Union[SuggestRequest, SearchRequest] +REQUEST_TYPE = Union[SuggestRequest, SearchRequest, GraphSearchRequest] T = TypeVar( "T", SuggestResponse, SearchResponse, + GraphSearchResponse, ) @@ -84,6 +90,15 @@ async def node_query( ) -> tuple[list[SearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... +@overload +async def node_query( + kbid: str, + method: Method, + pb_query: GraphSearchRequest, + timeout: Optional[float] = None, +) -> tuple[list[GraphSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... + + async def node_query( kbid: str, method: Method, diff --git a/nucliadb/src/nucliadb/search/search/find.py b/nucliadb/src/nucliadb/search/search/find.py index b0d71caa30..dbd2c73f72 100644 --- a/nucliadb/src/nucliadb/search/search/find.py +++ b/nucliadb/src/nucliadb/search/search/find.py @@ -40,7 +40,7 @@ ) from nucliadb.search.search.query import QueryParser from nucliadb.search.search.query_parser.old_filters import OldFilterParams -from nucliadb.search.search.query_parser.parser import parse_find +from nucliadb.search.search.query_parser.parsers import parse_find from nucliadb.search.search.rank_fusion import ( RankFusionAlgorithm, get_rank_fusion, diff --git a/nucliadb/src/nucliadb/search/search/graph_merge.py b/nucliadb/src/nucliadb/search/search/graph_merge.py new file mode 100644 index 0000000000..3d3567a5ea --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/graph_merge.py @@ -0,0 +1,90 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + + +from nucliadb.common.models_utils.from_proto import RelationNodeTypePbMap +from nucliadb_models.graph import responses as graph_responses +from nucliadb_models.graph.responses import ( + GraphNodesSearchResponse, + GraphRelationsSearchResponse, + GraphSearchResponse, +) +from nucliadb_protos import nodereader_pb2 + + +def build_graph_response(results: list[nodereader_pb2.GraphSearchResponse]) -> GraphSearchResponse: + paths = [] + for shard_results in results: + for pb_path in shard_results.graph: + source = shard_results.nodes[pb_path.source] + relation = shard_results.relations[pb_path.relation] + destination = shard_results.nodes[pb_path.destination] + + path = graph_responses.GraphPath( + source=graph_responses.GraphNode( + value=source.value, + type=RelationNodeTypePbMap[source.ntype], + group=source.subtype, + ), + relation=graph_responses.GraphRelation( + label=relation.label, + ), + destination=graph_responses.GraphNode( + value=destination.value, + type=RelationNodeTypePbMap[destination.ntype], + group=destination.subtype, + ), + ) + paths.append(path) + + response = GraphSearchResponse(paths=paths) + return response + + +def build_graph_nodes_response( + results: list[nodereader_pb2.GraphSearchResponse], +) -> GraphNodesSearchResponse: + nodes = [] + for shard_results in results: + for node in shard_results.nodes: + nodes.append( + graph_responses.GraphNode( + value=node.value, + type=RelationNodeTypePbMap[node.ntype], + group=node.subtype, + ) + ) + response = GraphNodesSearchResponse(nodes=nodes) + return response + + +def build_graph_relations_response( + results: list[nodereader_pb2.GraphSearchResponse], +) -> GraphRelationsSearchResponse: + relations = [] + for shard_results in results: + for relation in shard_results.relations: + relations.append( + graph_responses.GraphRelation( + label=relation.label, + ) + ) + response = GraphRelationsSearchResponse(relations=relations) + return response diff --git a/nucliadb/src/nucliadb/search/search/query_parser/models.py b/nucliadb/src/nucliadb/search/search/query_parser/models.py index 094f35677b..6b52df4447 100644 --- a/nucliadb/src/nucliadb/search/search/query_parser/models.py +++ b/nucliadb/src/nucliadb/search/search/query_parser/models.py @@ -28,6 +28,7 @@ ) from nucliadb_models import search as search_models +from nucliadb_protos import nodereader_pb2 ### Retrieval @@ -101,3 +102,11 @@ class CatalogQuery(BaseModel): faceted: list[str] page_size: int page_number: int + + +### Graph + + +# Right now, we don't need a more generic model for graph queries, we can +# directly use the protobuffer directly +GraphRetrieval = nodereader_pb2.GraphSearchRequest diff --git a/nucliadb/src/nucliadb/search/search/query_parser/parsers/__init__.py b/nucliadb/src/nucliadb/search/search/query_parser/parsers/__init__.py new file mode 100644 index 0000000000..2f81467719 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/query_parser/parsers/__init__.py @@ -0,0 +1,23 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + +from .catalog import parse_catalog # noqa: F401 +from .find import parse_find # noqa: F401 +from .graph import parse_graph_node_search, parse_graph_relation_search, parse_graph_search # noqa: F401 diff --git a/nucliadb/src/nucliadb/search/search/query_parser/catalog.py b/nucliadb/src/nucliadb/search/search/query_parser/parsers/catalog.py similarity index 98% rename from nucliadb/src/nucliadb/search/search/query_parser/catalog.py rename to nucliadb/src/nucliadb/search/search/query_parser/parsers/catalog.py index 4f8a6053fd..2943b2b5ee 100644 --- a/nucliadb/src/nucliadb/search/search/query_parser/catalog.py +++ b/nucliadb/src/nucliadb/search/search/query_parser/parsers/catalog.py @@ -18,10 +18,10 @@ # along with this program. If not, see . # - from nucliadb.common import datamanagers from nucliadb.search.search.exceptions import InvalidQueryError from nucliadb.search.search.filters import translate_label +from nucliadb.search.search.query_parser.filter_expression import FacetFilterTypes, facet_from_filter from nucliadb.search.search.query_parser.models import ( CatalogExpression, CatalogQuery, @@ -44,8 +44,6 @@ SortOrder, ) -from .filter_expression import FacetFilterTypes, facet_from_filter - async def parse_catalog(kbid: str, item: search_models.CatalogRequest) -> CatalogQuery: has_old_filters = ( diff --git a/nucliadb/src/nucliadb/search/search/query_parser/parser.py b/nucliadb/src/nucliadb/search/search/query_parser/parsers/find.py similarity index 99% rename from nucliadb/src/nucliadb/search/search/query_parser/parser.py rename to nucliadb/src/nucliadb/search/search/query_parser/parsers/find.py index c104880477..e24f50a2a0 100644 --- a/nucliadb/src/nucliadb/search/search/query_parser/parser.py +++ b/nucliadb/src/nucliadb/search/search/query_parser/parsers/find.py @@ -18,7 +18,6 @@ # along with this program. If not, see . # - from pydantic import ValidationError from nucliadb.search.search.query_parser.exceptions import InternalParserError diff --git a/nucliadb/src/nucliadb/search/search/query_parser/parsers/graph.py b/nucliadb/src/nucliadb/search/search/query_parser/parsers/graph.py new file mode 100644 index 0000000000..c380401d69 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/query_parser/parsers/graph.py @@ -0,0 +1,177 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + + +from nucliadb.common.models_utils.from_proto import RelationNodeTypeMap +from nucliadb.search.search.query_parser.models import GraphRetrieval +from nucliadb_models.graph import requests as graph_requests +from nucliadb_protos import nodereader_pb2 + + +def parse_graph_search(item: graph_requests.GraphSearchRequest) -> GraphRetrieval: + pb = nodereader_pb2.GraphSearchRequest() + pb.query.path.CopyFrom(_parse_path_query(item.query)) + pb.top_k = item.top_k + return pb + + +def parse_graph_node_search(item: graph_requests.GraphNodesSearchRequest) -> GraphRetrieval: + pb = nodereader_pb2.GraphSearchRequest() + pb.query.path.CopyFrom(_parse_node_query(item.query)) + pb.top_k = item.top_k + return pb + + +def parse_graph_relation_search(item: graph_requests.GraphRelationsSearchRequest) -> GraphRetrieval: + pb = nodereader_pb2.GraphSearchRequest() + pb.query.path.CopyFrom(_parse_relation_query(item.query)) + pb.top_k = item.top_k + return pb + + +def _parse_path_query(expr: graph_requests.GraphPathQuery) -> nodereader_pb2.GraphQuery.PathQuery: + pb = nodereader_pb2.GraphQuery.PathQuery() + + if isinstance(expr, graph_requests.And): + for op in expr.operands: + pb.bool_and.operands.append(_parse_path_query(op)) + + elif isinstance(expr, graph_requests.Or): + for op in expr.operands: + pb.bool_or.operands.append(_parse_path_query(op)) + + elif isinstance(expr, graph_requests.Not): + pb.bool_not.CopyFrom(_parse_path_query(expr.operand)) + + elif isinstance(expr, graph_requests.GraphPath): + if expr.source is not None: + _set_node_to_pb(expr.source, pb.path.source) + + if expr.destination is not None: + _set_node_to_pb(expr.destination, pb.path.destination) + + if expr.relation is not None: + relation = expr.relation + if relation.label is not None: + pb.path.relation.value = relation.label + + pb.path.undirected = expr.undirected + + elif isinstance(expr, graph_requests.SourceNode): + _set_node_to_pb(expr, pb.path.source) + + elif isinstance(expr, graph_requests.DestinationNode): + _set_node_to_pb(expr, pb.path.destination) + + elif isinstance(expr, graph_requests.AnyNode): + _set_node_to_pb(expr, pb.path.source) + pb.path.undirected = True + + elif isinstance(expr, graph_requests.Relation): + if expr.label is not None: + pb.path.relation.value = expr.label + + else: # pragma: nocover + # This is a trick so mypy generates an error if this branch can be reached, + # that is, if we are missing some ifs + _a: int = "a" + + return pb + + +def _parse_node_query(expr: graph_requests.GraphNodesQuery) -> nodereader_pb2.GraphQuery.PathQuery: + pb = nodereader_pb2.GraphQuery.PathQuery() + + if isinstance(expr, graph_requests.And): + for op in expr.operands: + pb.bool_and.operands.append(_parse_node_query(op)) + + elif isinstance(expr, graph_requests.Or): + for op in expr.operands: + pb.bool_or.operands.append(_parse_node_query(op)) + + elif isinstance(expr, graph_requests.Not): + pb.bool_not.CopyFrom(_parse_node_query(expr.operand)) + + elif isinstance(expr, graph_requests.SourceNode): + _set_node_to_pb(expr, pb.path.source) + + elif isinstance(expr, graph_requests.DestinationNode): + _set_node_to_pb(expr, pb.path.destination) + + elif isinstance(expr, graph_requests.AnyNode): + _set_node_to_pb(expr, pb.path.source) + pb.path.undirected = True + + else: # pragma: nocover + # This is a trick so mypy generates an error if this branch can be reached, + # that is, if we are missing some ifs + _a: int = "a" + + return pb + + +def _parse_relation_query( + expr: graph_requests.GraphRelationsQuery, +) -> nodereader_pb2.GraphQuery.PathQuery: + pb = nodereader_pb2.GraphQuery.PathQuery() + + if isinstance(expr, graph_requests.And): + for op in expr.operands: + pb.bool_and.operands.append(_parse_relation_query(op)) + + elif isinstance(expr, graph_requests.Or): + for op in expr.operands: + pb.bool_or.operands.append(_parse_relation_query(op)) + + elif isinstance(expr, graph_requests.Not): + pb.bool_not.CopyFrom(_parse_relation_query(expr.operand)) + + elif isinstance(expr, graph_requests.Relation): + if expr.label is not None: + pb.path.relation.value = expr.label + + else: # pragma: nocover + # This is a trick so mypy generates an error if this branch can be reached, + # that is, if we are missing some ifs + _a: int = "a" + + return pb + + +def _set_node_to_pb(node: graph_requests.GraphNode, pb: nodereader_pb2.GraphQuery.Node): + if node.value is not None: + pb.value = node.value + if node.match == graph_requests.NodeMatchKind.EXACT: + pb.match_kind = nodereader_pb2.GraphQuery.Node.MatchKind.EXACT + + elif node.match == graph_requests.NodeMatchKind.FUZZY: + pb.match_kind = nodereader_pb2.GraphQuery.Node.MatchKind.FUZZY + + else: # pragma: nocover + # This is a trick so mypy generates an error if this branch can be reached, + # that is, if we are missing some ifs + _a: int = "a" + + if node.type is not None: + pb.node_type = RelationNodeTypeMap[node.type] + + if node.group is not None: + pb.node_subtype = node.group diff --git a/nucliadb/src/nucliadb/search/search/shards.py b/nucliadb/src/nucliadb/search/search/shards.py index ae4b28e899..b713c6ebb4 100644 --- a/nucliadb/src/nucliadb/search/search/shards.py +++ b/nucliadb/src/nucliadb/search/search/shards.py @@ -26,6 +26,8 @@ from nucliadb.common.cluster.base import AbstractIndexNode from nucliadb_protos.nodereader_pb2 import ( GetShardRequest, + GraphSearchRequest, + GraphSearchResponse, SearchRequest, SearchResponse, SuggestRequest, @@ -79,3 +81,16 @@ async def suggest_shard(node: AbstractIndexNode, shard: str, query: SuggestReque req.shard = shard with node_observer({"type": "suggest", "node_id": node.id}): return await node.reader.Suggest(req) # type: ignore + + +@backoff.on_exception( + backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup +) +async def graph_search_shard( + node: AbstractIndexNode, shard: str, query: GraphSearchRequest +) -> GraphSearchResponse: + req = GraphSearchRequest() + req.CopyFrom(query) + req.shard = shard + with node_observer({"type": "graph_search", "node_id": node.id}): + return await node.reader.GraphSearch(req) # type: ignore diff --git a/nucliadb/tests/ndbfixtures/resources.py b/nucliadb/tests/ndbfixtures/resources.py index 60b7e039f3..110e45c7e7 100644 --- a/nucliadb/tests/ndbfixtures/resources.py +++ b/nucliadb/tests/ndbfixtures/resources.py @@ -34,6 +34,7 @@ from nucliadb.ingest.orm.resource import Resource from nucliadb.tests.vectors import V1 from nucliadb.writer.api.v1.router import KB_PREFIX, KBS_PREFIX +from nucliadb_models.metadata import RelationEntity, RelationNodeType from nucliadb_protos import utils_pb2 as upb from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata from nucliadb_protos.utils_pb2 import Relation, RelationNode @@ -385,3 +386,81 @@ async def knowledge_graph( assert resp.status_code == 200, resp.content return (nodes, edges, rid) + + +@pytest.fixture(scope="function") +async def kb_with_entity_graph( + nucliadb_reader: AsyncClient, + nucliadb_writer: AsyncClient, + standalone_knowledgebox: str, + entity_graph: tuple[dict[str, RelationEntity], list[tuple[str, str, str]]], +) -> AsyncIterator[str]: + kbid = standalone_knowledgebox + entities, paths = entity_graph + + resp = await nucliadb_writer.post( + f"/kb/{kbid}/resources", + json={ + "usermetadata": { + "relations": [ + { + "relation": "ENTITY", + "label": relation, + "from": entities[source].model_dump(), + "to": entities[target].model_dump(), + } + for source, relation, target in paths + ], + }, + }, + ) + assert resp.status_code == 201 + + yield kbid + + +@pytest.fixture(scope="function") +def entity_graph() -> tuple[dict[str, RelationEntity], list[tuple[str, str, str]]]: + entities = { + "Anastasia": RelationEntity(value="Anastasia", type=RelationNodeType.ENTITY, group="PERSON"), + "Anna": RelationEntity(value="Anna", type=RelationNodeType.ENTITY, group="PERSON"), + "Apollo": RelationEntity(value="Apollo", type=RelationNodeType.ENTITY, group="PROJECT"), + "Cat": RelationEntity(value="Cat", type=RelationNodeType.ENTITY, group="ANIMAL"), + "Climbing": RelationEntity(value="Climbing", type=RelationNodeType.ENTITY, group="ACTIVITY"), + "Computer science": RelationEntity( + value="Computer science", type=RelationNodeType.ENTITY, group="STUDY_FIELD" + ), + "Dimitri": RelationEntity(value="Dimitri", type=RelationNodeType.ENTITY, group="PERSON"), + "Erin": RelationEntity(value="Erin", type=RelationNodeType.ENTITY, group="PERSON"), + "Jerry": RelationEntity(value="Jerry", type=RelationNodeType.ENTITY, group="ANIMAL"), + "Margaret": RelationEntity(value="Margaret", type=RelationNodeType.ENTITY, group="PERSON"), + "Mouse": RelationEntity(value="Mouse", type=RelationNodeType.ENTITY, group="ANIMAL"), + "New York": RelationEntity(value="New York", type=RelationNodeType.ENTITY, group="PLACE"), + "Olympic athlete": RelationEntity( + value="Olympic athlete", type=RelationNodeType.ENTITY, group="SPORT" + ), + "Peter": RelationEntity(value="Peter", type=RelationNodeType.ENTITY, group="PERSON"), + "Rocket": RelationEntity(value="Rocket", type=RelationNodeType.ENTITY, group="VEHICLE"), + "Tom": RelationEntity(value="Tom", type=RelationNodeType.ENTITY, group="ANIMAL"), + "UK": RelationEntity(value="UK", type=RelationNodeType.ENTITY, group="PLACE"), + } + graph = [ + ("Anastasia", "IS_FRIEND", "Anna"), + ("Anna", "FOLLOW", "Erin"), + ("Anna", "LIVE_IN", "New York"), + ("Anna", "LOVE", "Cat"), + ("Anna", "WORK_IN", "New York"), + ("Apollo", "IS", "Rocket"), + ("Dimitri", "LOVE", "Anastasia"), + ("Erin", "BORN_IN", "UK"), + ("Erin", "IS", "Olympic athlete"), + ("Erin", "LOVE", "Climbing"), + ("Jerry", "IS", "Mouse"), + ("Margaret", "DEVELOPED", "Apollo"), + ("Margaret", "WORK_IN", "Computer science"), + ("Peter", "LIVE_IN", "New York"), + ("Tom", "CHASE", "Jerry"), + ("Tom", "IS", "Cat"), + ] + + return (entities, graph) diff --git a/nucliadb/tests/nucliadb/integration/search/graph/__init__.py b/nucliadb/tests/nucliadb/integration/search/graph/__init__.py new file mode 100644 index 0000000000..3b734776ac --- /dev/null +++ b/nucliadb/tests/nucliadb/integration/search/graph/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# diff --git a/nucliadb/tests/nucliadb/integration/search/graph/test_graph_crud.py b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_crud.py new file mode 100644 index 0000000000..fe225c5350 --- /dev/null +++ b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_crud.py @@ -0,0 +1,129 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + +import pytest +from httpx import AsyncClient + +from nucliadb_models.metadata import RelationEntity + + +@pytest.mark.deploy_modes("standalone") +async def test_user_defined_knowledge_graph( + nucliadb_reader: AsyncClient, + nucliadb_writer: AsyncClient, + standalone_knowledgebox: str, + entity_graph: tuple[dict[str, RelationEntity], list[tuple[str, str, str]]], +): + kbid = standalone_knowledgebox + entities, paths = entity_graph + + graph = paths[:3] + resp = await nucliadb_writer.post( + f"/kb/{kbid}/resources", + json={ + "title": "Knowledge graph", + "slug": "knowledge-graph", + "summary": "User defined knowledge graph", + "usermetadata": { + "relations": [ + { + "relation": "ENTITY", + "label": relation, + "from": entities[source].model_dump(), + "to": entities[target].model_dump(), + } + for source, relation, target in graph + ], + }, + }, + ) + assert resp.status_code == 201 + rid = resp.json()["uuid"] + + resp = await nucliadb_reader.get( + f"/kb/{kbid}/resource/{rid}", + params={ + "show": ["basic", "relations"], + }, + ) + assert resp.status_code == 200 + user_graph = resp.json()["usermetadata"]["relations"] + assert len(user_graph) == len(graph) + + # Update graph + + graph = paths + resp = await nucliadb_writer.patch( + f"/kb/{kbid}/resource/{rid}", + json={ + "usermetadata": { + "relations": [ + { + "relation": "ENTITY", + "label": relation, + "from": entities[source].model_dump(), + "to": entities[target].model_dump(), + } + for source, relation, target in graph + ], + }, + }, + ) + assert resp.status_code == 200 + + resp = await nucliadb_reader.get( + f"/kb/{kbid}/resource/{rid}", + params={ + "show": ["basic", "relations"], + }, + ) + assert resp.status_code == 200 + user_graph = resp.json()["usermetadata"]["relations"] + assert len(user_graph) == len(graph) + + # Search graph + + resp = await nucliadb_reader.post( + f"/kb/{kbid}/find", + json={ + "query_entities": [ + { + "name": entities["Anna"].value, + "type": entities["Anna"].type.value, + "subtype": entities["Anna"].group, + } + ], + "features": ["relations"], + }, + ) + assert resp.status_code == 200 + body = resp.json() + retrieved_graph = set() + for source, subgraph in body["relations"]["entities"].items(): + for target in subgraph["related_to"]: + if target["direction"] == "in": + path = (source, "-", target["relation_label"], "->", target["entity"]) + else: + path = (source, "<-", target["relation_label"], "-", target["entity"]) + + assert path not in retrieved_graph, "We don't expect duplicated paths" + retrieved_graph.add(path) + + assert len(retrieved_graph) == 5 diff --git a/nucliadb/tests/nucliadb/integration/search/graph/test_graph_nodes_search.py b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_nodes_search.py new file mode 100644 index 0000000000..b237785d02 --- /dev/null +++ b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_nodes_search.py @@ -0,0 +1,104 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +import pytest +from httpx import AsyncClient + +from nucliadb_models.graph.responses import GraphNodesSearchResponse + +# FIXME: all asserts here are wrong, as we are not deduplicating nor Rust is +# returning the proper entities. Fix all asserts once both issues are tackled + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_nodes_search( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # (:PERSON) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/nodes", + json={ + "query": { + "prop": "node", + "group": "PERSON", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + nodes = GraphNodesSearchResponse.model_validate(resp.json()).nodes + assert len(nodes) == 24 + + # (:PERSON)-[]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/nodes", + json={ + "query": { + "prop": "source_node", + "group": "PERSON", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + nodes = GraphNodesSearchResponse.model_validate(resp.json()).nodes + assert len(nodes) == 24 + + # ()-[]->(:PERSON) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/nodes", + json={ + "query": { + "prop": "destination_node", + "group": "PERSON", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + nodes = GraphNodesSearchResponse.model_validate(resp.json()).nodes + assert len(nodes) == 6 + + # (:Anastasia) -- implemented as an or instead of using prop=node + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/nodes", + json={ + "query": { + "or": [ + { + "prop": "source_node", + "value": "Anastasia", + "group": "PERSON", + }, + { + "prop": "destination_node", + "value": "Anastasia", + "group": "PERSON", + }, + ], + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + nodes = GraphNodesSearchResponse.model_validate(resp.json()).nodes + assert len(nodes) == 4 diff --git a/nucliadb/tests/nucliadb/integration/search/graph/test_graph_path_search.py b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_path_search.py new file mode 100644 index 0000000000..ad6c796c32 --- /dev/null +++ b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_path_search.py @@ -0,0 +1,467 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +import pytest +from httpx import AsyncClient + +from nucliadb_models.graph import responses as graph_responses +from nucliadb_models.graph.responses import GraphSearchResponse + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_search__node_queries( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # ()-[]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": {}, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 16 + + # (:PERSON)-[]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "source_node", + "group": "PERSON", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 12 + + # (:PERSON)-[]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "group": "PERSON", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 12 + + # (:Anna)-[]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "Anna", + "group": "PERSON", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 4 + assert ("Anna", "FOLLOW", "Erin") in paths + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Anna", "LOVE", "Cat") in paths + assert ("Anna", "WORK_IN", "New York") in paths + + # ()-[]->(:Anna) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "destination": { + "value": "Anna", + "group": "PERSON", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 1 + assert ("Anastasia", "IS_FRIEND", "Anna") in paths + + # (:Anna) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "Anna", + "group": "PERSON", + }, + "undirected": True, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 5 + assert ("Anna", "FOLLOW", "Erin") in paths + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Anna", "LOVE", "Cat") in paths + assert ("Anna", "WORK_IN", "New York") in paths + assert ("Anastasia", "IS_FRIEND", "Anna") in paths + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_search__fuzzy_node_queries( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # (:~Anastas) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "Anastas", + "match": "fuzzy", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 1 + assert ("Anastasia", "IS_FRIEND", "Anna") in paths + + # (:~AnXstXsiX) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "AnXstXsiX", + "match": "fuzzy", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 0 + + # (:~AnXstXsia) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "AnXstXsia", + "match": "fuzzy", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 1 + assert ("Anastasia", "IS_FRIEND", "Anna") in paths + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_search__relation_queries( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # ()-[LIVE_IN]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "relation", + "label": "LIVE_IN", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 2 + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + # ()-[LIVE_IN]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "relation": { + "label": "LIVE_IN", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 2 + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + # ()-[: LIVE_IN | BORN_IN]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "or": [ + { + "prop": "path", + "relation": { + "label": "LIVE_IN", + }, + }, + { + "prop": "path", + "relation": { + "label": "BORN_IN", + }, + }, + ] + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 3 + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Erin", "BORN_IN", "UK") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + # ()-[:!LIVE_IN]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "not": { + "prop": "path", + "relation": { + "label": "LIVE_IN", + }, + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 14 + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_search__directed_path_queries( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # (:Erin)-[]->(:UK) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "Erin", + "group": "PERSON", + }, + "destination": { + "value": "UK", + "group": "PLACE", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 1 + assert ("Erin", "BORN_IN", "UK") in paths + + # (:PERSON)-[]->(:PLACE) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "group": "PERSON", + }, + "destination": { + "group": "PLACE", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 4 + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Anna", "WORK_IN", "New York") in paths + assert ("Erin", "BORN_IN", "UK") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + # (:PERSON)-[LIVE_IN]->(:PLACE) + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "group": "PERSON", + }, + "relation": { + "label": "LIVE_IN", + }, + "destination": { + "group": "PLACE", + }, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 2 + assert ("Anna", "LIVE_IN", "New York") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + # (:!Anna)-[:LIVE_IN|LOVE]->() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "and": [ + { + "not": { + "prop": "path", + "source": { + "value": "Anna", + }, + } + }, + { + "or": [ + { + "prop": "path", + "relation": { + "label": "LIVE_IN", + }, + }, + { + "prop": "path", + "relation": { + "label": "LOVE", + }, + }, + ] + }, + ] + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 3 + assert ("Erin", "LOVE", "Climbing") in paths + assert ("Dimitri", "LOVE", "Anastasia") in paths + assert ("Peter", "LIVE_IN", "New York") in paths + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_search__undirected_path_queries( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # (:Anna)-[:IS_FRIEND]-() + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph", + json={ + "query": { + "prop": "path", + "source": { + "value": "Anna", + }, + "relation": { + "label": "IS_FRIEND", + }, + "undirected": True, + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + paths = simple_paths(GraphSearchResponse.model_validate(resp.json()).paths) + assert len(paths) == 1 + assert ("Anastasia", "IS_FRIEND", "Anna") in paths + + +def simple_paths(paths: list[graph_responses.GraphPath]) -> list[tuple[str, str, str]]: + simple_paths = [] + for path in paths: + # response should never return empty nodes/relations + assert path.source is not None + assert path.source.value is not None + assert path.relation is not None + assert path.relation.label is not None + assert path.destination is not None + assert path.destination.value is not None + simple_paths.append((path.source.value, path.relation.label, path.destination.value)) + return simple_paths diff --git a/nucliadb/tests/nucliadb/integration/search/graph/test_graph_relations_search.py b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_relations_search.py new file mode 100644 index 0000000000..1db88d1c90 --- /dev/null +++ b/nucliadb/tests/nucliadb/integration/search/graph/test_graph_relations_search.py @@ -0,0 +1,90 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +import pytest +from httpx import AsyncClient + +from nucliadb_models.graph.responses import GraphRelationsSearchResponse + +# FIXME: in this case, the number of relations returned is correct but they are +# duplicated, as Rust don't deduplicate them. Maybe the response should be +# different? + + +@pytest.mark.deploy_modes("standalone") +async def test_graph_nodes_search( + nucliadb_reader: AsyncClient, + kb_with_entity_graph: str, +): + kbid = kb_with_entity_graph + + # [:LIVE_IN] + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/relations", + json={ + "query": { + "prop": "relation", + "label": "LIVE_IN", + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + relations = GraphRelationsSearchResponse.model_validate(resp.json()).relations + assert len(relations) == 2 + + # [: LIVE_IN | BORN_IN] + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/relations", + json={ + "query": { + "or": [ + { + "prop": "relation", + "label": "LIVE_IN", + }, + { + "prop": "relation", + "label": "BORN_IN", + }, + ] + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + relations = GraphRelationsSearchResponse.model_validate(resp.json()).relations + assert len(relations) == 3 + + # [:!LIVE_IN] + resp = await nucliadb_reader.post( + f"/kb/{kbid}/graph/relations", + json={ + "query": { + "not": { + "prop": "relation", + "label": "LIVE_IN", + } + }, + "top_k": 100, + }, + ) + assert resp.status_code == 200 + relations = GraphRelationsSearchResponse.model_validate(resp.json()).relations + assert len(relations) == 14 diff --git a/nucliadb/tests/nucliadb/integration/test_relations.py b/nucliadb/tests/nucliadb/integration/test_relations.py index 036f68c7bf..5ec981a5b0 100644 --- a/nucliadb/tests/nucliadb/integration/test_relations.py +++ b/nucliadb/tests/nucliadb/integration/test_relations.py @@ -142,7 +142,6 @@ async def test_broker_message_relations( @pytest.mark.deploy_modes("standalone") async def test_extracted_relations( - nucliadb_ingest_grpc: WriterStub, nucliadb_reader: AsyncClient, nucliadb_writer: AsyncClient, standalone_knowledgebox, diff --git a/nucliadb/tests/search/unit/search/search/test_pgcatalog.py b/nucliadb/tests/search/unit/search/search/test_pgcatalog.py index baef7776b3..166679e805 100644 --- a/nucliadb/tests/search/unit/search/search/test_pgcatalog.py +++ b/nucliadb/tests/search/unit/search/search/test_pgcatalog.py @@ -20,8 +20,8 @@ from datetime import datetime from nucliadb.search.search.pgcatalog import _convert_filter, _prepare_query -from nucliadb.search.search.query_parser.catalog import parse_catalog from nucliadb.search.search.query_parser.models import CatalogExpression +from nucliadb.search.search.query_parser.parsers import parse_catalog from nucliadb_models.filters import CatalogFilterExpression from nucliadb_models.search import ( CatalogRequest, diff --git a/nucliadb/tests/search/unit/search/test_query_parsing.py b/nucliadb/tests/search/unit/search/test_query_parsing.py index 2d287b9145..50af59c91d 100644 --- a/nucliadb/tests/search/unit/search/test_query_parsing.py +++ b/nucliadb/tests/search/unit/search/test_query_parsing.py @@ -24,7 +24,7 @@ from pydantic import ValidationError from nucliadb.search.search.query_parser import models as parser_models -from nucliadb.search.search.query_parser.parser import parse_find +from nucliadb.search.search.query_parser.parsers import parse_find from nucliadb_models import search as search_models from nucliadb_models.search import FindRequest diff --git a/nucliadb/tests/search/unit/test_rank_fusion.py b/nucliadb/tests/search/unit/test_rank_fusion.py index eec54aba0b..5bdf502858 100644 --- a/nucliadb/tests/search/unit/test_rank_fusion.py +++ b/nucliadb/tests/search/unit/test_rank_fusion.py @@ -36,7 +36,7 @@ keyword_result_to_text_block_match, semantic_result_to_text_block_match, ) -from nucliadb.search.search.query_parser.parser import parse_find +from nucliadb.search.search.query_parser.parsers import parse_find from nucliadb.search.search.rank_fusion import LegacyRankFusion, ReciprocalRankFusion, get_rank_fusion from nucliadb_models.search import SCORE_TYPE, FindRequest from nucliadb_protos.nodereader_pb2 import DocumentScored, ParagraphResult diff --git a/nucliadb_models/src/nucliadb_models/graph/__init__.py b/nucliadb_models/src/nucliadb_models/graph/__init__.py new file mode 100644 index 0000000000..6e234e4311 --- /dev/null +++ b/nucliadb_models/src/nucliadb_models/graph/__init__.py @@ -0,0 +1,24 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + +from . import ( # noqa: F401 + requests, + responses, +) diff --git a/nucliadb_models/src/nucliadb_models/graph/requests.py b/nucliadb_models/src/nucliadb_models/graph/requests.py new file mode 100644 index 0000000000..12c3cfdc86 --- /dev/null +++ b/nucliadb_models/src/nucliadb_models/graph/requests.py @@ -0,0 +1,160 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from enum import Enum +from typing import Annotated, Literal, Optional, Union + +from pydantic import BaseModel, Discriminator, Field, Tag, model_validator +from typing_extensions import Self + +from nucliadb_models.filters import And, Not, Or, filter_discriminator +from nucliadb_models.metadata import RelationNodeType + +## Models for graph nodes and relations + + +class NodeMatchKind(str, Enum): + EXACT = "exact" + FUZZY = "fuzzy" + + +class GraphNode(BaseModel): + value: Optional[str] = None + match: NodeMatchKind = NodeMatchKind.EXACT + type: Optional[RelationNodeType] = RelationNodeType.ENTITY + group: Optional[str] = None + + @model_validator(mode="after") + def validate_fuzzy_usage(self) -> Self: + if self.match == NodeMatchKind.FUZZY: + if self.value is None: + raise ValueError("Fuzzy match can only be used if a node value is provided") + else: + if len(self.value) < 3: + raise ValueError( + "Fuzzy match must be used with values containing at least 3 characters" + ) + return self + + +class GraphRelation(BaseModel): + label: Optional[str] = None + + +## Models for query expressions + + +class AnyNode(GraphNode): + prop: Literal["node"] + + +class SourceNode(GraphNode): + prop: Literal["source_node"] + + +class DestinationNode(GraphNode): + prop: Literal["destination_node"] + + +class Relation(GraphRelation): + prop: Literal["relation"] + + +class GraphPath(BaseModel, extra="forbid"): + prop: Literal["path"] = "path" + source: Optional[GraphNode] = None + relation: Optional[GraphRelation] = None + destination: Optional[GraphNode] = None + undirected: bool = False + + +## Requests models + + +class BaseGraphSearchRequest(BaseModel): + top_k: int = Field(default=50, title="Number of results to retrieve") + + +graph_query_discriminator = filter_discriminator + + +# Paths search + +GraphPathQuery = Annotated[ + Union[ + # bool expressions + Annotated[And["GraphPathQuery"], Tag("and")], + Annotated[Or["GraphPathQuery"], Tag("or")], + Annotated[Not["GraphPathQuery"], Tag("not")], + # paths + Annotated[GraphPath, Tag("path")], + # nodes + Annotated[SourceNode, Tag("source_node")], + Annotated[DestinationNode, Tag("destination_node")], + Annotated[AnyNode, Tag("node")], + # relations + Annotated[Relation, Tag("relation")], + ], + Discriminator(graph_query_discriminator), +] + + +class GraphSearchRequest(BaseGraphSearchRequest): + query: GraphPathQuery + + +# Nodes search + +GraphNodesQuery = Annotated[ + Union[ + Annotated[And["GraphNodesQuery"], Tag("and")], + Annotated[Or["GraphNodesQuery"], Tag("or")], + Annotated[Not["GraphNodesQuery"], Tag("not")], + Annotated[SourceNode, Tag("source_node")], + Annotated[DestinationNode, Tag("destination_node")], + Annotated[AnyNode, Tag("node")], + ], + Discriminator(graph_query_discriminator), +] + + +class GraphNodesSearchRequest(BaseGraphSearchRequest): + query: GraphNodesQuery + + +# Relations search + +GraphRelationsQuery = Annotated[ + Union[ + Annotated[Or["GraphRelationsQuery"], Tag("or")], + Annotated[Not["GraphRelationsQuery"], Tag("not")], + Annotated[Relation, Tag("relation")], + ], + Discriminator(graph_query_discriminator), +] + + +class GraphRelationsSearchRequest(BaseGraphSearchRequest): + query: GraphRelationsQuery + + +# We need this to avoid issues with pydantic and generic types defined in another module +GraphSearchRequest.model_rebuild() +GraphNodesSearchRequest.model_rebuild() +GraphRelationsSearchRequest.model_rebuild() diff --git a/nucliadb_models/src/nucliadb_models/graph/responses.py b/nucliadb_models/src/nucliadb_models/graph/responses.py new file mode 100644 index 0000000000..a9aa0813ed --- /dev/null +++ b/nucliadb_models/src/nucliadb_models/graph/responses.py @@ -0,0 +1,62 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from enum import Enum + +from pydantic import BaseModel + +from nucliadb_models.metadata import RelationNodeType + + +class GraphNode(BaseModel): + value: str + type: RelationNodeType + group: str + + +class GraphNodePosition(str, Enum): + ANY = "any" + SOURCE = "source" + DESTINATION = "destination" + + +class PositionedGraphNode(GraphNode): + position: GraphNodePosition = GraphNodePosition.ANY + + +class GraphRelation(BaseModel): + label: str + + +class GraphPath(BaseModel): + source: GraphNode + relation: GraphRelation + destination: GraphNode + + +class GraphSearchResponse(BaseModel): + paths: list[GraphPath] + + +class GraphNodesSearchResponse(BaseModel): + nodes: list[GraphNode] + + +class GraphRelationsSearchResponse(BaseModel): + relations: list[GraphRelation]