Docstring modules. (#877)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
bb68d6a0df
commit
ff997f48b5
15 changed files with 528 additions and 29 deletions
|
|
@ -6,6 +6,25 @@ from cognee.modules.engine.models import Entity
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(DataPoint):
|
class DocumentChunk(DataPoint):
|
||||||
|
"""
|
||||||
|
Represents a chunk of text from a document with associated metadata.
|
||||||
|
|
||||||
|
Public methods include:
|
||||||
|
|
||||||
|
- No public methods defined in the provided code.
|
||||||
|
|
||||||
|
Instance variables include:
|
||||||
|
|
||||||
|
- text: The textual content of the chunk.
|
||||||
|
- chunk_size: The size of the chunk.
|
||||||
|
- chunk_index: The index of the chunk in the original document.
|
||||||
|
- cut_type: The type of cut that defined this chunk.
|
||||||
|
- is_part_of: The document to which this chunk belongs.
|
||||||
|
- contains: A list of entities contained within the chunk (default is None).
|
||||||
|
- metadata: A dictionary to hold meta information related to the chunk, including index
|
||||||
|
fields.
|
||||||
|
"""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
chunk_index: int
|
chunk_index: int
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,14 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
class EntityType(DataPoint):
|
class EntityType(DataPoint):
|
||||||
|
"""
|
||||||
|
Represents a type of entity with a name and description.
|
||||||
|
|
||||||
|
This class inherits from DataPoint and includes two primary attributes: `name` and
|
||||||
|
`description`. Additionally, it contains a metadata dictionary that specifies
|
||||||
|
`index_fields` for indexing purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,5 +7,11 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
||||||
|
|
||||||
|
|
||||||
async def setup():
|
async def setup():
|
||||||
|
"""
|
||||||
|
Set up the necessary databases and tables.
|
||||||
|
|
||||||
|
This function asynchronously creates a relational database and its corresponding tables,
|
||||||
|
followed by creating a PGVector database and its tables.
|
||||||
|
"""
|
||||||
await create_relational_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
await create_pgvector_db_and_tables()
|
await create_pgvector_db_and_tables()
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,21 @@ logger = get_logger("entity_completion_retriever")
|
||||||
|
|
||||||
|
|
||||||
class EntityCompletionRetriever(BaseRetriever):
|
class EntityCompletionRetriever(BaseRetriever):
|
||||||
"""Retriever that uses entity-based completion for generating responses."""
|
"""
|
||||||
|
Retriever that uses entity-based completion for generating responses.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
|
||||||
|
- get_context
|
||||||
|
- get_completion
|
||||||
|
|
||||||
|
Instance variables:
|
||||||
|
|
||||||
|
- extractor
|
||||||
|
- context_provider
|
||||||
|
- user_prompt_path
|
||||||
|
- system_prompt_path
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -26,7 +40,24 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Get context using entity extraction and context provider."""
|
"""
|
||||||
|
Get context using entity extraction and context provider.
|
||||||
|
|
||||||
|
Logs the processing of the query and retrieves entities. If entities are extracted, it
|
||||||
|
attempts to retrieve the corresponding context using the context provider. Returns None
|
||||||
|
if no entities or context are found, or logs the error if an exception occurs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string for which context is being retrieved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The context retrieved from the context provider or None if not found or an
|
||||||
|
error occurred.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Processing query: {query[:100]}")
|
logger.info(f"Processing query: {query[:100]}")
|
||||||
|
|
||||||
|
|
@ -47,7 +78,26 @@ class EntityCompletionRetriever(BaseRetriever):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
||||||
"""Generate completion using provided context or fetch new context."""
|
"""
|
||||||
|
Generate completion using provided context or fetch new context.
|
||||||
|
|
||||||
|
If context is not provided, it fetches context using the query. If no context is
|
||||||
|
available, it returns an error message. Logs an error if completion generation fails due
|
||||||
|
to an exception.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string for which completion is being generated.
|
||||||
|
- context (Optional[Any]): Optional context to be used for generating completion;
|
||||||
|
fetched if not provided. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- List[str]: A list containing the generated completion or an error message if no
|
||||||
|
relevant entities were found.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,16 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
||||||
|
|
||||||
|
|
||||||
class ChunksRetriever(BaseRetriever):
|
class ChunksRetriever(BaseRetriever):
|
||||||
"""Retriever for handling document chunk-based searches."""
|
"""
|
||||||
|
Handles document chunk-based searches by retrieving relevant chunks and generating
|
||||||
|
completions from them.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
|
||||||
|
- get_context: Retrieves document chunks based on a query.
|
||||||
|
- get_completion: Generates a completion using provided context or retrieves context if
|
||||||
|
not given.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -16,7 +25,22 @@ class ChunksRetriever(BaseRetriever):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves document chunks context based on the query."""
|
"""
|
||||||
|
Retrieves document chunks context based on the query.
|
||||||
|
|
||||||
|
Searches for document chunks relevant to the specified query using a vector engine.
|
||||||
|
Raises a NoDataError if no data is found in the system.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string to search for relevant document chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: A list of document chunk payloads retrieved from the search.
|
||||||
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -27,7 +51,25 @@ class ChunksRetriever(BaseRetriever):
|
||||||
return [result.payload for result in found_chunks]
|
return [result.payload for result in found_chunks]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates a completion using document chunks context."""
|
"""
|
||||||
|
Generates a completion using document chunks context.
|
||||||
|
|
||||||
|
If the context is not provided, it retrieves the context based on the query. Returns the
|
||||||
|
context, which can be used for further processing or generation of outputs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string to be used for generating a completion.
|
||||||
|
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||||
|
completion; if None, it retrieves the context for the query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The context used for the completion or the retrieved context if none was
|
||||||
|
provided.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,12 @@ class CodeRetriever(BaseRetriever):
|
||||||
"""Retriever for handling code-based searches."""
|
"""Retriever for handling code-based searches."""
|
||||||
|
|
||||||
class CodeQueryInfo(BaseModel):
|
class CodeQueryInfo(BaseModel):
|
||||||
"""Response model for information extraction from the query"""
|
"""
|
||||||
|
Model for representing the result of a query related to code files.
|
||||||
|
|
||||||
|
This class holds a list of filenames and the corresponding source code extracted from a
|
||||||
|
query. It is used to encapsulate response data in a structured format.
|
||||||
|
"""
|
||||||
|
|
||||||
filenames: List[str] = []
|
filenames: List[str] = []
|
||||||
sourcecode: str
|
sourcecode: str
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,13 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
|
||||||
|
|
||||||
|
|
||||||
class CompletionRetriever(BaseRetriever):
|
class CompletionRetriever(BaseRetriever):
|
||||||
"""Retriever for handling LLM-based completion searches."""
|
"""
|
||||||
|
Retriever for handling LLM-based completion searches.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
- get_context(query: str) -> str
|
||||||
|
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -22,7 +28,24 @@ class CompletionRetriever(BaseRetriever):
|
||||||
self.top_k = top_k if top_k is not None else 1
|
self.top_k = top_k if top_k is not None else 1
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
"""Retrieves relevant document chunks as context."""
|
"""
|
||||||
|
Retrieves relevant document chunks as context.
|
||||||
|
|
||||||
|
Fetches document chunks based on a query from a vector engine and combines their text.
|
||||||
|
Returns empty string if no chunks are found. Raises NoDataError if the collection is not
|
||||||
|
found.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string used to search for relevant document chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: A string containing the combined text of the retrieved document chunks, or an
|
||||||
|
empty string if none are found.
|
||||||
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -38,7 +61,24 @@ class CompletionRetriever(BaseRetriever):
|
||||||
raise NoDataError("No data found in the system, please add data first.") from error
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates an LLM completion using the context."""
|
"""
|
||||||
|
Generates an LLM completion using the context.
|
||||||
|
|
||||||
|
Retrieves context if not provided and generates a completion based on the query and
|
||||||
|
context using an external completion generator.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The input query for which the completion is generated.
|
||||||
|
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||||
|
not provided, it will be retrieved using get_context. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: A list containing the generated completion from the LLM.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,13 @@ logger = get_logger("CypherSearchRetriever")
|
||||||
|
|
||||||
|
|
||||||
class CypherSearchRetriever(BaseRetriever):
|
class CypherSearchRetriever(BaseRetriever):
|
||||||
"""Retriever for handling cypher-based search"""
|
"""
|
||||||
|
Retriever for handling cypher-based search.
|
||||||
|
|
||||||
|
Public methods include:
|
||||||
|
- get_context: Retrieves relevant context using a cypher query.
|
||||||
|
- get_completion: Returns the graph connections context.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -22,7 +28,22 @@ class CypherSearchRetriever(BaseRetriever):
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves relevant context using a cypher query."""
|
"""
|
||||||
|
Retrieves relevant context using a cypher query.
|
||||||
|
|
||||||
|
If the graph engine is an instance of NetworkXAdapter, raises SearchTypeNotSupported. If
|
||||||
|
any error occurs during execution, logs the error and raises CypherSearchError.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The cypher query used to retrieve context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The result of the cypher query execution.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
|
@ -38,7 +59,23 @@ class CypherSearchRetriever(BaseRetriever):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Returns the graph connections context."""
|
"""
|
||||||
|
Returns the graph connections context.
|
||||||
|
|
||||||
|
If no context is provided, it retrieves the context using the specified query.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query to retrieve context.
|
||||||
|
- context (Optional[Any]): Optional context to use, otherwise fetched using the
|
||||||
|
query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The context, either provided or retrieved.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,21 @@ logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
"""
|
||||||
|
Handles graph context completion for question answering tasks, extending context based
|
||||||
|
on retrieved triplets.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
- get_completion
|
||||||
|
|
||||||
|
Instance variables:
|
||||||
|
- user_prompt_path
|
||||||
|
- system_prompt_path
|
||||||
|
- top_k
|
||||||
|
- node_type
|
||||||
|
- node_name
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
|
|
@ -28,6 +43,30 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
|
completions based on them.
|
||||||
|
|
||||||
|
The method runs for a specified number of rounds to enhance context until no new
|
||||||
|
triplets are found or the maximum rounds are reached. It retrieves triplet suggestions
|
||||||
|
based on a generated completion from previous iterations, logging the process of context
|
||||||
|
extension.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The input query for which the completion is generated.
|
||||||
|
- context (Optional[Any]): The existing context to use for enhancing the query; if
|
||||||
|
None, it will be initialized from triplets generated for the query. (default None)
|
||||||
|
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||||
|
new triplets before halting. (default 4)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- List[str]: A list containing the generated answer based on the query and the
|
||||||
|
extended context.
|
||||||
|
"""
|
||||||
triplets = []
|
triplets = []
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,21 @@ logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
"""
|
||||||
|
Handles graph completion by generating responses based on a series of interactions with
|
||||||
|
a language model. This class extends from GraphCompletionRetriever and is designed to
|
||||||
|
manage the retrieval and validation process for user queries, integrating follow-up
|
||||||
|
questions based on reasoning. The public methods are:
|
||||||
|
|
||||||
|
- get_completion
|
||||||
|
|
||||||
|
Instance variables include:
|
||||||
|
- validation_system_prompt_path
|
||||||
|
- validation_user_prompt_path
|
||||||
|
- followup_system_prompt_path
|
||||||
|
- followup_user_prompt_path
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
|
|
@ -36,6 +51,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate completion responses based on a user query and contextual information.
|
||||||
|
|
||||||
|
This method interacts with a language model client to retrieve a structured response,
|
||||||
|
using a series of iterations to refine the answers and generate follow-up questions
|
||||||
|
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||||
|
retrieval fails or if the model encounters issues in generating outputs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The user's query to be processed and answered.
|
||||||
|
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||||
|
If not provided, it will be fetched based on the query. (default None)
|
||||||
|
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||||
|
follow-up questions. (default 4)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
|
"""
|
||||||
llm_client = get_llm_client()
|
llm_client = get_llm_client()
|
||||||
followup_question = ""
|
followup_question = ""
|
||||||
triplets = []
|
triplets = []
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,17 @@ from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionRetriever(BaseRetriever):
|
class GraphCompletionRetriever(BaseRetriever):
|
||||||
"""Retriever for handling graph-based completion searches."""
|
"""
|
||||||
|
Retriever for handling graph-based completion searches.
|
||||||
|
|
||||||
|
This class provides methods to retrieve graph nodes and edges, resolve them into a
|
||||||
|
human-readable format, and generate completions based on graph context. Public methods
|
||||||
|
include:
|
||||||
|
- resolve_edges_to_text
|
||||||
|
- get_triplets
|
||||||
|
- get_context
|
||||||
|
- get_completion
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -45,7 +55,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||||
"""Converts retrieved graph edges into a human-readable string format."""
|
"""
|
||||||
|
Converts retrieved graph edges into a human-readable string format.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: A formatted string representation of the nodes and their connections.
|
||||||
|
"""
|
||||||
nodes = self._get_nodes(retrieved_edges)
|
nodes = self._get_nodes(retrieved_edges)
|
||||||
node_section = "\n".join(
|
node_section = "\n".join(
|
||||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||||
|
|
@ -58,7 +80,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||||
|
|
||||||
async def get_triplets(self, query: str) -> list:
|
async def get_triplets(self, query: str) -> list:
|
||||||
"""Retrieves relevant graph triplets."""
|
"""
|
||||||
|
Retrieves relevant graph triplets based on a query string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string used to search for relevant triplets in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- list: A list of found triplets that match the query.
|
||||||
|
"""
|
||||||
subclasses = get_all_subclasses(DataPoint)
|
subclasses = get_all_subclasses(DataPoint)
|
||||||
vector_index_collections = []
|
vector_index_collections = []
|
||||||
|
|
||||||
|
|
@ -82,7 +116,20 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
return found_triplets
|
return found_triplets
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
"""Retrieves and resolves graph triplets into context."""
|
"""
|
||||||
|
Retrieves and resolves graph triplets into context based on a query.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string used to retrieve context from the graph triplets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: A string representing the resolved context from the retrieved triplets, or an
|
||||||
|
empty string if no triplets are found.
|
||||||
|
"""
|
||||||
triplets = await self.get_triplets(query)
|
triplets = await self.get_triplets(query)
|
||||||
|
|
||||||
if len(triplets) == 0:
|
if len(triplets) == 0:
|
||||||
|
|
@ -91,7 +138,21 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
return await self.resolve_edges_to_text(triplets)
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates a completion using graph connections context."""
|
"""
|
||||||
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The query string for which a completion is generated.
|
||||||
|
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||||
|
not provided, context is retrieved based on the query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: A generated completion based on the query and context provided.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,16 @@ from cognee.modules.retrieval.utils.completion import summarize_text
|
||||||
|
|
||||||
|
|
||||||
class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||||
"""Retriever for handling graph-based completion searches with summarized context."""
|
"""
|
||||||
|
Retriever for handling graph-based completion searches with summarized context.
|
||||||
|
|
||||||
|
This class inherits from the GraphCompletionRetriever and is intended to manage the
|
||||||
|
retrieval of graph edges with an added functionality to summarize the retrieved
|
||||||
|
information efficiently. Public methods include:
|
||||||
|
|
||||||
|
- __init__()
|
||||||
|
- resolve_edges_to_text()
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -27,6 +36,23 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||||
self.summarize_prompt_path = summarize_prompt_path
|
self.summarize_prompt_path = summarize_prompt_path
|
||||||
|
|
||||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||||
"""Converts retrieved graph edges into a summary without redundancies."""
|
"""
|
||||||
|
Convert retrieved graph edges into a summary without redundancies.
|
||||||
|
|
||||||
|
This asynchronous method processes a list of retrieved edges and summarizes their
|
||||||
|
content using a specified prompt path. It relies on the parent's implementation to
|
||||||
|
convert the edges to text before summarizing. Raises an error if the summarization fails
|
||||||
|
due to an invalid prompt path.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- retrieved_edges (list): List of graph edges retrieved for summarization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- str: A summary string representing the content of the retrieved edges.
|
||||||
|
"""
|
||||||
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
||||||
return await summarize_text(direct_text, self.summarize_prompt_path)
|
return await summarize_text(direct_text, self.summarize_prompt_path)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,17 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
||||||
|
|
||||||
|
|
||||||
class InsightsRetriever(BaseRetriever):
|
class InsightsRetriever(BaseRetriever):
|
||||||
"""Retriever for handling graph connection-based insights."""
|
"""
|
||||||
|
Retriever for handling graph connection-based insights.
|
||||||
|
|
||||||
|
Public methods include:
|
||||||
|
- get_context
|
||||||
|
- get_completion
|
||||||
|
|
||||||
|
Instance variables include:
|
||||||
|
- exploration_levels
|
||||||
|
- top_k
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
||||||
"""Initialize retriever with exploration levels and search parameters."""
|
"""Initialize retriever with exploration levels and search parameters."""
|
||||||
|
|
@ -17,7 +27,24 @@ class InsightsRetriever(BaseRetriever):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> list:
|
async def get_context(self, query: str) -> list:
|
||||||
"""Find the neighbours of a given node in the graph."""
|
"""
|
||||||
|
Find neighbours of a given node in the graph.
|
||||||
|
|
||||||
|
If the provided query does not correspond to an existing node,
|
||||||
|
search for similar entities and retrieve their connections.
|
||||||
|
Reraises NoDataError if there is no data found in the system.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): A string identifier for the node whose neighbours are to be
|
||||||
|
retrieved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- list: A list of unique connections found for the queried node.
|
||||||
|
"""
|
||||||
if query is None:
|
if query is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -67,7 +94,24 @@ class InsightsRetriever(BaseRetriever):
|
||||||
return unique_node_connections
|
return unique_node_connections
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Returns the graph connections context."""
|
"""
|
||||||
|
Returns the graph connections context.
|
||||||
|
|
||||||
|
If a context is not provided, it fetches the context using the query provided.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): A string identifier used to fetch the context.
|
||||||
|
- context (Optional[Any]): An optional context to use for the completion; if None,
|
||||||
|
it fetches the context based on the query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The context used for the completion, which is either provided or fetched
|
||||||
|
based on the query.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,15 @@ logger = logging.getLogger("NaturalLanguageRetriever")
|
||||||
|
|
||||||
|
|
||||||
class NaturalLanguageRetriever(BaseRetriever):
|
class NaturalLanguageRetriever(BaseRetriever):
|
||||||
"""Retriever for handling natural language search"""
|
"""
|
||||||
|
Retriever for handling natural language search.
|
||||||
|
|
||||||
|
Public methods include:
|
||||||
|
|
||||||
|
- get_context: Retrieves relevant context using a natural language query converted to
|
||||||
|
Cypher.
|
||||||
|
- get_completion: Returns a completion based on the query and context.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -97,7 +105,24 @@ class NaturalLanguageRetriever(BaseRetriever):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Optional[Any]:
|
async def get_context(self, query: str) -> Optional[Any]:
|
||||||
"""Retrieves relevant context using a natural language query converted to Cypher."""
|
"""
|
||||||
|
Retrieves relevant context using a natural language query converted to Cypher.
|
||||||
|
|
||||||
|
This method raises a SearchTypeNotSupported exception if the graph engine does not
|
||||||
|
support natural language search. It also logs errors if the execution of the retrieval
|
||||||
|
fails.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The natural language query used to retrieve context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Optional[Any]: Returns the context retrieved from the graph database based on the
|
||||||
|
query.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
|
@ -110,7 +135,25 @@ class NaturalLanguageRetriever(BaseRetriever):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Returns a completion based on the query and context."""
|
"""
|
||||||
|
Returns a completion based on the query and context.
|
||||||
|
|
||||||
|
If context is not provided, it retrieves the context using the given query. No
|
||||||
|
exceptions are explicitly raised from this method, but it relies on the get_context
|
||||||
|
method for possible exceptions.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The natural language query to get a completion from.
|
||||||
|
- context (Optional[Any]): The context in which to base the completion; if not
|
||||||
|
provided, it will be retrieved using the query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: Returns the completion derived from the given query and context.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,39 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
||||||
|
|
||||||
|
|
||||||
class SummariesRetriever(BaseRetriever):
|
class SummariesRetriever(BaseRetriever):
|
||||||
"""Retriever for handling summary-based searches."""
|
"""
|
||||||
|
Retriever for handling summary-based searches.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
- __init__
|
||||||
|
- get_context
|
||||||
|
- get_completion
|
||||||
|
|
||||||
|
Instance variables:
|
||||||
|
- top_k: int - Number of top summaries to retrieve.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, top_k: int = 5):
|
def __init__(self, top_k: int = 5):
|
||||||
"""Initialize retriever with search parameters."""
|
"""Initialize retriever with search parameters."""
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves summary context based on the query."""
|
"""
|
||||||
|
Retrieves summary context based on the query.
|
||||||
|
|
||||||
|
On encountering a missing collection, raises NoDataError with a message to add data
|
||||||
|
first.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The search query for which to retrieve summary context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: A list of payloads from the retrieved summaries.
|
||||||
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -27,7 +52,24 @@ class SummariesRetriever(BaseRetriever):
|
||||||
return [summary.payload for summary in summaries_results]
|
return [summary.payload for summary in summaries_results]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""Generates a completion using summaries context."""
|
"""
|
||||||
|
Generates a completion using summaries context.
|
||||||
|
|
||||||
|
If no context is provided, retrieves context using the query. Returns the provided
|
||||||
|
context or the retrieved context if none was given.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
- query (str): The search query for generating the completion.
|
||||||
|
- context (Optional[Any]): Optional context for the completion; if not provided,
|
||||||
|
will be retrieved based on the query. (default None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Any: The generated completion context, which is either provided or retrieved.
|
||||||
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue