Fix context retireval
This commit is contained in:
parent
7afe9a7e45
commit
0323ab68f1
5 changed files with 21 additions and 16 deletions
|
|
@ -77,7 +77,7 @@ async def classify_call(query, document_summaries):
|
||||||
|
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries:{document_summaries}"""
|
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
|
||||||
)
|
)
|
||||||
json_structure = [{
|
json_structure = [{
|
||||||
"name": "classifier",
|
"name": "classifier",
|
||||||
|
|
@ -88,6 +88,10 @@ async def classify_call(query, document_summaries):
|
||||||
"DocumentSummary": {
|
"DocumentSummary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The summary of the document and the topic it deals with."
|
"description": "The summary of the document and the topic it deals with."
|
||||||
|
},
|
||||||
|
"d_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The id of the document"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -98,11 +102,11 @@ async def classify_call(query, document_summaries):
|
||||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||||
print("This is the arguments string", arguments_str)
|
print("This is the arguments string", arguments_str)
|
||||||
arguments_dict = json.loads(arguments_str)
|
arguments_dict = json.loads(arguments_str)
|
||||||
classfier_value = arguments_dict.get('DocumentSummary', None)
|
classfier_id = arguments_dict.get('d_id', None)
|
||||||
|
|
||||||
print("This is the classifier value", classfier_value)
|
print("This is the classifier id ", classfier_id)
|
||||||
|
|
||||||
return classfier_value
|
return classfier_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,4 +107,6 @@ if __name__ == "__main__":
|
||||||
create_database(username, password, host, database_name)
|
create_database(username, password, host, database_name)
|
||||||
print(f"Database {database_name} created successfully.")
|
print(f"Database {database_name} created successfully.")
|
||||||
|
|
||||||
create_tables(engine)
|
create_tables(engine)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -476,7 +476,6 @@ class Neo4jGraphDB(AbstractGraphDB):
|
||||||
|
|
||||||
return create_statements
|
return create_statements
|
||||||
|
|
||||||
|
|
||||||
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"):
|
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"):
|
||||||
"""
|
"""
|
||||||
Retrieve a list of summaries for all documents associated with a given memory type for a user.
|
Retrieve a list of summaries for all documents associated with a given memory type for a user.
|
||||||
|
|
@ -486,7 +485,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
||||||
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of document categories associated with the memory type for the user.
|
List[Dict[str, Union[str, None]]]: A list of dictionaries containing document summary and d_id.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs during the database query execution.
|
Exception: If an error occurs during the database query execution.
|
||||||
|
|
@ -498,18 +497,18 @@ class Neo4jGraphDB(AbstractGraphDB):
|
||||||
try:
|
try:
|
||||||
query = f'''
|
query = f'''
|
||||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||||
RETURN document.summary AS summary
|
RETURN document.d_id AS d_id, document.summary AS summary
|
||||||
'''
|
'''
|
||||||
logging.info(f"Generated Cypher query: {query}")
|
logging.info(f"Generated Cypher query: {query}")
|
||||||
result = self.query(query)
|
result = self.query(query)
|
||||||
logging.info("Result: ", result)
|
logging.info(f"Result: {result}")
|
||||||
return [record.get("summary", "No summary available") for record in result]
|
return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for
|
||||||
|
record in result]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"An error occurred while retrieving document summary: {str(e)}")
|
logging.error(f"An error occurred while retrieving document summary: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# async def get_document_categories(self, user_id: str):
|
# async def get_document_categories(self, user_id: str):
|
||||||
# """
|
# """
|
||||||
# Retrieve a list of categories for all documents associated with a given user.
|
# Retrieve a list of categories for all documents associated with a given user.
|
||||||
|
|
@ -568,13 +567,13 @@ class Neo4jGraphDB(AbstractGraphDB):
|
||||||
# logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
|
# logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
|
||||||
# return None
|
# return None
|
||||||
|
|
||||||
async def get_memory_linked_document_ids(self, user_id: str, summary: str, memory_type: str = "PUBLIC"):
|
async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"):
|
||||||
"""
|
"""
|
||||||
Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
|
Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The unique identifier of the user.
|
user_id (str): The unique identifier of the user.
|
||||||
summary (str): The specific document summary to filter by.
|
summary_id (str): The specific document summary id to filter by.
|
||||||
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -591,7 +590,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
||||||
try:
|
try:
|
||||||
query = f'''
|
query = f'''
|
||||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||||
WHERE apoc.text.fuzzyMatch(document.summary, '{summary}') > 0.8
|
WHERE document.d_id = '{summary_id}'
|
||||||
RETURN document.d_id AS d_id
|
RETURN document.d_id AS d_id
|
||||||
'''
|
'''
|
||||||
logging.info(f"Generated Cypher query: {query}")
|
logging.info(f"Generated Cypher query: {query}")
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ class BaseMemory:
|
||||||
n_of_observations: Optional[int] = 2,
|
n_of_observations: Optional[int] = 2,
|
||||||
):
|
):
|
||||||
logging.info(namespace)
|
logging.info(namespace)
|
||||||
logging.info("The search type is %", search_type)
|
logging.info("The search type is %", str(search_type))
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
logging.info(observation)
|
logging.info(observation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
client = self.init_weaviate(namespace =self.namespace)
|
client = self.init_weaviate(namespace =self.namespace)
|
||||||
if search_type is None:
|
if search_type is None:
|
||||||
search_type = 'hybrid'
|
search_type = 'hybrid'
|
||||||
logging.info("The search type is s%", search_type)
|
logging.info("The search type is s%", (search_type))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue