Feat: Adds context extension search (#865)
<!-- .github/pull_request_template.md --> ## Description Adds context extension search ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
e0798ff25f
commit
d6639217c3
7 changed files with 277 additions and 8 deletions
|
|
@ -1,5 +1,8 @@
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
|
|
@ -12,6 +15,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
retriever_options: Dict[str, Any] = {
|
retriever_options: Dict[str, Any] = {
|
||||||
"cognee_graph_completion": GraphCompletionRetriever,
|
"cognee_graph_completion": GraphCompletionRetriever,
|
||||||
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
|
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
|
||||||
|
"cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever,
|
||||||
"cognee_completion": CompletionRetriever,
|
"cognee_completion": CompletionRetriever,
|
||||||
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,7 @@ class EvalConfig(BaseSettings):
|
||||||
|
|
||||||
# Question answering params
|
# Question answering params
|
||||||
answering_questions: bool = True
|
answering_questions: bool = True
|
||||||
qa_engine: str = (
|
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||||
"cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluation params
|
# Evaluation params
|
||||||
evaluating_answers: bool = True
|
evaluating_answers: bool = True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
top_k: Optional[int] = 5,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
user_prompt_path=user_prompt_path,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_completion(
|
||||||
|
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||||
|
) -> List[str]:
|
||||||
|
triplets = []
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
triplets += await self.get_triplets(query)
|
||||||
|
context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
|
round_idx = 1
|
||||||
|
|
||||||
|
while round_idx <= context_extension_rounds:
|
||||||
|
prev_size = len(triplets)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
|
)
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
triplets += await self.get_triplets(completion)
|
||||||
|
triplets = list(set(triplets))
|
||||||
|
context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
|
num_triplets = len(triplets)
|
||||||
|
|
||||||
|
if num_triplets == prev_size:
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} - "
|
||||||
|
f"number of unique retrieved triplets: {num_triplets}"
|
||||||
|
)
|
||||||
|
|
||||||
|
round_idx += 1
|
||||||
|
|
||||||
|
answer = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [answer]
|
||||||
|
|
@ -11,6 +11,10 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||||
|
|
@ -19,8 +23,7 @@ from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from ..operations import log_query, log_result
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
|
|
@ -75,6 +78,10 @@ async def specific_search(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
|
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
).get_completion,
|
||||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path, top_k=top_k
|
system_prompt_path=system_prompt_path, top_k=top_k
|
||||||
).get_completion,
|
).get_completion,
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,4 @@ class SearchType(Enum):
|
||||||
CYPHER = "CYPHER"
|
CYPHER = "CYPHER"
|
||||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||||
|
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,185 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphCompletionRetriever:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_simple(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Canva?")
|
||||||
|
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_complex(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
print(context)
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(DatabaseNotCreatedError):
|
||||||
|
await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == "", "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
|
test = TestGraphCompletionRetriever()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await test.test_graph_completion_context_simple()
|
||||||
|
await test.test_graph_completion_context_complex()
|
||||||
|
await test.test_get_graph_completion_context_on_empty_graph()
|
||||||
|
|
||||||
|
run(main())
|
||||||
|
|
@ -12,7 +12,7 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
||||||
|
|
||||||
class TestGraphCompletionRetriever:
|
class TestGraphCompletionRetriever:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_context_simple(self):
|
async def test_graph_completion_cot_context_simple(self):
|
||||||
system_directory_path = os.path.join(
|
system_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||||
)
|
)
|
||||||
|
|
@ -60,7 +60,7 @@ class TestGraphCompletionRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_context_complex(self):
|
async def test_graph_completion_cot_context_complex(self):
|
||||||
system_directory_path = os.path.join(
|
system_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
)
|
)
|
||||||
|
|
@ -139,7 +139,7 @@ class TestGraphCompletionRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||||
system_directory_path = os.path.join(
|
system_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue