Feat: Adds subgraph retriever to graph based completion searches (#874)
<!-- .github/pull_request_template.md --> ## Description Adds subgraph retriever to graph based completion searches ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
834d959b11
commit
965033e161
19 changed files with 302 additions and 31 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union
|
||||
from typing import Union, Optional, List, Type
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
|
@ -13,6 +13,8 @@ async def search(
|
|||
datasets: Union[list[str], str, None] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -22,12 +24,14 @@ async def search(
|
|||
user = await get_default_user()
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text,
|
||||
query_type,
|
||||
datasets,
|
||||
user,
|
||||
query_text=query_text,
|
||||
query_type=query_type,
|
||||
datasets=datasets,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -37,3 +37,17 @@ class EntityAlreadyExistsError(CogneeApiError):
|
|||
status_code=status.HTTP_409_CONFLICT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NodesetFilterNotSupportedError(CogneeApiError):
|
||||
"""Nodeset filter is not supported by the current database"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The nodeset filter is not supported in the current graph database.",
|
||||
name: str = "NodeSetFilterNotSupportedError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.status_code = status_code
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
from functools import wraps
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from typing import Optional, Dict, Any, List, Tuple, Type
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -183,6 +183,13 @@ class GraphDBInterface(ABC):
|
|||
"""Get all neighboring nodes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
"""Get nodeset subgraph"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_connections(
|
||||
self, node_id: str
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""Adapter for Kuzu graph database."""
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -728,6 +729,65 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to get graph data: {e}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
primary_query = """
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = $label AND n.name = wantedName
|
||||
RETURN DISTINCT n.id
|
||||
"""
|
||||
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
|
||||
primary_ids = [row[0] for row in primary_rows]
|
||||
if not primary_ids:
|
||||
return [], []
|
||||
|
||||
neighbor_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(nbr:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN DISTINCT nbr.id
|
||||
"""
|
||||
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
|
||||
neighbor_ids = [row[0] for row in nbr_rows]
|
||||
|
||||
all_ids = list({*primary_ids, *neighbor_ids})
|
||||
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN n.id, n.name, n.type, n.properties
|
||||
"""
|
||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||
nodes: List[Tuple[str, dict]] = []
|
||||
for node_id, name, typ, props in node_rows:
|
||||
data = {"id": node_id, "name": name, "type": typ}
|
||||
if props:
|
||||
try:
|
||||
data.update(json.loads(props))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for node {node_id}")
|
||||
nodes.append((node_id, data))
|
||||
|
||||
edges_query = """
|
||||
MATCH (a:Node)-[r:EDGE]-(b:Node)
|
||||
WHERE a.id IN $ids AND b.id IN $ids
|
||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||
edges: List[Tuple[str, str, str, dict]] = []
|
||||
for from_id, to_id, rel_type, props in edge_rows:
|
||||
data = {}
|
||||
if props:
|
||||
try:
|
||||
data = json.loads(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||
edges.append((from_id, to_id, rel_type, data))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -13,6 +13,7 @@ from neo4j.exceptions import Neo4jError
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
|
||||
|
||||
logger = get_logger("MemgraphAdapter", level=ERROR)
|
||||
|
||||
|
|
@ -482,6 +483,12 @@ class MemgraphAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
"""Get nodeset subgraph"""
|
||||
raise NodesetFilterNotSupportedError
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import json
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -517,6 +517,54 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ from datetime import datetime, timezone
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Dict, Any, List, Union
|
||||
from typing import Dict, Any, List, Union, Type, Tuple
|
||||
from uuid import UUID
|
||||
import aiofiles
|
||||
import aiofiles.os as aiofiles_os
|
||||
|
|
@ -396,6 +398,12 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
logger.error("Failed to delete graph: %s", error)
|
||||
raise error
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
"""Get nodeset subgraph"""
|
||||
raise NodesetFilterNotSupportedError
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -5,4 +5,3 @@ class NodeSet(DataPoint):
|
|||
"""NodeSet data point."""
|
||||
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
|
|
@ -61,22 +61,33 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node_dimension=1,
|
||||
edge_dimension=1,
|
||||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
if len(memory_fragment_filter) == 0:
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodetes projected from the database."
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
|
||||
if not nodes_data:
|
||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||
if not edges_data:
|
||||
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ def expand_with_nodes_and_edges(
|
|||
name=ont_node_name,
|
||||
description=ont_node_name,
|
||||
ontology_valid=True,
|
||||
belongs_to_set=data_chunk.belongs_to_set,
|
||||
)
|
||||
|
||||
for source, relation, target in ontology_entity_type_edges:
|
||||
|
|
@ -144,6 +145,7 @@ def expand_with_nodes_and_edges(
|
|||
is_a=type_node,
|
||||
description=node.description,
|
||||
ontology_valid=ontology_validated_source_ent,
|
||||
belongs_to_set=data_chunk.belongs_to_set,
|
||||
)
|
||||
|
||||
added_nodes_map[entity_node_key] = entity_node
|
||||
|
|
@ -174,6 +176,7 @@ def expand_with_nodes_and_edges(
|
|||
name=ont_node_name,
|
||||
description=ont_node_name,
|
||||
ontology_valid=True,
|
||||
belongs_to_set=data_chunk.belongs_to_set,
|
||||
)
|
||||
|
||||
for source, relation, target in ontology_entity_edges:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -14,11 +14,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -18,11 +18,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
|
|
@ -18,11 +18,15 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
|
|
@ -68,7 +72,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
query,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||
|
|
@ -13,12 +13,16 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
summarize_prompt_path: str = "summarize_search_results.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
|
@ -55,6 +55,8 @@ def format_triplets(edges):
|
|||
|
||||
async def get_memory_fragment(
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -68,6 +70,8 @@ async def get_memory_fragment(
|
|||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
except EntityNotFoundError:
|
||||
pass
|
||||
|
|
@ -82,6 +86,8 @@ async def brute_force_triplet_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
@ -93,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
return retrieved_results
|
||||
|
||||
|
|
@ -104,6 +112,8 @@ async def brute_force_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -125,7 +135,9 @@ async def brute_force_search(
|
|||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(properties_to_project)
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, List, Type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
|
|
@ -33,12 +33,20 @@ async def search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
@ -61,29 +69,39 @@ async def specific_search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt_path=system_prompt_path, top_k=top_k
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path, top_k=top_k
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import os
|
|||
import shutil
|
||||
import cognee
|
||||
import pathlib
|
||||
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.operations import get_history
|
||||
|
|
@ -84,6 +87,30 @@ async def main():
|
|||
history = await get_history(user.id)
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
nodeset_text = "Neo4j is a graph database that supports cypher."
|
||||
|
||||
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
context_nonempty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["first"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
context_empty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["nonexistent"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||
)
|
||||
|
||||
assert context_empty == "", (
|
||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -89,6 +91,30 @@ async def main():
|
|||
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
nodeset_text = "Neo4j is a graph database that supports cypher."
|
||||
|
||||
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
context_nonempty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["first"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
context_empty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["nonexistent"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||
)
|
||||
|
||||
assert context_empty == "", (
|
||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pylint.checkers.utils import node_type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.methods.search import search, specific_search
|
||||
|
|
@ -68,7 +69,13 @@ async def test_search(
|
|||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
|
||||
mock_specific_search.assert_called_once_with(
|
||||
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10
|
||||
query_type,
|
||||
query_text,
|
||||
mock_user,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k=10,
|
||||
node_type=None,
|
||||
node_name=None,
|
||||
)
|
||||
|
||||
# Only the first two results should be included (doc_id3 is filtered out)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue