Fix context retireval
This commit is contained in:
parent
7afe9a7e45
commit
0323ab68f1
5 changed files with 21 additions and 16 deletions
|
|
@ -77,7 +77,7 @@ async def classify_call(query, document_summaries):
|
|||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries:{document_summaries}"""
|
||||
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "classifier",
|
||||
|
|
@ -88,6 +88,10 @@ async def classify_call(query, document_summaries):
|
|||
"DocumentSummary": {
|
||||
"type": "string",
|
||||
"description": "The summary of the document and the topic it deals with."
|
||||
},
|
||||
"d_id": {
|
||||
"type": "string",
|
||||
"description": "The id of the document"
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -98,11 +102,11 @@ async def classify_call(query, document_summaries):
|
|||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
print("This is the arguments string", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
classfier_value = arguments_dict.get('DocumentSummary', None)
|
||||
classfier_id = arguments_dict.get('d_id', None)
|
||||
|
||||
print("This is the classifier value", classfier_value)
|
||||
print("This is the classifier id ", classfier_id)
|
||||
|
||||
return classfier_value
|
||||
return classfier_id
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -107,4 +107,6 @@ if __name__ == "__main__":
|
|||
create_database(username, password, host, database_name)
|
||||
print(f"Database {database_name} created successfully.")
|
||||
|
||||
create_tables(engine)
|
||||
create_tables(engine)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -476,7 +476,6 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
return create_statements
|
||||
|
||||
|
||||
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"):
|
||||
"""
|
||||
Retrieve a list of summaries for all documents associated with a given memory type for a user.
|
||||
|
|
@ -486,7 +485,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
||||
|
||||
Returns:
|
||||
List[str]: A list of document categories associated with the memory type for the user.
|
||||
List[Dict[str, Union[str, None]]]: A list of dictionaries containing document summary and d_id.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the database query execution.
|
||||
|
|
@ -498,18 +497,18 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
try:
|
||||
query = f'''
|
||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||
RETURN document.summary AS summary
|
||||
RETURN document.d_id AS d_id, document.summary AS summary
|
||||
'''
|
||||
logging.info(f"Generated Cypher query: {query}")
|
||||
result = self.query(query)
|
||||
logging.info("Result: ", result)
|
||||
return [record.get("summary", "No summary available") for record in result]
|
||||
logging.info(f"Result: {result}")
|
||||
return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for
|
||||
record in result]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving document summary: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
# async def get_document_categories(self, user_id: str):
|
||||
# """
|
||||
# Retrieve a list of categories for all documents associated with a given user.
|
||||
|
|
@ -568,13 +567,13 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
# logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
|
||||
# return None
|
||||
|
||||
async def get_memory_linked_document_ids(self, user_id: str, summary: str, memory_type: str = "PUBLIC"):
|
||||
async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"):
|
||||
"""
|
||||
Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
summary (str): The specific document summary to filter by.
|
||||
summary_id (str): The specific document summary id to filter by.
|
||||
memory_type (str): The type of memory node ('SemanticMemory' or 'PublicMemory').
|
||||
|
||||
Returns:
|
||||
|
|
@ -591,7 +590,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
try:
|
||||
query = f'''
|
||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||
WHERE apoc.text.fuzzyMatch(document.summary, '{summary}') > 0.8
|
||||
WHERE document.d_id = '{summary_id}'
|
||||
RETURN document.d_id AS d_id
|
||||
'''
|
||||
logging.info(f"Generated Cypher query: {query}")
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ class BaseMemory:
|
|||
n_of_observations: Optional[int] = 2,
|
||||
):
|
||||
logging.info(namespace)
|
||||
logging.info("The search type is %", search_type)
|
||||
logging.info("The search type is %", str(search_type))
|
||||
logging.info(params)
|
||||
logging.info(observation)
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
client = self.init_weaviate(namespace =self.namespace)
|
||||
if search_type is None:
|
||||
search_type = 'hybrid'
|
||||
logging.info("The search type is s%", search_type)
|
||||
logging.info("The search type is s%", (search_type))
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue