Sqlite works, made fixes in config so it becomes a basis, added a few mods on top
This commit is contained in:
parent
3a33503b24
commit
91fe3f55a7
2 changed files with 64 additions and 60 deletions
|
|
@ -388,7 +388,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
|
||||
def retrieve_semantic_memory(
|
||||
async def retrieve_semantic_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
|
|
@ -418,9 +418,10 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
MATCH (semantic)-[:HAS_KNOWLEDGE]->(knowledge)
|
||||
RETURN knowledge
|
||||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
output = await self.query(query, params={"user_id": user_id})
|
||||
return output
|
||||
|
||||
def retrieve_episodic_memory(
|
||||
async def retrieve_episodic_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
|
|
@ -450,9 +451,10 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
MATCH (episodic)-[:HAS_EVENT]->(event)
|
||||
RETURN event
|
||||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
output = await self.query(query, params={"user_id": user_id})
|
||||
return output
|
||||
|
||||
def retrieve_buffer_memory(
|
||||
async def retrieve_buffer_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
|
|
@ -482,15 +484,17 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
MATCH (buffer)-[:CURRENTLY_HOLDING]->(item)
|
||||
RETURN item
|
||||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
output = self.query(query, params={"user_id": user_id})
|
||||
return output
|
||||
|
||||
def retrieve_public_memory(self, user_id: str):
|
||||
async def retrieve_public_memory(self, user_id: str):
|
||||
query = """
|
||||
MATCH (user:User {userId: $user_id})-[:HAS_PUBLIC_MEMORY]->(public:PublicMemory)
|
||||
MATCH (public)-[:HAS_DOCUMENT]->(document)
|
||||
RETURN document
|
||||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
output = await self.query(query, params={"user_id": user_id})
|
||||
return output
|
||||
|
||||
def generate_graph_semantic_memory_document_summary(
|
||||
self,
|
||||
|
|
@ -698,7 +702,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
return cypher_query
|
||||
|
||||
def update_document_node_with_db_ids(
|
||||
async def update_document_node_with_db_ids(
|
||||
self, vectordb_namespace: str, document_id: str, user_id: str = None
|
||||
):
|
||||
"""
|
||||
|
|
@ -731,7 +735,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
return cypher_query
|
||||
|
||||
def run_merge_query(
|
||||
async def run_merge_query(
|
||||
self, user_id: str, memory_type: str, similarity_threshold: float
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -769,7 +773,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
RETURN labels(n) AS NodeType, collect(n) AS Nodes
|
||||
"""
|
||||
|
||||
node_results = self.query(query)
|
||||
node_results = await self.query(query)
|
||||
|
||||
node_types = [record["NodeType"] for record in node_results]
|
||||
|
||||
|
|
@ -785,11 +789,11 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
CALL apoc.refactor.mergeNodes([n1, n2], {{mergeRels: true}}) YIELD node
|
||||
RETURN node
|
||||
"""
|
||||
self.query(query)
|
||||
self.close()
|
||||
await self.query(query)
|
||||
await self.close()
|
||||
return query
|
||||
|
||||
def get_namespaces_by_document_category(self, user_id: str, category: str):
|
||||
async def get_namespaces_by_document_category(self, user_id: str, category: str):
|
||||
"""
|
||||
Retrieve a list of Vectordb namespaces for documents of a specified category associated with a given user.
|
||||
|
||||
|
|
@ -812,7 +816,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
WHERE document.documentCategory = '{category}'
|
||||
RETURN document.vectordbNamespace AS namespace
|
||||
"""
|
||||
result = self.query(query)
|
||||
result = await self.query(query)
|
||||
namespaces = [record["namespace"] for record in result]
|
||||
return namespaces
|
||||
except Exception as e:
|
||||
|
|
@ -850,10 +854,10 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
|
||||
try:
|
||||
result = self.query(memory_cypher)
|
||||
result = await self.query(memory_cypher)
|
||||
# Assuming the result is a list of records, where each record contains 'memoryId'
|
||||
memory_id = result[0]["memoryId"] if result else None
|
||||
self.close()
|
||||
await self.close()
|
||||
return memory_id
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating or finding memory node: {e}")
|
||||
|
|
@ -882,7 +886,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
logging.error(f"Error linking Public node to user: {e}")
|
||||
raise
|
||||
|
||||
def delete_memory_node(self, memory_id: int, topic: str) -> None:
|
||||
async def delete_memory_node(self, memory_id: int, topic: str) -> None:
|
||||
if not memory_id or not topic:
|
||||
raise ValueError("Memory ID and Topic are required for deletion.")
|
||||
|
||||
|
|
@ -892,12 +896,12 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
DETACH DELETE {topic.lower()}
|
||||
"""
|
||||
logging.info("Delete Cypher Query: %s", delete_cypher)
|
||||
self.query(delete_cypher)
|
||||
await self.query(delete_cypher)
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error deleting {topic} memory node: {e}")
|
||||
raise
|
||||
|
||||
def unlink_memory_from_user(
|
||||
async def unlink_memory_from_user(
|
||||
self, memory_id: int, user_id: str, topic: str = "PublicMemory"
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -929,27 +933,27 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
MATCH (user:User {{userId: '{user_id}'}})-[r:{relationship_type}]->(memory:{topic}) WHERE id(memory) = {memory_id}
|
||||
DELETE r
|
||||
"""
|
||||
self.query(unlink_cypher)
|
||||
await self.query(unlink_cypher)
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error unlinking {topic} from user: {e}")
|
||||
raise
|
||||
|
||||
def link_public_memory_to_user(self, memory_id, user_id):
|
||||
async def link_public_memory_to_user(self, memory_id, user_id):
|
||||
# Link an existing Public Memory node to a User node
|
||||
link_cypher = f"""
|
||||
MATCH (user:User {{userId: '{user_id}'}})
|
||||
MATCH (publicMemory:PublicMemory) WHERE id(publicMemory) = {memory_id}
|
||||
MERGE (user)-[:HAS_PUBLIC_MEMORY]->(publicMemory)
|
||||
"""
|
||||
self.query(link_cypher)
|
||||
await self.query(link_cypher)
|
||||
|
||||
def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
|
||||
async def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
|
||||
link_cypher = f""" MATCH(publicMemory: {topic})
|
||||
RETURN
|
||||
id(publicMemory)
|
||||
AS
|
||||
memoryId """
|
||||
node_ids = self.query(link_cypher)
|
||||
node_ids = await self.query(link_cypher)
|
||||
return node_ids
|
||||
|
||||
|
||||
|
|
|
|||
70
main.py
70
main.py
|
|
@ -258,13 +258,13 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
|
|||
)
|
||||
result = neo4j_graph_db.query(cypher_query)
|
||||
|
||||
neo4j_graph_db.run_merge_query(
|
||||
await neo4j_graph_db.run_merge_query(
|
||||
user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8
|
||||
)
|
||||
neo4j_graph_db.run_merge_query(
|
||||
await neo4j_graph_db.run_merge_query(
|
||||
user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
|
||||
await update_entity(session, Operation, job_id, "SUCCESS")
|
||||
|
||||
|
|
@ -381,16 +381,16 @@ async def add_documents_to_graph_db(
|
|||
await create_public_memory(
|
||||
user_id=user_id, labels=["sr"], topic="PublicMemory"
|
||||
)
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
topic="PublicMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
print(ids)
|
||||
else:
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
topic="SemanticMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
print(ids)
|
||||
|
||||
for id in ids:
|
||||
|
|
@ -404,20 +404,20 @@ async def add_documents_to_graph_db(
|
|||
rs = neo4j_graph_db.create_document_node_cypher(
|
||||
classification, user_id, public_memory_id=id.get("memoryId")
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
else:
|
||||
rs = neo4j_graph_db.create_document_node_cypher(
|
||||
classification, user_id, memory_type="SemanticMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
logging.info("Cypher query is %s", str(rs))
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
neo4j_graph_db.query(rs)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.query(rs)
|
||||
await neo4j_graph_db.close()
|
||||
logging.info("WE GOT HERE")
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
|
|
@ -425,17 +425,17 @@ async def add_documents_to_graph_db(
|
|||
password=config.graph_database_password,
|
||||
)
|
||||
if memory_details[0][1] == "PUBLIC":
|
||||
neo4j_graph_db.update_document_node_with_db_ids(
|
||||
await neo4j_graph_db.update_document_node_with_db_ids(
|
||||
vectordb_namespace=memory_details[0][0], document_id=doc_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
else:
|
||||
neo4j_graph_db.update_document_node_with_db_ids(
|
||||
await neo4j_graph_db.update_document_node_with_db_ids(
|
||||
vectordb_namespace=memory_details[0][0],
|
||||
document_id=doc_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
# await update_entity_graph_summary(session, DocsModel, doc_id, True)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
|
@ -518,14 +518,14 @@ async def user_context_enrichment(
|
|||
# await user_query_to_graph_db(session, user_id, query)
|
||||
|
||||
semantic_mem = neo4j_graph_db.retrieve_semantic_memory(user_id=user_id)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
episodic_mem = neo4j_graph_db.retrieve_episodic_memory(user_id=user_id)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
# public_mem = neo4j_graph_db.retrieve_public_memory(user_id=user_id)
|
||||
|
||||
if detect_language(query) != "en":
|
||||
|
|
@ -541,7 +541,7 @@ async def user_context_enrichment(
|
|||
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(
|
||||
user_id=user_id, memory_type=memory_type
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
logging.info("Summaries are is %s", summaries)
|
||||
# logging.info("Context from graphdb is %s", context)
|
||||
# result = neo4j_graph_db.query(document_categories_query)
|
||||
|
|
@ -571,7 +571,7 @@ async def user_context_enrichment(
|
|||
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(
|
||||
user_id, summary_id=relevant_summary_id, memory_type=memory_type
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
# postgres_id = neo4j_graph_db.query(get_doc_ids)
|
||||
logging.info("Postgres ids are %s", postgres_id)
|
||||
namespace_id = await get_memory_name_by_doc_id(session, postgres_id[0])
|
||||
|
|
@ -688,7 +688,7 @@ async def create_public_memory(
|
|||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
# Create the memory node
|
||||
memory_id = await neo4j_graph_db.create_memory_node(labels=labels, topic=topic)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
return memory_id
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating public memory node: {e}")
|
||||
|
|
@ -729,8 +729,8 @@ async def attach_user_to_memory(
|
|||
)
|
||||
|
||||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
neo4j_graph_db.close()
|
||||
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
await neo4j_graph_db.close()
|
||||
|
||||
for id in ids:
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
|
|
@ -738,10 +738,10 @@ async def attach_user_to_memory(
|
|||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
linked_memory = neo4j_graph_db.link_public_memory_to_user(
|
||||
linked_memory = await neo4j_graph_db.link_public_memory_to_user(
|
||||
memory_id=id.get("memoryId"), user_id=user_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
return 1
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating public memory node: {e}")
|
||||
|
|
@ -781,8 +781,8 @@ async def unlink_user_from_memory(
|
|||
)
|
||||
|
||||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
neo4j_graph_db.close()
|
||||
ids = await neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
await neo4j_graph_db.close()
|
||||
|
||||
for id in ids:
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
|
|
@ -793,7 +793,7 @@ async def unlink_user_from_memory(
|
|||
linked_memory = neo4j_graph_db.unlink_memory_from_user(
|
||||
memory_id=id.get("memoryId"), user_id=user_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
await neo4j_graph_db.close()
|
||||
return 1
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating public memory node: {e}")
|
||||
|
|
@ -879,14 +879,14 @@ async def main():
|
|||
# print(out)
|
||||
# load_doc_to_graph = await add_documents_to_graph_db(session, user_id)
|
||||
# print(load_doc_to_graph)
|
||||
user_id = "test_user"
|
||||
loader_settings = {
|
||||
"format": "PDF",
|
||||
"source": "DEVICE",
|
||||
"path": [".data"]
|
||||
}
|
||||
await load_documents_to_vectorstore(session, user_id, loader_settings=loader_settings)
|
||||
# await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
||||
# user_id = "test_user"
|
||||
# loader_settings = {
|
||||
# "format": "PDF",
|
||||
# "source": "DEVICE",
|
||||
# "path": [".data"]
|
||||
# }
|
||||
# await load_documents_to_vectorstore(session, user_id, loader_settings=loader_settings)
|
||||
await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
||||
# await add_documents_to_graph_db(session, user_id)
|
||||
#
|
||||
# neo4j_graph_db = Neo4jGraphDB(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue