Docstring modules. (#877)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
bb68d6a0df
commit
ff997f48b5
15 changed files with 528 additions and 29 deletions
|
|
@ -6,6 +6,25 @@ from cognee.modules.engine.models import Entity
|
|||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
"""
|
||||
Represents a chunk of text from a document with associated metadata.
|
||||
|
||||
Public methods include:
|
||||
|
||||
- No public methods defined in the provided code.
|
||||
|
||||
Instance variables include:
|
||||
|
||||
- text: The textual content of the chunk.
|
||||
- chunk_size: The size of the chunk.
|
||||
- chunk_index: The index of the chunk in the original document.
|
||||
- cut_type: The type of cut that defined this chunk.
|
||||
- is_part_of: The document to which this chunk belongs.
|
||||
- contains: A list of entities contained within the chunk (default is None).
|
||||
- metadata: A dictionary to hold meta information related to the chunk, including index
|
||||
fields.
|
||||
"""
|
||||
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
|
|
|
|||
|
|
@ -2,6 +2,14 @@ from cognee.infrastructure.engine import DataPoint
|
|||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
"""
|
||||
Represents a type of entity with a name and description.
|
||||
|
||||
This class inherits from DataPoint and includes two primary attributes: `name` and
|
||||
`description`. Additionally, it contains a metadata dictionary that specifies
|
||||
`index_fields` for indexing purposes.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
|
|
|||
|
|
@ -7,5 +7,11 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
|||
|
||||
|
||||
async def setup():
|
||||
"""
|
||||
Set up the necessary databases and tables.
|
||||
|
||||
This function asynchronously creates a relational database and its corresponding tables,
|
||||
followed by creating a PGVector database and its tables.
|
||||
"""
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,21 @@ logger = get_logger("entity_completion_retriever")
|
|||
|
||||
|
||||
class EntityCompletionRetriever(BaseRetriever):
|
||||
"""Retriever that uses entity-based completion for generating responses."""
|
||||
"""
|
||||
Retriever that uses entity-based completion for generating responses.
|
||||
|
||||
Public methods:
|
||||
|
||||
- get_context
|
||||
- get_completion
|
||||
|
||||
Instance variables:
|
||||
|
||||
- extractor
|
||||
- context_provider
|
||||
- user_prompt_path
|
||||
- system_prompt_path
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -26,7 +40,24 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Get context using entity extraction and context provider."""
|
||||
"""
|
||||
Get context using entity extraction and context provider.
|
||||
|
||||
Logs the processing of the query and retrieves entities. If entities are extracted, it
|
||||
attempts to retrieve the corresponding context using the context provider. Returns None
|
||||
if no entities or context are found, or logs the error if an exception occurs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string for which context is being retrieved.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context retrieved from the context provider or None if not found or an
|
||||
error occurred.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Processing query: {query[:100]}")
|
||||
|
||||
|
|
@ -47,7 +78,26 @@ class EntityCompletionRetriever(BaseRetriever):
|
|||
return None
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
||||
"""Generate completion using provided context or fetch new context."""
|
||||
"""
|
||||
Generate completion using provided context or fetch new context.
|
||||
|
||||
If context is not provided, it fetches context using the query. If no context is
|
||||
available, it returns an error message. Logs an error if completion generation fails due
|
||||
to an exception.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string for which completion is being generated.
|
||||
- context (Optional[Any]): Optional context to be used for generating completion;
|
||||
fetched if not provided. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated completion or an error message if no
|
||||
relevant entities were found.
|
||||
"""
|
||||
try:
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,16 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
|||
|
||||
|
||||
class ChunksRetriever(BaseRetriever):
|
||||
"""Retriever for handling document chunk-based searches."""
|
||||
"""
|
||||
Handles document chunk-based searches by retrieving relevant chunks and generating
|
||||
completions from them.
|
||||
|
||||
Public methods:
|
||||
|
||||
- get_context: Retrieves document chunks based on a query.
|
||||
- get_completion: Generates a completion using provided context or retrieves context if
|
||||
not given.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -16,7 +25,22 @@ class ChunksRetriever(BaseRetriever):
|
|||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves document chunks context based on the query."""
|
||||
"""
|
||||
Retrieves document chunks context based on the query.
|
||||
|
||||
Searches for document chunks relevant to the specified query using a vector engine.
|
||||
Raises a NoDataError if no data is found in the system.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to search for relevant document chunks.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A list of document chunk payloads retrieved from the search.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
|
|
@ -27,7 +51,25 @@ class ChunksRetriever(BaseRetriever):
|
|||
return [result.payload for result in found_chunks]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using document chunks context."""
|
||||
"""
|
||||
Generates a completion using document chunks context.
|
||||
|
||||
If the context is not provided, it retrieves the context based on the query. Returns the
|
||||
context, which can be used for further processing or generation of outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context used for the completion or the retrieved context if none was
|
||||
provided.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -14,7 +14,12 @@ class CodeRetriever(BaseRetriever):
|
|||
"""Retriever for handling code-based searches."""
|
||||
|
||||
class CodeQueryInfo(BaseModel):
|
||||
"""Response model for information extraction from the query"""
|
||||
"""
|
||||
Model for representing the result of a query related to code files.
|
||||
|
||||
This class holds a list of filenames and the corresponding source code extracted from a
|
||||
query. It is used to encapsulate response data in a structured format.
|
||||
"""
|
||||
|
||||
filenames: List[str] = []
|
||||
sourcecode: str
|
||||
|
|
|
|||
|
|
@ -8,7 +8,13 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
|
|||
|
||||
|
||||
class CompletionRetriever(BaseRetriever):
|
||||
"""Retriever for handling LLM-based completion searches."""
|
||||
"""
|
||||
Retriever for handling LLM-based completion searches.
|
||||
|
||||
Public methods:
|
||||
- get_context(query: str) -> str
|
||||
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -22,7 +28,24 @@ class CompletionRetriever(BaseRetriever):
|
|||
self.top_k = top_k if top_k is not None else 1
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Retrieves relevant document chunks as context."""
|
||||
"""
|
||||
Retrieves relevant document chunks as context.
|
||||
|
||||
Fetches document chunks based on a query from a vector engine and combines their text.
|
||||
Returns empty string if no chunks are found. Raises NoDataError if the collection is not
|
||||
found.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string used to search for relevant document chunks.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A string containing the combined text of the retrieved document chunks, or an
|
||||
empty string if none are found.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
|
|
@ -38,7 +61,24 @@ class CompletionRetriever(BaseRetriever):
|
|||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates an LLM completion using the context."""
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
Retrieves context if not provided and generates a completion based on the query and
|
||||
context using an external completion generator.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, it will be retrieved using get_context. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A list containing the generated completion from the LLM.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ logger = get_logger("CypherSearchRetriever")
|
|||
|
||||
|
||||
class CypherSearchRetriever(BaseRetriever):
|
||||
"""Retriever for handling cypher-based search"""
|
||||
"""
|
||||
Retriever for handling cypher-based search.
|
||||
|
||||
Public methods include:
|
||||
- get_context: Retrieves relevant context using a cypher query.
|
||||
- get_completion: Returns the graph connections context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -22,7 +28,22 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves relevant context using a cypher query."""
|
||||
"""
|
||||
Retrieves relevant context using a cypher query.
|
||||
|
||||
If the graph engine is an instance of NetworkXAdapter, raises SearchTypeNotSupported. If
|
||||
any error occurs during execution, logs the error and raises CypherSearchError.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The cypher query used to retrieve context.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The result of the cypher query execution.
|
||||
"""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
|
|
@ -38,7 +59,23 @@ class CypherSearchRetriever(BaseRetriever):
|
|||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the graph connections context."""
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
If no context is provided, it retrieves the context using the specified query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query to retrieve context.
|
||||
- context (Optional[Any]): Optional context to use, otherwise fetched using the
|
||||
query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -9,6 +9,21 @@ logger = get_logger()
|
|||
|
||||
|
||||
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||
"""
|
||||
Handles graph context completion for question answering tasks, extending context based
|
||||
on retrieved triplets.
|
||||
|
||||
Public methods:
|
||||
- get_completion
|
||||
|
||||
Instance variables:
|
||||
- user_prompt_path
|
||||
- system_prompt_path
|
||||
- top_k
|
||||
- node_type
|
||||
- node_name
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
|
|
@ -28,6 +43,30 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
completions based on them.
|
||||
|
||||
The method runs for a specified number of rounds to enhance context until no new
|
||||
triplets are found or the maximum rounds are reached. It retrieves triplet suggestions
|
||||
based on a generated completion from previous iterations, logging the process of context
|
||||
extension.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): The existing context to use for enhancing the query; if
|
||||
None, it will be initialized from triplets generated for the query. (default None)
|
||||
- context_extension_rounds: The maximum number of rounds to extend the context with
|
||||
new triplets before halting. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer based on the query and the
|
||||
extended context.
|
||||
"""
|
||||
triplets = []
|
||||
|
||||
if context is None:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,21 @@ logger = get_logger()
|
|||
|
||||
|
||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||
"""
|
||||
Handles graph completion by generating responses based on a series of interactions with
|
||||
a language model. This class extends from GraphCompletionRetriever and is designed to
|
||||
manage the retrieval and validation process for user queries, integrating follow-up
|
||||
questions based on reasoning. The public methods are:
|
||||
|
||||
- get_completion
|
||||
|
||||
Instance variables include:
|
||||
- validation_system_prompt_path
|
||||
- validation_user_prompt_path
|
||||
- followup_system_prompt_path
|
||||
- followup_user_prompt_path
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
|
|
@ -36,6 +51,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
This method interacts with a language model client to retrieve a structured response,
|
||||
using a series of iterations to refine the answers and generate follow-up questions
|
||||
based on reasoning derived from previous outputs. It raises exceptions if the context
|
||||
retrieval fails or if the model encounters issues in generating outputs.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The user's query to be processed and answered.
|
||||
- context (Optional[Any]): Optional context that may assist in answering the query.
|
||||
If not provided, it will be fetched based on the query. (default None)
|
||||
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||
follow-up questions. (default 4)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- List[str]: A list containing the generated answer to the user's query.
|
||||
"""
|
||||
llm_client = get_llm_client()
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
|
|
|
|||
|
|
@ -11,7 +11,17 @@ from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
"""Retriever for handling graph-based completion searches."""
|
||||
"""
|
||||
Retriever for handling graph-based completion searches.
|
||||
|
||||
This class provides methods to retrieve graph nodes and edges, resolve them into a
|
||||
human-readable format, and generate completions based on graph context. Public methods
|
||||
include:
|
||||
- resolve_edges_to_text
|
||||
- get_triplets
|
||||
- get_context
|
||||
- get_completion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -45,7 +55,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
return nodes
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
nodes = self._get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
|
|
@ -58,7 +80,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
"""Retrieves relevant graph triplets."""
|
||||
"""
|
||||
Retrieves relevant graph triplets based on a query string.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string used to search for relevant triplets in the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- list: A list of found triplets that match the query.
|
||||
"""
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_index_collections = []
|
||||
|
||||
|
|
@ -82,7 +116,20 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Retrieves and resolves graph triplets into context."""
|
||||
"""
|
||||
Retrieves and resolves graph triplets into context based on a query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string used to retrieve context from the graph triplets.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A string representing the resolved context from the retrieved triplets, or an
|
||||
empty string if no triplets are found.
|
||||
"""
|
||||
triplets = await self.get_triplets(query)
|
||||
|
||||
if len(triplets) == 0:
|
||||
|
|
@ -91,7 +138,21 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using graph connections context."""
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string for which a completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, context is retrieved based on the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A generated completion based on the query and context provided.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,16 @@ from cognee.modules.retrieval.utils.completion import summarize_text
|
|||
|
||||
|
||||
class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||
"""Retriever for handling graph-based completion searches with summarized context."""
|
||||
"""
|
||||
Retriever for handling graph-based completion searches with summarized context.
|
||||
|
||||
This class inherits from the GraphCompletionRetriever and is intended to manage the
|
||||
retrieval of graph edges with an added functionality to summarize the retrieved
|
||||
information efficiently. Public methods include:
|
||||
|
||||
- __init__()
|
||||
- resolve_edges_to_text()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -27,6 +36,23 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a summary without redundancies."""
|
||||
"""
|
||||
Convert retrieved graph edges into a summary without redundancies.
|
||||
|
||||
This asynchronous method processes a list of retrieved edges and summarizes their
|
||||
content using a specified prompt path. It relies on the parent's implementation to
|
||||
convert the edges to text before summarizing. Raises an error if the summarization fails
|
||||
due to an invalid prompt path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): List of graph edges retrieved for summarization.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A summary string representing the content of the retrieved edges.
|
||||
"""
|
||||
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
||||
return await summarize_text(direct_text, self.summarize_prompt_path)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,17 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
|||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
"""Retriever for handling graph connection-based insights."""
|
||||
"""
|
||||
Retriever for handling graph connection-based insights.
|
||||
|
||||
Public methods include:
|
||||
- get_context
|
||||
- get_completion
|
||||
|
||||
Instance variables include:
|
||||
- exploration_levels
|
||||
- top_k
|
||||
"""
|
||||
|
||||
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
||||
"""Initialize retriever with exploration levels and search parameters."""
|
||||
|
|
@ -17,7 +27,24 @@ class InsightsRetriever(BaseRetriever):
|
|||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> list:
|
||||
"""Find the neighbours of a given node in the graph."""
|
||||
"""
|
||||
Find neighbours of a given node in the graph.
|
||||
|
||||
If the provided query does not correspond to an existing node,
|
||||
search for similar entities and retrieve their connections.
|
||||
Reraises NoDataError if there is no data found in the system.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): A string identifier for the node whose neighbours are to be
|
||||
retrieved.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- list: A list of unique connections found for the queried node.
|
||||
"""
|
||||
if query is None:
|
||||
return []
|
||||
|
||||
|
|
@ -67,7 +94,24 @@ class InsightsRetriever(BaseRetriever):
|
|||
return unique_node_connections
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the graph connections context."""
|
||||
"""
|
||||
Returns the graph connections context.
|
||||
|
||||
If a context is not provided, it fetches the context using the query provided.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): A string identifier used to fetch the context.
|
||||
- context (Optional[Any]): An optional context to use for the completion; if None,
|
||||
it fetches the context based on the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The context used for the completion, which is either provided or fetched
|
||||
based on the query.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -12,7 +12,15 @@ logger = logging.getLogger("NaturalLanguageRetriever")
|
|||
|
||||
|
||||
class NaturalLanguageRetriever(BaseRetriever):
|
||||
"""Retriever for handling natural language search"""
|
||||
"""
|
||||
Retriever for handling natural language search.
|
||||
|
||||
Public methods include:
|
||||
|
||||
- get_context: Retrieves relevant context using a natural language query converted to
|
||||
Cypher.
|
||||
- get_completion: Returns a completion based on the query and context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -97,7 +105,24 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
return []
|
||||
|
||||
async def get_context(self, query: str) -> Optional[Any]:
|
||||
"""Retrieves relevant context using a natural language query converted to Cypher."""
|
||||
"""
|
||||
Retrieves relevant context using a natural language query converted to Cypher.
|
||||
|
||||
This method raises a SearchTypeNotSupported exception if the graph engine does not
|
||||
support natural language search. It also logs errors if the execution of the retrieval
|
||||
fails.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The natural language query used to retrieve context.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Optional[Any]: Returns the context retrieved from the graph database based on the
|
||||
query.
|
||||
"""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
|
|
@ -110,7 +135,25 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
raise e
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns a completion based on the query and context."""
|
||||
"""
|
||||
Returns a completion based on the query and context.
|
||||
|
||||
If context is not provided, it retrieves the context using the given query. No
|
||||
exceptions are explicitly raised from this method, but it relies on the get_context
|
||||
method for possible exceptions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The natural language query to get a completion from.
|
||||
- context (Optional[Any]): The context in which to base the completion; if not
|
||||
provided, it will be retrieved using the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: Returns the completion derived from the given query and context.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,14 +7,39 @@ from cognee.infrastructure.databases.vector.exceptions.exceptions import Collect
|
|||
|
||||
|
||||
class SummariesRetriever(BaseRetriever):
|
||||
"""Retriever for handling summary-based searches."""
|
||||
"""
|
||||
Retriever for handling summary-based searches.
|
||||
|
||||
Public methods:
|
||||
- __init__
|
||||
- get_context
|
||||
- get_completion
|
||||
|
||||
Instance variables:
|
||||
- top_k: int - Number of top summaries to retrieve.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int = 5):
|
||||
"""Initialize retriever with search parameters."""
|
||||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves summary context based on the query."""
|
||||
"""
|
||||
Retrieves summary context based on the query.
|
||||
|
||||
On encountering a missing collection, raises NoDataError with a message to add data
|
||||
first.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The search query for which to retrieve summary context.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A list of payloads from the retrieved summaries.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
|
|
@ -27,7 +52,24 @@ class SummariesRetriever(BaseRetriever):
|
|||
return [summary.payload for summary in summaries_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using summaries context."""
|
||||
"""
|
||||
Generates a completion using summaries context.
|
||||
|
||||
If no context is provided, retrieves context using the query. Returns the provided
|
||||
context or the retrieved context if none was given.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The search query for generating the completion.
|
||||
- context (Optional[Any]): Optional context for the completion; if not provided,
|
||||
will be retrieved based on the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The generated completion context, which is either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue