Fix context retireval

This commit is contained in:
Vasilije 2024-01-10 23:59:01 +01:00
parent 7afe9a7e45
commit 0323ab68f1
5 changed files with 21 additions and 16 deletions

View file

@ -77,7 +77,7 @@ async def classify_call(query, document_summaries):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries:{document_summaries}"""
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
)
json_structure = [{
"name": "classifier",
@ -88,6 +88,10 @@ async def classify_call(query, document_summaries):
"DocumentSummary": {
"type": "string",
"description": "The summary of the document and the topic it deals with."
},
"d_id": {
"type": "string",
"description": "The id of the document"
}
@ -98,11 +102,11 @@ async def classify_call(query, document_summaries):
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
classfier_value = arguments_dict.get('DocumentSummary', None)
classfier_id = arguments_dict.get('d_id', None)
print("This is the classifier value", classfier_value)
print("This is the classifier id ", classfier_id)
return classfier_value
return classfier_id

View file

@ -107,4 +107,6 @@ if __name__ == "__main__":
create_database(username, password, host, database_name)
print(f"Database {database_name} created successfully.")
create_tables(engine)
create_tables(engine)

View file

@ -476,7 +476,6 @@ class Neo4jGraphDB(AbstractGraphDB):
return create_statements
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"):
"""
Retrieve a list of summaries for all documents associated with a given memory type for a user.
@ -486,7 +485,7 @@ class Neo4jGraphDB(AbstractGraphDB):
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
Returns:
List[str]: A list of document categories associated with the memory type for the user.
List[Dict[str, Union[str, None]]]: A list of dictionaries containing document summary and d_id.
Raises:
Exception: If an error occurs during the database query execution.
@ -498,18 +497,18 @@ class Neo4jGraphDB(AbstractGraphDB):
try:
query = f'''
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
RETURN document.summary AS summary
RETURN document.d_id AS d_id, document.summary AS summary
'''
logging.info(f"Generated Cypher query: {query}")
result = self.query(query)
logging.info("Result: ", result)
return [record.get("summary", "No summary available") for record in result]
logging.info(f"Result: {result}")
return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for
record in result]
except Exception as e:
logging.error(f"An error occurred while retrieving document summary: {str(e)}")
return None
# async def get_document_categories(self, user_id: str):
# """
# Retrieve a list of categories for all documents associated with a given user.
@ -568,13 +567,13 @@ class Neo4jGraphDB(AbstractGraphDB):
# logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
# return None
async def get_memory_linked_document_ids(self, user_id: str, summary: str, memory_type: str = "PUBLIC"):
async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"):
"""
Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
Args:
user_id (str): The unique identifier of the user.
summary (str): The specific document summary to filter by.
summary_id (str): The specific document summary id to filter by.
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
Returns:
@ -591,7 +590,7 @@ class Neo4jGraphDB(AbstractGraphDB):
try:
query = f'''
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
WHERE apoc.text.fuzzyMatch(document.summary, '{summary}') > 0.8
WHERE document.d_id = '{summary_id}'
RETURN document.d_id AS d_id
'''
logging.info(f"Generated Cypher query: {query}")

View file

@ -278,7 +278,7 @@ class BaseMemory:
n_of_observations: Optional[int] = 2,
):
logging.info(namespace)
logging.info("The search type is %", search_type)
logging.info("The search type is %", str(search_type))
logging.info(params)
logging.info(observation)

View file

@ -199,7 +199,7 @@ class WeaviateVectorDB(VectorDB):
client = self.init_weaviate(namespace =self.namespace)
if search_type is None:
search_type = 'hybrid'
logging.info("The search type is s%", search_type)
logging.info("The search type is s%", (search_type))