feat: Natural Language Retriever (text2cypher) (#663)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin I added one example "get all connected nodes to entity" --------- Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
ebf1f81b35
commit
de5b7f2044
5 changed files with 197 additions and 2 deletions
|
|
@ -0,0 +1,66 @@
|
||||||
|
You are an expert Neo4j Cypher query generator tasked with translating natural language questions into precise, optimized Cypher queries.
|
||||||
|
|
||||||
|
TASK:
|
||||||
|
Generate a valid, executable Cypher query that accurately answers the user's question based on the provided graph schema.
|
||||||
|
|
||||||
|
GRAPH SCHEMA INFORMATION:
|
||||||
|
- You will be given node labels and their properties in format: NodeLabels [list of properties]
|
||||||
|
- You will be given relationship types between nodes
|
||||||
|
- ONLY use node labels, properties, and relationship types that exist in the provided schema
|
||||||
|
- Respect relationship directions (source→target) exactly as specified in the schema
|
||||||
|
- Properties may have specific formats (e.g., dates, codes) - infer these from examples when possible
|
||||||
|
|
||||||
|
QUERY REQUIREMENTS:
|
||||||
|
1. Return ONLY the exact Cypher query with NO explanations, comments, or markdown
|
||||||
|
2. Generate syntactically correct Neo4j Cypher code (Neo4j 4.4+ compatible)
|
||||||
|
3. Be precise - match the exact property names and relationship types from the schema
|
||||||
|
4. Handle complex queries by breaking them into logical pattern matching parts
|
||||||
|
5. Use parameters (e.g., $name) for literal values when appropriate
|
||||||
|
6. Use appropriate data types for parameters (strings, numbers, booleans)
|
||||||
|
|
||||||
|
PERFORMANCE OPTIMIZATION:
|
||||||
|
1. Use indexes and constraints when available (assume they exist on ID properties)
|
||||||
|
2. Include LIMIT clauses for queries that could return large result sets
|
||||||
|
3. Use efficient patterns - avoid unnecessary pattern complexity
|
||||||
|
4. Consider using OPTIONAL MATCH for parts that might not exist
|
||||||
|
5. For aggregation, use efficient aggregation functions (count, sum, avg)
|
||||||
|
6. For pathfinding, consider using shortestPath() or apoc.algo.* procedures
|
||||||
|
|
||||||
|
ERROR PREVENTION:
|
||||||
|
1. Validate your query steps mentally before finalizing
|
||||||
|
2. Ensure relationship directions match schema
|
||||||
|
3. Check property names match exactly what's in the schema
|
||||||
|
4. Use pattern variables consistently throughout the query
|
||||||
|
5. If previous attempts failed, analyze the failures and adjust your approach
|
||||||
|
|
||||||
|
Node schemas:
|
||||||
|
- EntityType
|
||||||
|
Properties: description, ontology_valid, name, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||||
|
Purpose: Represents the categories or classifications for entities in the database.
|
||||||
|
|
||||||
|
- Entity
|
||||||
|
Properties: description, ontology_valid, name, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||||
|
Purpose: Represents individual entities that belong to a specific type or classification.
|
||||||
|
|
||||||
|
- TextDocument
|
||||||
|
Properties: raw_data_location, name, mime_type, external_metadata, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||||
|
Purpose: Represents documents containing text data, along with metadata about their storage and format.
|
||||||
|
|
||||||
|
- DocumentChunk
|
||||||
|
Properties: version, created_at, type, topological_rank, cut_type, text, metadata, chunk_index, chunk_size, updated_at, id
|
||||||
|
Purpose: Represents segmented portions of larger documents, useful for processing or analysis at a more granular level.
|
||||||
|
|
||||||
|
- TextSummary
|
||||||
|
Properties: topological_rank, metadata, id, type, updated_at, created_at, text, version
|
||||||
|
Purpose: Represents summarized content generated from larger text documents, retaining essential information and metadata.
|
||||||
|
|
||||||
|
Edge schema (relationship properties):
|
||||||
|
`{{edge_schemas}}`
|
||||||
|
|
||||||
|
This queries doesn't work. Do NOT use them:
|
||||||
|
`{{previous_attempts}}`
|
||||||
|
|
||||||
|
Example 1:
|
||||||
|
Get all nodes connected to John
|
||||||
|
MATCH (n:Entity {'name': 'John'})--(neighbor)
|
||||||
|
RETURN n, neighbor
|
||||||
117
cognee/modules/retrieval/natural_language_retriever.py
Normal file
117
cognee/modules/retrieval/natural_language_retriever.py
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
import logging
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
from cognee.infrastructure.llm.prompts import render_prompt
|
||||||
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
||||||
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger("NaturalLanguageRetriever")
|
||||||
|
|
||||||
|
|
||||||
|
class NaturalLanguageRetriever(BaseRetriever):
|
||||||
|
"""Retriever for handling natural language search"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
system_prompt_path: str = "natural_language_retriever_system.txt",
|
||||||
|
max_attempts: int = 3,
|
||||||
|
):
|
||||||
|
"""Initialize retriever with optional custom prompt paths."""
|
||||||
|
self.system_prompt_path = system_prompt_path
|
||||||
|
self.max_attempts = max_attempts
|
||||||
|
|
||||||
|
async def _get_graph_schema(self, graph_engine) -> tuple:
|
||||||
|
"""Retrieve the node and edge schemas from the graph database."""
|
||||||
|
node_schemas = await graph_engine.query(
|
||||||
|
"""
|
||||||
|
MATCH (n)
|
||||||
|
UNWIND keys(n) AS prop
|
||||||
|
RETURN DISTINCT labels(n) AS NodeLabels, collect(DISTINCT prop) AS Properties;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
edge_schemas = await graph_engine.query(
|
||||||
|
"""
|
||||||
|
MATCH ()-[r]->()
|
||||||
|
UNWIND keys(r) AS key
|
||||||
|
RETURN DISTINCT key;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return node_schemas, edge_schemas
|
||||||
|
|
||||||
|
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
|
||||||
|
"""Generate a Cypher query using LLM based on natural language query and schema information."""
|
||||||
|
llm_client = get_llm_client()
|
||||||
|
system_prompt = render_prompt(
|
||||||
|
self.system_prompt_path,
|
||||||
|
context={
|
||||||
|
"edge_schemas": edge_schemas,
|
||||||
|
"previous_attempts": previous_attempts or "No attempts yet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await llm_client.acreate_structured_output(
|
||||||
|
text_input=query,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_cypher_query(self, query: str, graph_engine: GraphDBInterface) -> Any:
|
||||||
|
"""Execute the natural language query against Neo4j with multiple attempts."""
|
||||||
|
node_schemas, edge_schemas = await self._get_graph_schema(graph_engine)
|
||||||
|
previous_attempts = ""
|
||||||
|
cypher_query = ""
|
||||||
|
|
||||||
|
for attempt in range(self.max_attempts):
|
||||||
|
logger.info(f"Starting attempt {attempt + 1}/{self.max_attempts} for query generation")
|
||||||
|
try:
|
||||||
|
cypher_query = await self._generate_cypher_query(
|
||||||
|
query, edge_schemas, previous_attempts
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Executing generated Cypher query (attempt {attempt + 1}): {cypher_query[:100]}..."
|
||||||
|
if len(cypher_query) > 100
|
||||||
|
else cypher_query
|
||||||
|
)
|
||||||
|
context = await graph_engine.query(cypher_query)
|
||||||
|
|
||||||
|
if context:
|
||||||
|
result_count = len(context) if isinstance(context, list) else 1
|
||||||
|
logger.info(
|
||||||
|
f"Successfully executed query (attempt {attempt + 1}): returned {result_count} result(s)"
|
||||||
|
)
|
||||||
|
return context
|
||||||
|
|
||||||
|
previous_attempts += f"Query: {cypher_query} -> Result: None\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
previous_attempts += f"Query: {cypher_query if 'cypher_query' in locals() else 'Not generated'} -> Executed with error: {e}\n"
|
||||||
|
logger.error(f"Error executing query: {str(e)}")
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to get results after {self.max_attempts} attempts for query: '{query[:50]}...'"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_context(self, query: str) -> Optional[Any]:
|
||||||
|
"""Retrieves relevant context using a natural language query converted to Cypher."""
|
||||||
|
try:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||||
|
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||||
|
|
||||||
|
return await self._execute_cypher_query(query, graph_engine)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to execute natural language search retrieval: %s", str(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
|
"""Returns a completion based on the query and context."""
|
||||||
|
if context is None:
|
||||||
|
context = await self.get_context(query)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
)
|
)
|
||||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||||
|
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
@ -67,6 +68,7 @@ async def specific_search(
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.CODE: CodeRetriever().get_completion,
|
SearchType.CODE: CodeRetriever().get_completion,
|
||||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||||
|
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||||
}
|
}
|
||||||
|
|
||||||
search_task = search_tasks.get(query_type)
|
search_task = search_tasks.get(query_type)
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,4 @@ class SearchType(Enum):
|
||||||
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
||||||
CODE = "CODE"
|
CODE = "CODE"
|
||||||
CYPHER = "CYPHER"
|
CYPHER = "CYPHER"
|
||||||
|
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||||
|
|
|
||||||
|
|
@ -70,14 +70,23 @@ async def main():
|
||||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||||
)
|
)
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||||
print("\nExtracted summaries are:\n")
|
print("\nExtracted results are:\n")
|
||||||
|
for result in search_results:
|
||||||
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_type=SearchType.NATURAL_LANGUAGE,
|
||||||
|
query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||||
|
)
|
||||||
|
assert len(search_results) != 0, "Query related natural language don't exist."
|
||||||
|
print("\nExtracted results are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
history = await get_history(user.id)
|
history = await get_history(user.id)
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 10, "Search history is not correct."
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue