feat: implements the first version of graph based completion in search

This commit is contained in:
hajdul88 2025-01-08 16:10:29 +01:00
parent 35892f974b
commit d39140f28b
6 changed files with 60 additions and 6 deletions

View file

@ -15,12 +15,14 @@ from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
class SearchType(Enum):
SUMMARIES = "SUMMARIES"
INSIGHTS = "INSIGHTS"
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
async def search(query_type: SearchType, query_text: str, user: User = None,
datasets: Union[list[str], str, None] = None) -> list:
@ -58,6 +60,7 @@ async def specific_search(query_type: SearchType, query: str, user) -> list:
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion
}
search_task = search_tasks.get(query_type)

View file

@ -0,0 +1 @@
Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.

View file

@ -0,0 +1,2 @@
The question is: `{{ question }}`
and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`

View file

@ -1 +1,2 @@
from .query_completion import query_completion
from .query_completion import query_completion
from .graph_query_completion import graph_query_completion

View file

@ -0,0 +1,45 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
def retrieved_edges_to_string(retrieved_edges):
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes['text'] or edge.node1.attributes.get('name')
node2_string = edge.node2.attributes['text'] or edge.node2.attributes.get('name')
edge_string = edge.attributes['relationship_type']
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)
async def graph_query_completion(query: str) -> list:
"""
Parameters:
- query (str): The query string to compute.
Returns:
- list: Answer to the query.
"""
found_triplets = await brute_force_triplet_search(query, top_k=5)
if len(found_triplets) == 0:
raise NoRelevantDataFound
args = {
"question": query,
"context": retrieved_edges_to_string(found_triplets),
}
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return [computed_answer]

View file

@ -1,8 +1,8 @@
import cognee
import asyncio
import logging
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
job_1 = """
@ -184,11 +184,13 @@ async def main(enable_steps):
# Step 4: Query insights
if enable_steps.get("retriever"):
results = await brute_force_triplet_search('Who has the most experience with graphic design?')
print(format_triplets(results))
search_results = await cognee.search(
SearchType.GRAPH_COMPLETION, query_text='Who has experience in design tools?'
)
print(search_results)
if __name__ == '__main__':
setup_logging(logging.ERROR)
setup_logging(logging.INFO)
rebuild_kg = True
retrieve = True