Feat: Adds chain of thought retriever (#864)
<!-- .github/pull_request_template.md --> ## Description Adds chain of thought retriever ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
08bc472b00
commit
e0798ff25f
10 changed files with 309 additions and 2 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Dict
|
||||
from typing import List, Dict, Any
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
|
|
@ -8,8 +9,9 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
retriever_options: Dict[str, BaseRetriever] = {
|
||||
retriever_options: Dict[str, Any] = {
|
||||
"cognee_graph_completion": GraphCompletionRetriever,
|
||||
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
|
||||
"cognee_completion": CompletionRetriever,
|
||||
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ def run():
|
|||
# Streamlit Dashboard Application Logic
|
||||
# ----------------------------------------------------------------------------
|
||||
def main():
|
||||
metrics_volume.reload()
|
||||
|
||||
st.set_page_config(page_title="Metrics Dashboard", layout="wide")
|
||||
st.title("📊 Cognee Evaluations Dashboard")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,
|
||||
to collect the missing piece of information needed to fully answer the user’s original query.
|
||||
Respond with the question only (no extra text, no punctuation beyond what’s needed).
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.
|
||||
Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks
|
||||
|
||||
<QUERY>
|
||||
`{{ query}}`
|
||||
</QUERY>
|
||||
|
||||
<ANSWER>
|
||||
`{{ answer }}`
|
||||
</ANSWER>
|
||||
|
||||
<REASONING>
|
||||
`{{ reasoning }}`
|
||||
</REASONING>
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
You are a helpful agent who are allowed to use only the provided question answer and context.
|
||||
I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
<QUESTION>
|
||||
`{{ query}}`
|
||||
</QUESTION>
|
||||
|
||||
<ANSWER>
|
||||
`{{ answer }}`
|
||||
</ANSWER>
|
||||
|
||||
<CONTEXT>
|
||||
`{{ context }}`
|
||||
</CONTEXT>
|
||||
84
cognee/modules/retrieval/graph_completion_cot_retriever.py
Normal file
84
cognee/modules/retrieval/graph_completion_cot_retriever.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from typing import Any, Optional, List
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
validation_user_prompt_path: str = "cot_validation_user_prompt.txt",
|
||||
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
self.followup_system_prompt_path = followup_system_prompt_path
|
||||
self.followup_user_prompt_path = followup_user_prompt_path
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||
) -> List[str]:
|
||||
llm_client = get_llm_client()
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
answer = [""]
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
else:
|
||||
triplets += await self.get_triplets(followup_question)
|
||||
context = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
answer = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
|
||||
if round_idx < max_iter:
|
||||
valid_args = {"query": query, "answer": answer, "context": context}
|
||||
valid_user_prompt = render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
valid_system_prompt = read_query_prompt(
|
||||
prompt_file_name=self.validation_system_prompt_path
|
||||
)
|
||||
|
||||
reasoning = await llm_client.acreate_structured_output(
|
||||
text_input=valid_user_prompt,
|
||||
system_prompt=valid_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
|
||||
followup_prompt = render_prompt(
|
||||
filename=self.followup_user_prompt_path, context=followup_args
|
||||
)
|
||||
followup_system = read_query_prompt(
|
||||
prompt_file_name=self.followup_system_prompt_path
|
||||
)
|
||||
|
||||
followup_question = await llm_client.acreate_structured_output(
|
||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||
)
|
||||
logger.info(
|
||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||
)
|
||||
|
||||
return [answer]
|
||||
|
|
@ -20,6 +20,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from ..operations import log_query, log_result
|
||||
from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -70,6 +71,10 @@ async def specific_search(
|
|||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path, top_k=top_k
|
||||
).get_completion,
|
||||
|
|
|
|||
|
|
@ -11,3 +11,4 @@ class SearchType(Enum):
|
|||
CODE = "CODE"
|
||||
CYPHER = "CYPHER"
|
||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context = await retriever.get_context("Who works at Canva?")
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with pytest.raises(DatabaseNotCreatedError):
|
||||
await retriever.get_context("Who works at Figma?")
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
async def main():
|
||||
await test.test_graph_completion_context_simple()
|
||||
await test.test_graph_completion_context_complex()
|
||||
await test.test_get_graph_completion_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
Loading…
Add table
Reference in a new issue