Deprecate SearchType.INSIGHTS, replace all references to default search type - SearchType.GRAPH_COMPLETION

This commit is contained in:
Daulet Amirkhanov 2025-10-07 16:03:12 +01:00
parent 583923903c
commit 63a1463073
25 changed files with 31 additions and 411 deletions

View file

@ -255,7 +255,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks

View file

@ -148,7 +148,7 @@ async def cognify(
# 2. Get entity relationships and connections
relationships = await cognee.search(
"connections between concepts",
query_type=SearchType.INSIGHTS
query_type=SearchType.GRAPH_COMPLETION
)
# 3. Find relevant document chunks

View file

@ -1,133 +0,0 @@
import asyncio
from typing import Any, Optional
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("InsightsRetriever")
class InsightsRetriever(BaseGraphRetriever):
"""
Retriever for handling graph connection-based insights.
Public methods include:
- get_context
- get_completion
Instance variables include:
- exploration_levels
- top_k
"""
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k
async def get_context(self, query: str) -> list:
"""
Find neighbours of a given node in the graph.
If the provided query does not correspond to an existing node,
search for similar entities and retrieve their connections.
Reraises NoDataError if there is no data found in the system.
Parameters:
-----------
- query (str): A string identifier for the node whose neighbours are to be
retrieved.
Returns:
--------
- list: A list of unique connections found for the queried node.
"""
if query is None:
return []
node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
try:
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
except CollectionNotFoundError as error:
logger.error("Entity collections not found")
raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
if len(relevant_results) == 0:
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
unique_node_connections_map = {}
unique_node_connections = []
for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections
# return [
# Edge(
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
# attributes={
# **connection[1],
# "relationship_type": connection[1]["relationship_name"],
# },
# )
# for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
Returns the graph connections context.
If a context is not provided, it fetches the context using the query provided.
Parameters:
-----------
- query (str): A string identifier used to fetch the context.
- context (Optional[Any]): An optional context to use for the completion; if None,
it fetches the context based on the query. (default None)
Returns:
--------
- Any: The context used for the completion, which is either provided or fetched
based on the query.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -9,7 +9,6 @@ from cognee.modules.search.exceptions import UnsupportedSearchTypeError
# Retrievers
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -44,10 +43,6 @@ async def get_search_type_tools(
SummariesRetriever(top_k=top_k).get_completion,
SummariesRetriever(top_k=top_k).get_context,
],
SearchType.INSIGHTS: [
InsightsRetriever(top_k=top_k).get_completion,
InsightsRetriever(top_k=top_k).get_context,
],
SearchType.CHUNKS: [
ChunksRetriever(top_k=top_k).get_completion,
ChunksRetriever(top_k=top_k).get_context,

View file

@ -3,7 +3,6 @@ from enum import Enum
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
RAG_COMPLETION = "RAG_COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"

View file

@ -159,7 +159,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -61,7 +61,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -157,7 +157,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -51,7 +51,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -56,7 +56,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -56,7 +56,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -60,7 +60,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -167,7 +167,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -65,7 +65,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -47,7 +47,7 @@ async def main():
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")

View file

@ -1,251 +0,0 @@
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.modules.engine.models import Entity, EntityType
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
class TestInsightsRetriever:
@pytest.mark.asyncio
async def test_insights_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
entityTypePerson = EntityType(
name="Person",
description="An individual",
)
person1 = Entity(
name="Steve Rodger",
is_a=entityTypePerson,
description="An American actor, comedian, and filmmaker",
)
person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
retriever = InsightsRetriever()
context = await retriever.get_context("Mike")
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_insights_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
entityTypePerson = EntityType(
name="Person",
description="An individual",
)
person1 = Entity(
name="Steve Rodger",
is_a=entityTypePerson,
description="An American actor, comedian, and filmmaker",
)
person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
person4 = Entity(
name="Jason Statham",
is_a=entityTypePerson,
description="An American actor",
)
person5 = Entity(
name="Mike Tyson",
is_a=entityTypePerson,
description="A former professional boxer from the United States",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
graph_engine = await get_graph_engine()
await graph_engine.add_edges(
[
(
(str)(person1.id),
(str)(company1.id),
"works_for",
dict(
relationship_name="works_for",
source_node_id=person1.id,
target_node_id=company1.id,
),
),
(
(str)(person2.id),
(str)(company2.id),
"works_for",
dict(
relationship_name="works_for",
source_node_id=person2.id,
target_node_id=company2.id,
),
),
(
(str)(person3.id),
(str)(company3.id),
"works_for",
dict(
relationship_name="works_for",
source_node_id=person3.id,
target_node_id=company3.id,
),
),
(
(str)(person4.id),
(str)(company1.id),
"works_for",
dict(
relationship_name="works_for",
source_node_id=person4.id,
target_node_id=company1.id,
),
),
(
(str)(person5.id),
(str)(company1.id),
"works_for",
dict(
relationship_name="works_for",
source_node_id=person5.id,
target_node_id=company1.id,
),
),
]
)
retriever = InsightsRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_insights_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = InsightsRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("Entity_name", payload_schema=Entity)
await vector_engine.create_collection("EntityType_name", payload_schema=EntityType)
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"

View file

@ -34,7 +34,7 @@ class CogneeConfig(QABenchmarkConfig):
system_prompt_path: str = "answer_simple_question_benchmark2.txt"
# Search parameters (fallback if not using eval framework)
search_type: SearchType = SearchType.INSIGHTS
search_type: SearchType = SearchType.GRAPH_COMPLETION
# Clean slate on initialization
clean_start: bool = True

View file

@ -57,7 +57,9 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "ChromaDB"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="ChromaDB")
insights_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="ChromaDB"
)
print("\nInsights about ChromaDB:")
for result in insights_results:
print(f"- {result}")

View file

@ -55,7 +55,9 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "KuzuDB"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="KuzuDB")
insights_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="KuzuDB"
)
print("\nInsights about KuzuDB:")
for result in insights_results:
print(f"- {result}")

View file

@ -64,7 +64,9 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "Neo4j"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="Neo4j")
insights_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Neo4j"
)
print("\nInsights about Neo4j:")
for result in insights_results:
print(f"- {result}")

View file

@ -79,7 +79,7 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "Neptune Analytics"
insights_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text="Neptune Analytics"
query_type=SearchType.GRAPH_COMPLETION, query_text="Neptune Analytics"
)
print("\n========Insights about Neptune Analytics========:")
for result in insights_results:

View file

@ -69,7 +69,9 @@ async def main():
# Now let's perform some searches
# 1. Search for insights related to "PGVector"
insights_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text="PGVector")
insights_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="PGVector"
)
print("\nInsights about PGVector:")
for result in insights_results:
print(f"- {result}")

View file

@ -50,7 +50,9 @@ async def main():
query_text = "Tell me about NLP"
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text
search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=query_text)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
)
print("Search results:")
# Display results

View file

@ -1795,7 +1795,7 @@
}
],
"source": [
"search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=node_name)\n",
"search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=node_name)\n",
"print(\"\\n\\nExtracted sentences are:\\n\")\n",
"for result in search_results:\n",
" print(f\"{result}\\n\")"

View file

@ -295,7 +295,7 @@
"cell_type": "code",
"source": [
"# Search graph insights\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.GRAPH_COMPLETION)\n",
"print(\"\\nInsights about Neptune Analytics:\")\n",
"for result in insights_results:\n",
" src_node = result[0].get(\"name\", result[0][\"type\"])\n",