feat: Natural Language Retriever (text2cypher) (#663)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin I added one example "get all connected nodes to entity" --------- Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
ebf1f81b35
commit
de5b7f2044
5 changed files with 197 additions and 2 deletions
|
|
@ -0,0 +1,66 @@
|
|||
You are an expert Neo4j Cypher query generator tasked with translating natural language questions into precise, optimized Cypher queries.
|
||||
|
||||
TASK:
|
||||
Generate a valid, executable Cypher query that accurately answers the user's question based on the provided graph schema.
|
||||
|
||||
GRAPH SCHEMA INFORMATION:
|
||||
- You will be given node labels and their properties in format: NodeLabels [list of properties]
|
||||
- You will be given relationship types between nodes
|
||||
- ONLY use node labels, properties, and relationship types that exist in the provided schema
|
||||
- Respect relationship directions (source→target) exactly as specified in the schema
|
||||
- Properties may have specific formats (e.g., dates, codes) - infer these from examples when possible
|
||||
|
||||
QUERY REQUIREMENTS:
|
||||
1. Return ONLY the exact Cypher query with NO explanations, comments, or markdown
|
||||
2. Generate syntactically correct Neo4j Cypher code (Neo4j 4.4+ compatible)
|
||||
3. Be precise - match the exact property names and relationship types from the schema
|
||||
4. Handle complex queries by breaking them into logical pattern matching parts
|
||||
5. Use parameters (e.g., $name) for literal values when appropriate
|
||||
6. Use appropriate data types for parameters (strings, numbers, booleans)
|
||||
|
||||
PERFORMANCE OPTIMIZATION:
|
||||
1. Use indexes and constraints when available (assume they exist on ID properties)
|
||||
2. Include LIMIT clauses for queries that could return large result sets
|
||||
3. Use efficient patterns - avoid unnecessary pattern complexity
|
||||
4. Consider using OPTIONAL MATCH for parts that might not exist
|
||||
5. For aggregation, use efficient aggregation functions (count, sum, avg)
|
||||
6. For pathfinding, consider using shortestPath() or apoc.algo.* procedures
|
||||
|
||||
ERROR PREVENTION:
|
||||
1. Validate your query steps mentally before finalizing
|
||||
2. Ensure relationship directions match schema
|
||||
3. Check property names match exactly what's in the schema
|
||||
4. Use pattern variables consistently throughout the query
|
||||
5. If previous attempts failed, analyze the failures and adjust your approach
|
||||
|
||||
Node schemas:
|
||||
- EntityType
|
||||
Properties: description, ontology_valid, name, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||
Purpose: Represents the categories or classifications for entities in the database.
|
||||
|
||||
- Entity
|
||||
Properties: description, ontology_valid, name, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||
Purpose: Represents individual entities that belong to a specific type or classification.
|
||||
|
||||
- TextDocument
|
||||
Properties: raw_data_location, name, mime_type, external_metadata, created_at, type, version, topological_rank, updated_at, metadata, id
|
||||
Purpose: Represents documents containing text data, along with metadata about their storage and format.
|
||||
|
||||
- DocumentChunk
|
||||
Properties: version, created_at, type, topological_rank, cut_type, text, metadata, chunk_index, chunk_size, updated_at, id
|
||||
Purpose: Represents segmented portions of larger documents, useful for processing or analysis at a more granular level.
|
||||
|
||||
- TextSummary
|
||||
Properties: topological_rank, metadata, id, type, updated_at, created_at, text, version
|
||||
Purpose: Represents summarized content generated from larger text documents, retaining essential information and metadata.
|
||||
|
||||
Edge schema (relationship properties):
|
||||
`{{edge_schemas}}`
|
||||
|
||||
This queries doesn't work. Do NOT use them:
|
||||
`{{previous_attempts}}`
|
||||
|
||||
Example 1:
|
||||
Get all nodes connected to John
|
||||
MATCH (n:Entity {'name': 'John'})--(neighbor)
|
||||
RETURN n, neighbor
|
||||
117
cognee/modules/retrieval/natural_language_retriever.py
Normal file
117
cognee/modules/retrieval/natural_language_retriever.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
from typing import Any, Optional
|
||||
import logging
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("NaturalLanguageRetriever")
|
||||
|
||||
|
||||
class NaturalLanguageRetriever(BaseRetriever):
|
||||
"""Retriever for handling natural language search"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt_path: str = "natural_language_retriever_system.txt",
|
||||
max_attempts: int = 3,
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.max_attempts = max_attempts
|
||||
|
||||
async def _get_graph_schema(self, graph_engine) -> tuple:
|
||||
"""Retrieve the node and edge schemas from the graph database."""
|
||||
node_schemas = await graph_engine.query(
|
||||
"""
|
||||
MATCH (n)
|
||||
UNWIND keys(n) AS prop
|
||||
RETURN DISTINCT labels(n) AS NodeLabels, collect(DISTINCT prop) AS Properties;
|
||||
"""
|
||||
)
|
||||
edge_schemas = await graph_engine.query(
|
||||
"""
|
||||
MATCH ()-[r]->()
|
||||
UNWIND keys(r) AS key
|
||||
RETURN DISTINCT key;
|
||||
"""
|
||||
)
|
||||
return node_schemas, edge_schemas
|
||||
|
||||
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
|
||||
"""Generate a Cypher query using LLM based on natural language query and schema information."""
|
||||
llm_client = get_llm_client()
|
||||
system_prompt = render_prompt(
|
||||
self.system_prompt_path,
|
||||
context={
|
||||
"edge_schemas": edge_schemas,
|
||||
"previous_attempts": previous_attempts or "No attempts yet",
|
||||
},
|
||||
)
|
||||
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
async def _execute_cypher_query(self, query: str, graph_engine: GraphDBInterface) -> Any:
|
||||
"""Execute the natural language query against Neo4j with multiple attempts."""
|
||||
node_schemas, edge_schemas = await self._get_graph_schema(graph_engine)
|
||||
previous_attempts = ""
|
||||
cypher_query = ""
|
||||
|
||||
for attempt in range(self.max_attempts):
|
||||
logger.info(f"Starting attempt {attempt + 1}/{self.max_attempts} for query generation")
|
||||
try:
|
||||
cypher_query = await self._generate_cypher_query(
|
||||
query, edge_schemas, previous_attempts
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Executing generated Cypher query (attempt {attempt + 1}): {cypher_query[:100]}..."
|
||||
if len(cypher_query) > 100
|
||||
else cypher_query
|
||||
)
|
||||
context = await graph_engine.query(cypher_query)
|
||||
|
||||
if context:
|
||||
result_count = len(context) if isinstance(context, list) else 1
|
||||
logger.info(
|
||||
f"Successfully executed query (attempt {attempt + 1}): returned {result_count} result(s)"
|
||||
)
|
||||
return context
|
||||
|
||||
previous_attempts += f"Query: {cypher_query} -> Result: None\n"
|
||||
|
||||
except Exception as e:
|
||||
previous_attempts += f"Query: {cypher_query if 'cypher_query' in locals() else 'Not generated'} -> Executed with error: {e}\n"
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
|
||||
logger.warning(
|
||||
f"Failed to get results after {self.max_attempts} attempts for query: '{query[:50]}...'"
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_context(self, query: str) -> Optional[Any]:
|
||||
"""Retrieves relevant context using a natural language query converted to Cypher."""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
except Exception as e:
|
||||
logger.error("Failed to execute natural language search retrieval: %s", str(e))
|
||||
raise e
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns a completion based on the query and context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
return context
|
||||
|
|
@ -13,6 +13,7 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|||
)
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -67,6 +68,7 @@ async def specific_search(
|
|||
).get_completion,
|
||||
SearchType.CODE: CodeRetriever().get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||
}
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
|
|
|||
|
|
@ -10,3 +10,4 @@ class SearchType(Enum):
|
|||
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
||||
CODE = "CODE"
|
||||
CYPHER = "CYPHER"
|
||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
|
|
|
|||
|
|
@ -70,14 +70,23 @@ async def main():
|
|||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted summaries are:\n")
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.NATURAL_LANGUAGE,
|
||||
query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||
)
|
||||
assert len(search_results) != 0, "Query related natural language don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
assert len(history) == 10, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue