From 0323ab68f115b206d80fd7040a69e36b9fc583a1 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Wed, 10 Jan 2024 23:59:01 +0100 Subject: [PATCH] Fix context retireval --- .../classifiers/classifier.py | 12 ++++++++---- .../database/create_database_tst.py | 4 +++- .../database/graph_database/graph.py | 17 ++++++++--------- .../database/vectordb/basevectordb.py | 2 +- .../database/vectordb/vectordb.py | 2 +- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/cognitive_architecture/classifiers/classifier.py b/cognitive_architecture/classifiers/classifier.py index 8a04d72d2..a958f14fd 100644 --- a/cognitive_architecture/classifiers/classifier.py +++ b/cognitive_architecture/classifiers/classifier.py @@ -77,7 +77,7 @@ async def classify_call(query, document_summaries): llm = ChatOpenAI(temperature=0, model=config.model) prompt_classify = ChatPromptTemplate.from_template( - """You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries:{document_summaries}""" + """You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}""" ) json_structure = [{ "name": "classifier", @@ -88,6 +88,10 @@ async def classify_call(query, document_summaries): "DocumentSummary": { "type": "string", "description": "The summary of the document and the topic it deals with." + }, + "d_id": { + "type": "string", + "description": "The id of the document" } @@ -98,11 +102,11 @@ async def classify_call(query, document_summaries): arguments_str = classifier_output.additional_kwargs['function_call']['arguments'] print("This is the arguments string", arguments_str) arguments_dict = json.loads(arguments_str) - classfier_value = arguments_dict.get('DocumentSummary', None) + classfier_id = arguments_dict.get('d_id', None) - print("This is the classifier value", classfier_value) + print("This is the classifier id ", classfier_id) - return classfier_value + return classfier_id diff --git a/cognitive_architecture/database/create_database_tst.py b/cognitive_architecture/database/create_database_tst.py index 46339dd16..3af3a42f5 100644 --- a/cognitive_architecture/database/create_database_tst.py +++ b/cognitive_architecture/database/create_database_tst.py @@ -107,4 +107,6 @@ if __name__ == "__main__": create_database(username, password, host, database_name) print(f"Database {database_name} created successfully.") - create_tables(engine) \ No newline at end of file + create_tables(engine) + + diff --git a/cognitive_architecture/database/graph_database/graph.py b/cognitive_architecture/database/graph_database/graph.py index f787eecac..b970da286 100644 --- a/cognitive_architecture/database/graph_database/graph.py +++ b/cognitive_architecture/database/graph_database/graph.py @@ -476,7 +476,6 @@ class Neo4jGraphDB(AbstractGraphDB): return create_statements - async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"): """ Retrieve a list of summaries for all documents associated with a given memory type for a user. @@ -486,7 +485,7 @@ class Neo4jGraphDB(AbstractGraphDB): memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory'). Returns: - List[str]: A list of document categories associated with the memory type for the user. + List[Dict[str, Union[str, None]]]: A list of dictionaries containing document summary and d_id. Raises: Exception: If an error occurs during the database query execution. @@ -498,18 +497,18 @@ class Neo4jGraphDB(AbstractGraphDB): try: query = f''' MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document) - RETURN document.summary AS summary + RETURN document.d_id AS d_id, document.summary AS summary ''' logging.info(f"Generated Cypher query: {query}") result = self.query(query) - logging.info("Result: ", result) - return [record.get("summary", "No summary available") for record in result] + logging.info(f"Result: {result}") + return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for + record in result] except Exception as e: logging.error(f"An error occurred while retrieving document summary: {str(e)}") return None - # async def get_document_categories(self, user_id: str): # """ # Retrieve a list of categories for all documents associated with a given user. @@ -568,13 +567,13 @@ class Neo4jGraphDB(AbstractGraphDB): # logging.error(f"An error occurred while retrieving document IDs: {str(e)}") # return None - async def get_memory_linked_document_ids(self, user_id: str, summary: str, memory_type: str = "PUBLIC"): + async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"): """ Retrieve a list of document IDs for a specific category associated with a given memory type for a user. Args: user_id (str): The unique identifier of the user. - summary (str): The specific document summary to filter by. + summary_id (str): The specific document summary id to filter by. memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory'). Returns: @@ -591,7 +590,7 @@ class Neo4jGraphDB(AbstractGraphDB): try: query = f''' MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document) - WHERE apoc.text.fuzzyMatch(document.summary, '{summary}') > 0.8 + WHERE document.d_id = '{summary_id}' RETURN document.d_id AS d_id ''' logging.info(f"Generated Cypher query: {query}") diff --git a/cognitive_architecture/database/vectordb/basevectordb.py b/cognitive_architecture/database/vectordb/basevectordb.py index 8d6d73ddb..8a0c1a6be 100644 --- a/cognitive_architecture/database/vectordb/basevectordb.py +++ b/cognitive_architecture/database/vectordb/basevectordb.py @@ -278,7 +278,7 @@ class BaseMemory: n_of_observations: Optional[int] = 2, ): logging.info(namespace) - logging.info("The search type is %", search_type) + logging.info("The search type is %", str(search_type)) logging.info(params) logging.info(observation) diff --git a/cognitive_architecture/database/vectordb/vectordb.py b/cognitive_architecture/database/vectordb/vectordb.py index 9ff57c77f..251b79c45 100644 --- a/cognitive_architecture/database/vectordb/vectordb.py +++ b/cognitive_architecture/database/vectordb/vectordb.py @@ -199,7 +199,7 @@ class WeaviateVectorDB(VectorDB): client = self.init_weaviate(namespace =self.namespace) if search_type is None: search_type = 'hybrid' - logging.info("The search type is s%", search_type) + logging.info("The search type is s%", (search_type))