ran black, fixed some linting issues
This commit is contained in:
parent
1f3ac1ec97
commit
b3f29d3f2d
36 changed files with 1778 additions and 966 deletions
191
api.py
191
api.py
|
|
@ -39,12 +39,13 @@ app = FastAPI(debug=True)
|
|||
#
|
||||
# auth = JWTBearer(jwks)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""
|
||||
Root endpoint that returns a welcome message.
|
||||
"""
|
||||
return { "message": "Hello, World, I am alive!" }
|
||||
return {"message": "Hello, World, I am alive!"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
|
@ -61,8 +62,8 @@ class Payload(BaseModel):
|
|||
|
||||
@app.post("/add-memory", response_model=dict)
|
||||
async def add_memory(
|
||||
payload: Payload,
|
||||
# files: List[UploadFile] = File(...),
|
||||
payload: Payload,
|
||||
# files: List[UploadFile] = File(...),
|
||||
):
|
||||
try:
|
||||
logging.info(" Adding to Memory ")
|
||||
|
|
@ -70,68 +71,76 @@ async def add_memory(
|
|||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from main import load_documents_to_vectorstore
|
||||
|
||||
if 'settings' in decoded_payload and decoded_payload['settings'] is not None:
|
||||
settings_for_loader = decoded_payload['settings']
|
||||
if (
|
||||
"settings" in decoded_payload
|
||||
and decoded_payload["settings"] is not None
|
||||
):
|
||||
settings_for_loader = decoded_payload["settings"]
|
||||
else:
|
||||
settings_for_loader = None
|
||||
|
||||
if 'content' in decoded_payload and decoded_payload['content'] is not None:
|
||||
content = decoded_payload['content']
|
||||
if "content" in decoded_payload and decoded_payload["content"] is not None:
|
||||
content = decoded_payload["content"]
|
||||
else:
|
||||
content = None
|
||||
|
||||
output = await load_documents_to_vectorstore(session, decoded_payload['user_id'], content=content,
|
||||
loader_settings=settings_for_loader)
|
||||
output = await load_documents_to_vectorstore(
|
||||
session,
|
||||
decoded_payload["user_id"],
|
||||
content=content,
|
||||
loader_settings=settings_for_loader,
|
||||
)
|
||||
return JSONResponse(content={"response": output}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={"response": {"error": str(e)}}, status_code=503
|
||||
)
|
||||
return JSONResponse(content={"response": {"error": str(e)}}, status_code=503)
|
||||
|
||||
|
||||
@app.post("/add-architecture-public-memory", response_model=dict)
|
||||
async def add_memory(
|
||||
payload: Payload,
|
||||
# files: List[UploadFile] = File(...),
|
||||
payload: Payload,
|
||||
# files: List[UploadFile] = File(...),
|
||||
):
|
||||
try:
|
||||
logging.info(" Adding to Memory ")
|
||||
decoded_payload = payload.payload
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from main import load_documents_to_vectorstore
|
||||
if 'content' in decoded_payload and decoded_payload['content'] is not None:
|
||||
content = decoded_payload['content']
|
||||
|
||||
if "content" in decoded_payload and decoded_payload["content"] is not None:
|
||||
content = decoded_payload["content"]
|
||||
else:
|
||||
content = None
|
||||
|
||||
user_id = 'system_user'
|
||||
loader_settings = {
|
||||
"format": "PDF",
|
||||
"source": "DEVICE",
|
||||
"path": [".data"]
|
||||
}
|
||||
user_id = "system_user"
|
||||
loader_settings = {"format": "PDF", "source": "DEVICE", "path": [".data"]}
|
||||
|
||||
output = await load_documents_to_vectorstore(session, user_id=user_id, content=content,
|
||||
loader_settings=loader_settings)
|
||||
output = await load_documents_to_vectorstore(
|
||||
session,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
loader_settings=loader_settings,
|
||||
)
|
||||
return JSONResponse(content={"response": output}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={"response": {"error": str(e)}}, status_code=503
|
||||
)
|
||||
return JSONResponse(content={"response": {"error": str(e)}}, status_code=503)
|
||||
|
||||
|
||||
@app.post("/user-query-to-graph")
|
||||
async def user_query_to_graph(payload: Payload):
|
||||
try:
|
||||
from main import user_query_to_graph_db
|
||||
|
||||
decoded_payload = payload.payload
|
||||
# Execute the query - replace this with the actual execution method
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
||||
result = await user_query_to_graph_db(session=session, user_id=decoded_payload['user_id'],
|
||||
query_input=decoded_payload['query'])
|
||||
result = await user_query_to_graph_db(
|
||||
session=session,
|
||||
user_id=decoded_payload["user_id"],
|
||||
query_input=decoded_payload["query"],
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -144,17 +153,23 @@ async def document_to_graph_db(payload: Payload):
|
|||
logging.info("Adding documents to graph db")
|
||||
try:
|
||||
decoded_payload = payload.payload
|
||||
if 'settings' in decoded_payload and decoded_payload['settings'] is not None:
|
||||
settings_for_loader = decoded_payload['settings']
|
||||
if "settings" in decoded_payload and decoded_payload["settings"] is not None:
|
||||
settings_for_loader = decoded_payload["settings"]
|
||||
else:
|
||||
settings_for_loader = None
|
||||
if 'memory_type' in decoded_payload and decoded_payload['memory_type'] is not None:
|
||||
memory_type = decoded_payload['memory_type']
|
||||
if (
|
||||
"memory_type" in decoded_payload
|
||||
and decoded_payload["memory_type"] is not None
|
||||
):
|
||||
memory_type = decoded_payload["memory_type"]
|
||||
else:
|
||||
memory_type = None
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
result = await add_documents_to_graph_db(session=session, user_id=decoded_payload['user_id'],
|
||||
document_memory_types=memory_type)
|
||||
result = await add_documents_to_graph_db(
|
||||
session=session,
|
||||
user_id=decoded_payload["user_id"],
|
||||
document_memory_types=memory_type,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -166,10 +181,13 @@ async def cognitive_context_enrichment(payload: Payload):
|
|||
try:
|
||||
decoded_payload = payload.payload
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
result = await user_context_enrichment(session, user_id=decoded_payload['user_id'],
|
||||
query=decoded_payload['query'],
|
||||
generative_response=decoded_payload['generative_response'],
|
||||
memory_type=decoded_payload['memory_type'])
|
||||
result = await user_context_enrichment(
|
||||
session,
|
||||
user_id=decoded_payload["user_id"],
|
||||
query=decoded_payload["query"],
|
||||
generative_response=decoded_payload["generative_response"],
|
||||
memory_type=decoded_payload["memory_type"],
|
||||
)
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -182,8 +200,11 @@ async def classify_user_query(payload: Payload):
|
|||
decoded_payload = payload.payload
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from main import relevance_feedback
|
||||
result = await relevance_feedback(query=decoded_payload['query'],
|
||||
input_type=decoded_payload['knowledge_type'])
|
||||
|
||||
result = await relevance_feedback(
|
||||
query=decoded_payload["query"],
|
||||
input_type=decoded_payload["knowledge_type"],
|
||||
)
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -197,9 +218,14 @@ async def user_query_classfier(payload: Payload):
|
|||
|
||||
# Execute the query - replace this with the actual execution method
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from cognitive_architecture.classifiers.classifier import classify_user_query
|
||||
from cognitive_architecture.classifiers.classifier import (
|
||||
classify_user_query,
|
||||
)
|
||||
|
||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
||||
result = await classify_user_query(session, decoded_payload['user_id'], decoded_payload['query'])
|
||||
result = await classify_user_query(
|
||||
session, decoded_payload["user_id"], decoded_payload["query"]
|
||||
)
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -211,41 +237,43 @@ async def drop_db(payload: Payload):
|
|||
try:
|
||||
decoded_payload = payload.payload
|
||||
|
||||
if decoded_payload['operation'] == 'drop':
|
||||
|
||||
if os.environ.get('AWS_ENV') == 'dev':
|
||||
host = os.environ.get('POSTGRES_HOST')
|
||||
username = os.environ.get('POSTGRES_USER')
|
||||
password = os.environ.get('POSTGRES_PASSWORD')
|
||||
database_name = os.environ.get('POSTGRES_DB')
|
||||
if decoded_payload["operation"] == "drop":
|
||||
if os.environ.get("AWS_ENV") == "dev":
|
||||
host = os.environ.get("POSTGRES_HOST")
|
||||
username = os.environ.get("POSTGRES_USER")
|
||||
password = os.environ.get("POSTGRES_PASSWORD")
|
||||
database_name = os.environ.get("POSTGRES_DB")
|
||||
else:
|
||||
pass
|
||||
|
||||
from cognitive_architecture.database.create_database import drop_database, create_admin_engine
|
||||
from cognitive_architecture.database.create_database import (
|
||||
drop_database,
|
||||
create_admin_engine,
|
||||
)
|
||||
|
||||
engine = create_admin_engine(username, password, host, database_name)
|
||||
connection = engine.raw_connection()
|
||||
drop_database(connection, database_name)
|
||||
return JSONResponse(content={"response": "DB dropped"}, status_code=200)
|
||||
else:
|
||||
|
||||
if os.environ.get('AWS_ENV') == 'dev':
|
||||
host = os.environ.get('POSTGRES_HOST')
|
||||
username = os.environ.get('POSTGRES_USER')
|
||||
password = os.environ.get('POSTGRES_PASSWORD')
|
||||
database_name = os.environ.get('POSTGRES_DB')
|
||||
if os.environ.get("AWS_ENV") == "dev":
|
||||
host = os.environ.get("POSTGRES_HOST")
|
||||
username = os.environ.get("POSTGRES_USER")
|
||||
password = os.environ.get("POSTGRES_PASSWORD")
|
||||
database_name = os.environ.get("POSTGRES_DB")
|
||||
else:
|
||||
pass
|
||||
|
||||
from cognitive_architecture.database.create_database import create_database, create_admin_engine
|
||||
from cognitive_architecture.database.create_database import (
|
||||
create_database,
|
||||
create_admin_engine,
|
||||
)
|
||||
|
||||
engine = create_admin_engine(username, password, host, database_name)
|
||||
connection = engine.raw_connection()
|
||||
create_database(connection, database_name)
|
||||
return JSONResponse(content={"response": " DB drop"}, status_code=200)
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
|
@ -255,18 +283,18 @@ async def create_public_memory(payload: Payload):
|
|||
try:
|
||||
decoded_payload = payload.payload
|
||||
|
||||
if 'user_id' in decoded_payload and decoded_payload['user_id'] is not None:
|
||||
user_id = decoded_payload['user_id']
|
||||
if "user_id" in decoded_payload and decoded_payload["user_id"] is not None:
|
||||
user_id = decoded_payload["user_id"]
|
||||
else:
|
||||
user_id = None
|
||||
|
||||
if 'labels' in decoded_payload and decoded_payload['labels'] is not None:
|
||||
labels = decoded_payload['labels']
|
||||
if "labels" in decoded_payload and decoded_payload["labels"] is not None:
|
||||
labels = decoded_payload["labels"]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
if 'topic' in decoded_payload and decoded_payload['topic'] is not None:
|
||||
topic = decoded_payload['topic']
|
||||
if "topic" in decoded_payload and decoded_payload["topic"] is not None:
|
||||
topic = decoded_payload["topic"]
|
||||
else:
|
||||
topic = None
|
||||
|
||||
|
|
@ -286,21 +314,26 @@ async def attach_user_to_public_memory(payload: Payload):
|
|||
try:
|
||||
decoded_payload = payload.payload
|
||||
|
||||
if 'topic' in decoded_payload and decoded_payload['topic'] is not None:
|
||||
topic = decoded_payload['topic']
|
||||
if "topic" in decoded_payload and decoded_payload["topic"] is not None:
|
||||
topic = decoded_payload["topic"]
|
||||
else:
|
||||
topic = None
|
||||
if 'labels' in decoded_payload and decoded_payload['labels'] is not None:
|
||||
labels = decoded_payload['labels']
|
||||
if "labels" in decoded_payload and decoded_payload["labels"] is not None:
|
||||
labels = decoded_payload["labels"]
|
||||
else:
|
||||
labels = ['sr']
|
||||
labels = ["sr"]
|
||||
|
||||
# Execute the query - replace this with the actual execution method
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from main import attach_user_to_memory, create_public_memory
|
||||
|
||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
||||
await create_public_memory(user_id=decoded_payload['user_id'], topic=topic, labels=labels)
|
||||
result = await attach_user_to_memory(user_id=decoded_payload['user_id'], topic=topic, labels=labels)
|
||||
await create_public_memory(
|
||||
user_id=decoded_payload["user_id"], topic=topic, labels=labels
|
||||
)
|
||||
result = await attach_user_to_memory(
|
||||
user_id=decoded_payload["user_id"], topic=topic, labels=labels
|
||||
)
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -312,17 +345,21 @@ async def unlink_user_from_public_memory(payload: Payload):
|
|||
try:
|
||||
decoded_payload = payload.payload
|
||||
|
||||
if 'topic' in decoded_payload and decoded_payload['topic'] is not None:
|
||||
topic = decoded_payload['topic']
|
||||
if "topic" in decoded_payload and decoded_payload["topic"] is not None:
|
||||
topic = decoded_payload["topic"]
|
||||
else:
|
||||
topic = None
|
||||
|
||||
# Execute the query - replace this with the actual execution method
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from main import unlink_user_from_memory
|
||||
|
||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
||||
result = await unlink_user_from_memory(user_id=decoded_payload['user_id'], topic=topic,
|
||||
labels=decoded_payload['labels'])
|
||||
result = await unlink_user_from_memory(
|
||||
user_id=decoded_payload["user_id"],
|
||||
topic=topic,
|
||||
labels=decoded_payload["labels"],
|
||||
)
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,7 @@ import logging
|
|||
from langchain.prompts import ChatPromptTemplate
|
||||
import json
|
||||
|
||||
#TO DO, ADD ALL CLASSIFIERS HERE
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# TO DO, ADD ALL CLASSIFIERS HERE
|
||||
|
||||
|
||||
from langchain.chains import create_extraction_chain
|
||||
|
|
@ -16,6 +11,7 @@ from langchain.chat_models import ChatOpenAI
|
|||
|
||||
from ..config import Config
|
||||
from ..database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
OPENAI_API_KEY = config.openai_key
|
||||
|
|
@ -23,150 +19,164 @@ from langchain.document_loaders import TextLoader
|
|||
from langchain.document_loaders import DirectoryLoader
|
||||
|
||||
|
||||
async def classify_documents(query:str, document_id:str, content:str):
|
||||
|
||||
document_context = content
|
||||
async def classify_documents(query: str, document_id: str, content: str):
|
||||
document_context = content
|
||||
logging.info("This is the document context", document_context)
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a summarizer and classifier. Determine what book this is and where does it belong in the output : {query}, Id: {d_id} Document context is: {context}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "summarizer",
|
||||
"description": "Summarization and classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"DocumentCategory": {
|
||||
"type": "string",
|
||||
"description": "The classification of documents in groups such as legal, medical, etc."
|
||||
json_structure = [
|
||||
{
|
||||
"name": "summarizer",
|
||||
"description": "Summarization and classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"DocumentCategory": {
|
||||
"type": "string",
|
||||
"description": "The classification of documents in groups such as legal, medical, etc.",
|
||||
},
|
||||
"Title": {
|
||||
"type": "string",
|
||||
"description": "The title of the document",
|
||||
},
|
||||
"Summary": {
|
||||
"type": "string",
|
||||
"description": "The summary of the document",
|
||||
},
|
||||
"d_id": {"type": "string", "description": "The id of the document"},
|
||||
},
|
||||
"Title": {
|
||||
"type": "string",
|
||||
"description": "The title of the document"
|
||||
},
|
||||
"Summary": {
|
||||
"type": "string",
|
||||
"description": "The summary of the document"
|
||||
},
|
||||
"d_id": {
|
||||
"type": "string",
|
||||
"description": "The id of the document"
|
||||
}
|
||||
|
||||
|
||||
}, "required": ["DocumentCategory", "Title", "Summary","d_id"] }
|
||||
}]
|
||||
chain_filter = prompt_classify | llm.bind(function_call={"name": "summarizer"}, functions=json_structure)
|
||||
classifier_output = await chain_filter.ainvoke({"query": query, "d_id": document_id, "context": str(document_context)})
|
||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
"required": ["DocumentCategory", "Title", "Summary", "d_id"],
|
||||
},
|
||||
}
|
||||
]
|
||||
chain_filter = prompt_classify | llm.bind(
|
||||
function_call={"name": "summarizer"}, functions=json_structure
|
||||
)
|
||||
classifier_output = await chain_filter.ainvoke(
|
||||
{"query": query, "d_id": document_id, "context": str(document_context)}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
print("This is the arguments string", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
return arguments_dict
|
||||
|
||||
|
||||
|
||||
# classify retrievals according to type of retrieval
|
||||
def classify_retrieval():
|
||||
pass
|
||||
|
||||
|
||||
async def classify_user_input(query, input_type):
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"InputClassification": {
|
||||
"type": "boolean",
|
||||
"description": "The classification of the input"
|
||||
}
|
||||
}, "required": ["InputClassification"] }
|
||||
}]
|
||||
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
|
||||
classifier_output = await chain_filter.ainvoke({"query": query, "input_type": input_type})
|
||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
json_structure = [
|
||||
{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"InputClassification": {
|
||||
"type": "boolean",
|
||||
"description": "The classification of the input",
|
||||
}
|
||||
},
|
||||
"required": ["InputClassification"],
|
||||
},
|
||||
}
|
||||
]
|
||||
chain_filter = prompt_classify | llm.bind(
|
||||
function_call={"name": "classifier"}, functions=json_structure
|
||||
)
|
||||
classifier_output = await chain_filter.ainvoke(
|
||||
{"query": query, "input_type": input_type}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
logging.info("This is the arguments string %s", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
|
||||
InputClassification = arguments_dict.get('InputClassification', None)
|
||||
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
|
||||
InputClassification = arguments_dict.get("InputClassification", None)
|
||||
logging.info("This is the classification %s", InputClassification)
|
||||
return InputClassification
|
||||
|
||||
|
||||
# classify documents according to type of document
|
||||
async def classify_call(query, document_summaries):
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"DocumentSummary": {
|
||||
"type": "string",
|
||||
"description": "The summary of the document and the topic it deals with."
|
||||
json_structure = [
|
||||
{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"DocumentSummary": {
|
||||
"type": "string",
|
||||
"description": "The summary of the document and the topic it deals with.",
|
||||
},
|
||||
"d_id": {"type": "string", "description": "The id of the document"},
|
||||
},
|
||||
"d_id": {
|
||||
"type": "string",
|
||||
"description": "The id of the document"
|
||||
}
|
||||
|
||||
|
||||
}, "required": ["DocumentSummary"] }
|
||||
}]
|
||||
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
|
||||
classifier_output = await chain_filter.ainvoke({"query": query, "document_summaries": document_summaries})
|
||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
"required": ["DocumentSummary"],
|
||||
},
|
||||
}
|
||||
]
|
||||
chain_filter = prompt_classify | llm.bind(
|
||||
function_call={"name": "classifier"}, functions=json_structure
|
||||
)
|
||||
classifier_output = await chain_filter.ainvoke(
|
||||
{"query": query, "document_summaries": document_summaries}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
print("This is the arguments string", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
|
||||
classfier_id = arguments_dict.get('d_id', None)
|
||||
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
|
||||
classfier_id = arguments_dict.get("d_id", None)
|
||||
|
||||
print("This is the classifier id ", classfier_id)
|
||||
|
||||
return classfier_id
|
||||
|
||||
|
||||
|
||||
async def classify_user_query(query, context, document_types):
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"UserQueryClassifier": {
|
||||
"type": "bool",
|
||||
"description": "The classification of documents in groups such as legal, medical, etc."
|
||||
}
|
||||
|
||||
|
||||
}, "required": ["UserQueryClassiffier"] }
|
||||
}]
|
||||
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
|
||||
classifier_output = await chain_filter.ainvoke({"query": query, "context": context, "document_types": document_types})
|
||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
json_structure = [
|
||||
{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"UserQueryClassifier": {
|
||||
"type": "bool",
|
||||
"description": "The classification of documents in groups such as legal, medical, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["UserQueryClassiffier"],
|
||||
},
|
||||
}
|
||||
]
|
||||
chain_filter = prompt_classify | llm.bind(
|
||||
function_call={"name": "classifier"}, functions=json_structure
|
||||
)
|
||||
classifier_output = await chain_filter.ainvoke(
|
||||
{"query": query, "context": context, "document_types": document_types}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
print("This is the arguments string", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
classfier_value = arguments_dict.get('UserQueryClassifier', None)
|
||||
classfier_value = arguments_dict.get("UserQueryClassifier", None)
|
||||
|
||||
print("This is the classifier value", classfier_value)
|
||||
|
||||
return classfier_value
|
||||
return classfier_value
|
||||
|
|
|
|||
|
|
@ -10,52 +10,65 @@ from dotenv import load_dotenv
|
|||
|
||||
base_dir = Path(__file__).resolve().parent.parent
|
||||
# Load the .env file from the base directory
|
||||
dotenv_path = base_dir / '.env'
|
||||
dotenv_path = base_dir / ".env"
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
# Paths and Directories
|
||||
memgpt_dir: str = field(default_factory=lambda: os.getenv('COG_ARCH_DIR', 'cognitive_achitecture'))
|
||||
config_path: str = field(default_factory=lambda: os.path.join(os.getenv('COG_ARCH_DIR', 'cognitive_achitecture'), 'config'))
|
||||
memgpt_dir: str = field(
|
||||
default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture")
|
||||
)
|
||||
config_path: str = field(
|
||||
default_factory=lambda: os.path.join(
|
||||
os.getenv("COG_ARCH_DIR", "cognitive_achitecture"), "config"
|
||||
)
|
||||
)
|
||||
|
||||
vectordb:str = 'lancedb'
|
||||
vectordb: str = "lancedb"
|
||||
|
||||
# Model parameters
|
||||
model: str = 'gpt-4-1106-preview'
|
||||
model_endpoint: str = 'openai'
|
||||
openai_key: Optional[str] = os.getenv('OPENAI_API_KEY')
|
||||
model: str = "gpt-4-1106-preview"
|
||||
model_endpoint: str = "openai"
|
||||
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
||||
|
||||
# Embedding parameters
|
||||
embedding_model: str = 'openai'
|
||||
embedding_model: str = "openai"
|
||||
embedding_dim: int = 1536
|
||||
embedding_chunk_size: int = 300
|
||||
|
||||
# Database parameters
|
||||
if os.getenv('ENV') == 'prod' or os.getenv('ENV') == 'dev' or os.getenv('AWS_ENV') == 'dev' or os.getenv('AWS_ENV') == 'prd':
|
||||
graph_database_url: str = os.getenv('GRAPH_DB_URL_PROD')
|
||||
graph_database_username: str = os.getenv('GRAPH_DB_USER')
|
||||
graph_database_password: str = os.getenv('GRAPH_DB_PW')
|
||||
if (
|
||||
os.getenv("ENV") == "prod"
|
||||
or os.getenv("ENV") == "dev"
|
||||
or os.getenv("AWS_ENV") == "dev"
|
||||
or os.getenv("AWS_ENV") == "prd"
|
||||
):
|
||||
graph_database_url: str = os.getenv("GRAPH_DB_URL_PROD")
|
||||
graph_database_username: str = os.getenv("GRAPH_DB_USER")
|
||||
graph_database_password: str = os.getenv("GRAPH_DB_PW")
|
||||
else:
|
||||
graph_database_url: str = os.getenv('GRAPH_DB_URL')
|
||||
graph_database_username: str = os.getenv('GRAPH_DB_USER')
|
||||
graph_database_password: str = os.getenv('GRAPH_DB_PW')
|
||||
weaviate_url: str = os.getenv('WEAVIATE_URL')
|
||||
weaviate_api_key: str = os.getenv('WEAVIATE_API_KEY')
|
||||
postgres_user: str = os.getenv('POSTGRES_USER')
|
||||
postgres_password: str = os.getenv('POSTGRES_PASSWORD')
|
||||
postgres_db: str = os.getenv('POSTGRES_DB')
|
||||
if os.getenv('ENV') == 'prod' or os.getenv('ENV') == 'dev' or os.getenv('AWS_ENV') == 'dev' or os.getenv('AWS_ENV') == 'prd':
|
||||
postgres_host: str = os.getenv('POSTGRES_PROD_HOST')
|
||||
elif os.getenv('ENV') == 'docker':
|
||||
postgres_host: str = os.getenv('POSTGRES_HOST_DOCKER')
|
||||
elif os.getenv('ENV') == 'local':
|
||||
postgres_host: str = os.getenv('POSTGRES_HOST_LOCAL')
|
||||
|
||||
|
||||
|
||||
|
||||
graph_database_url: str = os.getenv("GRAPH_DB_URL")
|
||||
graph_database_username: str = os.getenv("GRAPH_DB_USER")
|
||||
graph_database_password: str = os.getenv("GRAPH_DB_PW")
|
||||
weaviate_url: str = os.getenv("WEAVIATE_URL")
|
||||
weaviate_api_key: str = os.getenv("WEAVIATE_API_KEY")
|
||||
postgres_user: str = os.getenv("POSTGRES_USER")
|
||||
postgres_password: str = os.getenv("POSTGRES_PASSWORD")
|
||||
postgres_db: str = os.getenv("POSTGRES_DB")
|
||||
if (
|
||||
os.getenv("ENV") == "prod"
|
||||
or os.getenv("ENV") == "dev"
|
||||
or os.getenv("AWS_ENV") == "dev"
|
||||
or os.getenv("AWS_ENV") == "prd"
|
||||
):
|
||||
postgres_host: str = os.getenv("POSTGRES_PROD_HOST")
|
||||
elif os.getenv("ENV") == "docker":
|
||||
postgres_host: str = os.getenv("POSTGRES_HOST_DOCKER")
|
||||
elif os.getenv("ENV") == "local":
|
||||
postgres_host: str = os.getenv("POSTGRES_HOST_LOCAL")
|
||||
|
||||
# Client ID
|
||||
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
|
|
@ -84,12 +97,12 @@ class Config:
|
|||
|
||||
# Save the current settings to the config file
|
||||
for attr, value in self.__dict__.items():
|
||||
section, option = attr.split('_', 1)
|
||||
section, option = attr.split("_", 1)
|
||||
if not config.has_section(section):
|
||||
config.add_section(section)
|
||||
config.set(section, option, str(value))
|
||||
|
||||
with open(self.config_path, 'w') as configfile:
|
||||
with open(self.config_path, "w") as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -22,12 +22,17 @@ from sqlalchemy import create_engine, text
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from contextlib import contextmanager
|
||||
from dotenv import load_dotenv
|
||||
from relationaldb.database import Base # Assuming all models are imported within this module
|
||||
from relationaldb.database import DatabaseConfig # Assuming DatabaseConfig is defined as before
|
||||
from relationaldb.database import (
|
||||
Base,
|
||||
) # Assuming all models are imported within this module
|
||||
from relationaldb.database import (
|
||||
DatabaseConfig,
|
||||
) # Assuming DatabaseConfig is defined as before
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, config: DatabaseConfig):
|
||||
self.config = config
|
||||
|
|
@ -36,7 +41,7 @@ class DatabaseManager:
|
|||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
if self.db_type in ['sqlite', 'duckdb']:
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, the engine itself manages connections
|
||||
yield self.engine
|
||||
else:
|
||||
|
|
@ -47,7 +52,7 @@ class DatabaseManager:
|
|||
connection.close()
|
||||
|
||||
def database_exists(self, db_name):
|
||||
if self.db_type in ['sqlite', 'duckdb']:
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, check if the database file exists
|
||||
return os.path.exists(db_name)
|
||||
else:
|
||||
|
|
@ -57,14 +62,14 @@ class DatabaseManager:
|
|||
return result is not None
|
||||
|
||||
def create_database(self, db_name):
|
||||
if self.db_type not in ['sqlite', 'duckdb']:
|
||||
if self.db_type not in ["sqlite", "duckdb"]:
|
||||
# For databases like PostgreSQL, create the database explicitly
|
||||
with self.get_connection() as connection:
|
||||
connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
connection.execute(f"CREATE DATABASE {db_name}")
|
||||
|
||||
def drop_database(self, db_name):
|
||||
if self.db_type in ['sqlite', 'duckdb']:
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, simply remove the database file
|
||||
os.remove(db_name)
|
||||
else:
|
||||
|
|
@ -75,9 +80,10 @@ class DatabaseManager:
|
|||
def create_tables(self):
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage with SQLite
|
||||
config = DatabaseConfig(db_type='sqlite', db_name='mydatabase.db')
|
||||
config = DatabaseConfig(db_type="sqlite", db_name="mydatabase.db")
|
||||
|
||||
# For DuckDB, you would set db_type to 'duckdb' and provide the database file name
|
||||
# config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb')
|
||||
|
|
@ -139,4 +145,4 @@ if __name__ == "__main__":
|
|||
# connection.close()
|
||||
# engine.dispose()
|
||||
#
|
||||
# create_tables(engine)
|
||||
# create_tables(engine)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
|
@ -26,13 +25,19 @@ from abc import ABC, abstractmethod
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
from ...utils import format_dict, append_uuid_to_variable_names, create_edge_variable_mapping, \
|
||||
create_node_variable_mapping, get_unsumarized_vector_db_namespace
|
||||
from ...utils import (
|
||||
format_dict,
|
||||
append_uuid_to_variable_names,
|
||||
create_edge_variable_mapping,
|
||||
create_node_variable_mapping,
|
||||
get_unsumarized_vector_db_namespace,
|
||||
)
|
||||
from ...llm.queries import generate_summary, generate_graph
|
||||
import logging
|
||||
from neo4j import AsyncGraphDatabase, Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
DEFAULT_PRESET = "promethai_chat"
|
||||
preset_options = [DEFAULT_PRESET]
|
||||
PROMETHAI_DIR = os.path.join(os.path.expanduser("~"), ".")
|
||||
|
|
@ -41,7 +46,13 @@ load_dotenv()
|
|||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
from ...config import Config
|
||||
|
||||
from ...shared.data_models import Node, Edge, KnowledgeGraph, GraphQLQuery, MemorySummary
|
||||
from ...shared.data_models import (
|
||||
Node,
|
||||
Edge,
|
||||
KnowledgeGraph,
|
||||
GraphQLQuery,
|
||||
MemorySummary,
|
||||
)
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -53,8 +64,8 @@ OPENAI_API_KEY = config.openai_key
|
|||
|
||||
aclient = instructor.patch(OpenAI())
|
||||
|
||||
class AbstractGraphDB(ABC):
|
||||
|
||||
class AbstractGraphDB(ABC):
|
||||
@abstractmethod
|
||||
def query(self, query: str, params=None):
|
||||
pass
|
||||
|
|
@ -73,8 +84,12 @@ class AbstractGraphDB(ABC):
|
|||
|
||||
|
||||
class Neo4jGraphDB(AbstractGraphDB):
|
||||
def __init__(self, url: str, username: str, password: str, driver: Optional[Any] = None):
|
||||
self.driver = driver or AsyncGraphDatabase.driver(url, auth=(username, password))
|
||||
def __init__(
|
||||
self, url: str, username: str, password: str, driver: Optional[Any] = None
|
||||
):
|
||||
self.driver = driver or AsyncGraphDatabase.driver(
|
||||
url, auth=(username, password)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.driver.close()
|
||||
|
|
@ -84,7 +99,9 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
async with self.driver.session() as session:
|
||||
yield session
|
||||
|
||||
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
|
|
@ -93,30 +110,28 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
logging.error(f"Neo4j query error: {e.message}")
|
||||
raise
|
||||
|
||||
|
||||
# class Neo4jGraphDB(AbstractGraphDB):
|
||||
# def __init__(self, url, username, password):
|
||||
# # self.graph = Neo4jGraph(url=url, username=username, password=password)
|
||||
# from neo4j import GraphDatabase
|
||||
# self.driver = GraphDatabase.driver(url, auth=(username, password))
|
||||
# self.openai_key = config.openai_key
|
||||
#
|
||||
#
|
||||
#
|
||||
# def close(self):
|
||||
# # Method to close the Neo4j driver instance
|
||||
# self.driver.close()
|
||||
#
|
||||
# def query(self, query, params=None):
|
||||
# try:
|
||||
# with self.driver.session() as session:
|
||||
# result = session.run(query, params).data()
|
||||
# return result
|
||||
# except Exception as e:
|
||||
# logging.error(f"An error occurred while executing the query: {e}")
|
||||
# raise e
|
||||
#
|
||||
|
||||
# class Neo4jGraphDB(AbstractGraphDB):
|
||||
# def __init__(self, url, username, password):
|
||||
# # self.graph = Neo4jGraph(url=url, username=username, password=password)
|
||||
# from neo4j import GraphDatabase
|
||||
# self.driver = GraphDatabase.driver(url, auth=(username, password))
|
||||
# self.openai_key = config.openai_key
|
||||
#
|
||||
#
|
||||
#
|
||||
# def close(self):
|
||||
# # Method to close the Neo4j driver instance
|
||||
# self.driver.close()
|
||||
#
|
||||
# def query(self, query, params=None):
|
||||
# try:
|
||||
# with self.driver.session() as session:
|
||||
# result = session.run(query, params).data()
|
||||
# return result
|
||||
# except Exception as e:
|
||||
# logging.error(f"An error occurred while executing the query: {e}")
|
||||
# raise e
|
||||
#
|
||||
|
||||
def create_base_cognitive_architecture(self, user_id: str):
|
||||
# Create the user and memory components if they don't exist
|
||||
|
|
@ -131,16 +146,22 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
return user_memory_cypher
|
||||
|
||||
async def retrieve_memory(self, user_id: str, memory_type: str, timestamp: float = None, summarized: bool = None):
|
||||
if memory_type == 'SemanticMemory':
|
||||
relationship = 'SEMANTIC_MEMORY'
|
||||
memory_rel = 'HAS_KNOWLEDGE'
|
||||
elif memory_type == 'EpisodicMemory':
|
||||
relationship = 'EPISODIC_MEMORY'
|
||||
memory_rel = 'HAS_EVENT'
|
||||
elif memory_type == 'Buffer':
|
||||
relationship = 'BUFFER'
|
||||
memory_rel = 'CURRENTLY_HOLDING'
|
||||
async def retrieve_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
memory_type: str,
|
||||
timestamp: float = None,
|
||||
summarized: bool = None,
|
||||
):
|
||||
if memory_type == "SemanticMemory":
|
||||
relationship = "SEMANTIC_MEMORY"
|
||||
memory_rel = "HAS_KNOWLEDGE"
|
||||
elif memory_type == "EpisodicMemory":
|
||||
relationship = "EPISODIC_MEMORY"
|
||||
memory_rel = "HAS_EVENT"
|
||||
elif memory_type == "Buffer":
|
||||
relationship = "BUFFER"
|
||||
memory_rel = "CURRENTLY_HOLDING"
|
||||
if timestamp is not None and summarized is not None:
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_{relationship}]->(memory:{memory_type})
|
||||
|
|
@ -172,79 +193,100 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
output = self.query(query, params={"user_id": user_id})
|
||||
print("Here is the output", output)
|
||||
|
||||
reduced_graph = await generate_summary(input = output)
|
||||
reduced_graph = await generate_summary(input=output)
|
||||
return reduced_graph
|
||||
|
||||
|
||||
def cypher_statement_correcting(self, input: str) ->str:
|
||||
def cypher_statement_correcting(self, input: str) -> str:
|
||||
return aclient.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Check the cypher query for syntax issues, and fix any if found and return it as is: {input}. """,
|
||||
|
||||
},
|
||||
{"role": "system", "content": """You are a top-tier algorithm
|
||||
designed for checking cypher queries for neo4j graph databases. You have to return input provided to you as is"""}
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a top-tier algorithm
|
||||
designed for checking cypher queries for neo4j graph databases. You have to return input provided to you as is""",
|
||||
},
|
||||
],
|
||||
response_model=GraphQLQuery,
|
||||
)
|
||||
|
||||
def generate_create_statements_for_nodes_with_uuid(self, nodes, unique_mapping, base_node_mapping):
|
||||
def generate_create_statements_for_nodes_with_uuid(
|
||||
self, nodes, unique_mapping, base_node_mapping
|
||||
):
|
||||
create_statements = []
|
||||
for node in nodes:
|
||||
original_variable_name = base_node_mapping[node['id']]
|
||||
original_variable_name = base_node_mapping[node["id"]]
|
||||
unique_variable_name = unique_mapping[original_variable_name]
|
||||
node_label = node['category'].capitalize()
|
||||
properties = {k: v for k, v in node.items() if k not in ['id', 'category']}
|
||||
node_label = node["category"].capitalize()
|
||||
properties = {k: v for k, v in node.items() if k not in ["id", "category"]}
|
||||
try:
|
||||
properties = format_dict(properties)
|
||||
except:
|
||||
pass
|
||||
create_statements.append(f"CREATE ({unique_variable_name}:{node_label} {properties})")
|
||||
create_statements.append(
|
||||
f"CREATE ({unique_variable_name}:{node_label} {properties})"
|
||||
)
|
||||
return create_statements
|
||||
|
||||
# Update the function to generate Cypher CREATE statements for edges with unique variable names
|
||||
def generate_create_statements_for_edges_with_uuid(self, user_id, edges, unique_mapping, base_node_mapping):
|
||||
def generate_create_statements_for_edges_with_uuid(
|
||||
self, user_id, edges, unique_mapping, base_node_mapping
|
||||
):
|
||||
create_statements = []
|
||||
with_statement = f"WITH {', '.join(unique_mapping.values())}, user , semantic, episodic, buffer"
|
||||
create_statements.append(with_statement)
|
||||
|
||||
for edge in edges:
|
||||
# print("HERE IS THE EDGE", edge)
|
||||
source_variable = unique_mapping[base_node_mapping[edge['source']]]
|
||||
target_variable = unique_mapping[base_node_mapping[edge['target']]]
|
||||
relationship = edge['description'].replace(" ", "_").upper()
|
||||
create_statements.append(f"CREATE ({source_variable})-[:{relationship}]->({target_variable})")
|
||||
source_variable = unique_mapping[base_node_mapping[edge["source"]]]
|
||||
target_variable = unique_mapping[base_node_mapping[edge["target"]]]
|
||||
relationship = edge["description"].replace(" ", "_").upper()
|
||||
create_statements.append(
|
||||
f"CREATE ({source_variable})-[:{relationship}]->({target_variable})"
|
||||
)
|
||||
return create_statements
|
||||
|
||||
def generate_memory_type_relationships_with_uuid_and_time_context(self, user_id, nodes, unique_mapping, base_node_mapping):
|
||||
def generate_memory_type_relationships_with_uuid_and_time_context(
|
||||
self, user_id, nodes, unique_mapping, base_node_mapping
|
||||
):
|
||||
create_statements = []
|
||||
with_statement = f"WITH {', '.join(unique_mapping.values())}, user, semantic, episodic, buffer"
|
||||
create_statements.append(with_statement)
|
||||
|
||||
# Loop through each node and create relationships based on memory_type
|
||||
for node in nodes:
|
||||
original_variable_name = base_node_mapping[node['id']]
|
||||
original_variable_name = base_node_mapping[node["id"]]
|
||||
unique_variable_name = unique_mapping[original_variable_name]
|
||||
if node['memory_type'] == 'semantic':
|
||||
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_variable_name})")
|
||||
elif node['memory_type'] == 'episodic':
|
||||
create_statements.append(f"CREATE (episodic)-[:HAS_EVENT]->({unique_variable_name})")
|
||||
if node['category'] == 'time':
|
||||
create_statements.append(f"CREATE (buffer)-[:HAS_TIME_CONTEXT]->({unique_variable_name})")
|
||||
if node["memory_type"] == "semantic":
|
||||
create_statements.append(
|
||||
f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_variable_name})"
|
||||
)
|
||||
elif node["memory_type"] == "episodic":
|
||||
create_statements.append(
|
||||
f"CREATE (episodic)-[:HAS_EVENT]->({unique_variable_name})"
|
||||
)
|
||||
if node["category"] == "time":
|
||||
create_statements.append(
|
||||
f"CREATE (buffer)-[:HAS_TIME_CONTEXT]->({unique_variable_name})"
|
||||
)
|
||||
|
||||
# Assuming buffer holds all actions and times
|
||||
# if node['category'] in ['action', 'time']:
|
||||
create_statements.append(f"CREATE (buffer)-[:CURRENTLY_HOLDING]->({unique_variable_name})")
|
||||
create_statements.append(
|
||||
f"CREATE (buffer)-[:CURRENTLY_HOLDING]->({unique_variable_name})"
|
||||
)
|
||||
|
||||
return create_statements
|
||||
|
||||
async def generate_cypher_query_for_user_prompt_decomposition(self, user_id:str, query:str):
|
||||
|
||||
async def generate_cypher_query_for_user_prompt_decomposition(
|
||||
self, user_id: str, query: str
|
||||
):
|
||||
graph: KnowledgeGraph = generate_graph(query)
|
||||
import time
|
||||
|
||||
for node in graph.nodes:
|
||||
node.created_at = time.time()
|
||||
node.summarized = False
|
||||
|
|
@ -254,19 +296,41 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
edge.summarized = False
|
||||
graph_dic = graph.dict()
|
||||
|
||||
node_variable_mapping = create_node_variable_mapping(graph_dic['nodes'])
|
||||
edge_variable_mapping = create_edge_variable_mapping(graph_dic['edges'])
|
||||
node_variable_mapping = create_node_variable_mapping(graph_dic["nodes"])
|
||||
edge_variable_mapping = create_edge_variable_mapping(graph_dic["edges"])
|
||||
# Create unique variable names for each node
|
||||
unique_node_variable_mapping = append_uuid_to_variable_names(node_variable_mapping)
|
||||
unique_edge_variable_mapping = append_uuid_to_variable_names(edge_variable_mapping)
|
||||
create_nodes_statements = self.generate_create_statements_for_nodes_with_uuid(graph_dic['nodes'], unique_node_variable_mapping, node_variable_mapping)
|
||||
create_edges_statements =self.generate_create_statements_for_edges_with_uuid(user_id, graph_dic['edges'], unique_node_variable_mapping, node_variable_mapping)
|
||||
unique_node_variable_mapping = append_uuid_to_variable_names(
|
||||
node_variable_mapping
|
||||
)
|
||||
unique_edge_variable_mapping = append_uuid_to_variable_names(
|
||||
edge_variable_mapping
|
||||
)
|
||||
create_nodes_statements = self.generate_create_statements_for_nodes_with_uuid(
|
||||
graph_dic["nodes"], unique_node_variable_mapping, node_variable_mapping
|
||||
)
|
||||
create_edges_statements = self.generate_create_statements_for_edges_with_uuid(
|
||||
user_id,
|
||||
graph_dic["edges"],
|
||||
unique_node_variable_mapping,
|
||||
node_variable_mapping,
|
||||
)
|
||||
|
||||
memory_type_statements_with_uuid_and_time_context = self.generate_memory_type_relationships_with_uuid_and_time_context(user_id,
|
||||
graph_dic['nodes'], unique_node_variable_mapping, node_variable_mapping)
|
||||
memory_type_statements_with_uuid_and_time_context = (
|
||||
self.generate_memory_type_relationships_with_uuid_and_time_context(
|
||||
user_id,
|
||||
graph_dic["nodes"],
|
||||
unique_node_variable_mapping,
|
||||
node_variable_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
# # Combine all statements
|
||||
cypher_statements = [self.create_base_cognitive_architecture(user_id)] + create_nodes_statements + create_edges_statements + memory_type_statements_with_uuid_and_time_context
|
||||
cypher_statements = (
|
||||
[self.create_base_cognitive_architecture(user_id)]
|
||||
+ create_nodes_statements
|
||||
+ create_edges_statements
|
||||
+ memory_type_statements_with_uuid_and_time_context
|
||||
)
|
||||
cypher_statements_joined = "\n".join(cypher_statements)
|
||||
logging.info("User Cypher Query raw: %s", cypher_statements_joined)
|
||||
# corrected_cypher_statements = self.cypher_statement_correcting(input = cypher_statements_joined)
|
||||
|
|
@ -274,15 +338,15 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
# return corrected_cypher_statements.query
|
||||
return cypher_statements_joined
|
||||
|
||||
|
||||
def update_user_query_for_user_prompt_decomposition(self, user_id, user_query):
|
||||
pass
|
||||
|
||||
|
||||
def delete_all_user_memories(self, user_id):
|
||||
try:
|
||||
# Check if the user exists
|
||||
user_exists = self.query(f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user")
|
||||
user_exists = self.query(
|
||||
f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user"
|
||||
)
|
||||
if not user_exists:
|
||||
return f"No user found with ID: {user_id}"
|
||||
|
||||
|
|
@ -304,12 +368,14 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
def delete_specific_memory_type(self, user_id, memory_type):
|
||||
try:
|
||||
# Check if the user exists
|
||||
user_exists = self.query(f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user")
|
||||
user_exists = self.query(
|
||||
f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user"
|
||||
)
|
||||
if not user_exists:
|
||||
return f"No user found with ID: {user_id}"
|
||||
|
||||
# Validate memory type
|
||||
if memory_type not in ['SemanticMemory', 'EpisodicMemory', 'Buffer']:
|
||||
if memory_type not in ["SemanticMemory", "EpisodicMemory", "Buffer"]:
|
||||
return "Invalid memory type. Choose from 'SemanticMemory', 'EpisodicMemory', or 'Buffer'."
|
||||
|
||||
# Delete specific memory type nodes and relationships for the given user
|
||||
|
|
@ -322,7 +388,9 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
|
||||
def retrieve_semantic_memory(self, user_id: str, timestamp: float = None, summarized: bool = None):
|
||||
def retrieve_semantic_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory)
|
||||
|
|
@ -352,7 +420,9 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
|
||||
def retrieve_episodic_memory(self, user_id: str, timestamp: float = None, summarized: bool = None):
|
||||
def retrieve_episodic_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_EPISODIC_MEMORY]->(episodic:EpisodicMemory)
|
||||
|
|
@ -382,8 +452,9 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
|
||||
|
||||
def retrieve_buffer_memory(self, user_id: str, timestamp: float = None, summarized: bool = None):
|
||||
def retrieve_buffer_memory(
|
||||
self, user_id: str, timestamp: float = None, summarized: bool = None
|
||||
):
|
||||
if timestamp is not None and summarized is not None:
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_BUFFER]->(buffer:Buffer)
|
||||
|
|
@ -413,8 +484,6 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
|
||||
|
||||
|
||||
def retrieve_public_memory(self, user_id: str):
|
||||
query = """
|
||||
MATCH (user:User {userId: $user_id})-[:HAS_PUBLIC_MEMORY]->(public:PublicMemory)
|
||||
|
|
@ -422,23 +491,33 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
RETURN document
|
||||
"""
|
||||
return self.query(query, params={"user_id": user_id})
|
||||
def generate_graph_semantic_memory_document_summary(self, document_summary : str, unique_graphdb_mapping_values: dict, document_namespace: str):
|
||||
""" This function takes a document and generates a document summary in Semantic Memory"""
|
||||
|
||||
def generate_graph_semantic_memory_document_summary(
|
||||
self,
|
||||
document_summary: str,
|
||||
unique_graphdb_mapping_values: dict,
|
||||
document_namespace: str,
|
||||
):
|
||||
"""This function takes a document and generates a document summary in Semantic Memory"""
|
||||
create_statements = []
|
||||
with_statement = f"WITH {', '.join(unique_graphdb_mapping_values.values())}, user, semantic, episodic, buffer"
|
||||
create_statements.append(with_statement)
|
||||
|
||||
# Loop through each node and create relationships based on memory_type
|
||||
|
||||
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})")
|
||||
|
||||
create_statements.append(
|
||||
f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})"
|
||||
)
|
||||
|
||||
return create_statements
|
||||
|
||||
|
||||
def generate_document_summary(self, document_summary : str, unique_graphdb_mapping_values: dict, document_namespace: str):
|
||||
""" This function takes a document and generates a document summary in Semantic Memory"""
|
||||
|
||||
def generate_document_summary(
|
||||
self,
|
||||
document_summary: str,
|
||||
unique_graphdb_mapping_values: dict,
|
||||
document_namespace: str,
|
||||
):
|
||||
"""This function takes a document and generates a document summary in Semantic Memory"""
|
||||
|
||||
# fetch namespace from postgres db
|
||||
# fetch 1st and last page from vector store
|
||||
|
|
@ -450,12 +529,15 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
# Loop through each node and create relationships based on memory_type
|
||||
|
||||
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})")
|
||||
|
||||
create_statements.append(
|
||||
f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})"
|
||||
)
|
||||
|
||||
return create_statements
|
||||
|
||||
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"):
|
||||
async def get_memory_linked_document_summaries(
|
||||
self, user_id: str, memory_type: str = "PublicMemory"
|
||||
):
|
||||
"""
|
||||
Retrieve a list of summaries for all documents associated with a given memory type for a user.
|
||||
|
||||
|
|
@ -474,23 +556,30 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
elif memory_type == "SemanticMemory":
|
||||
relationship = "HAS_SEMANTIC_MEMORY"
|
||||
try:
|
||||
query = f'''
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||
RETURN document.d_id AS d_id, document.summary AS summary
|
||||
'''
|
||||
"""
|
||||
logging.info(f"Generated Cypher query: {query}")
|
||||
result = self.query(query)
|
||||
logging.info(f"Result: {result}")
|
||||
return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for
|
||||
record in result]
|
||||
return [
|
||||
{
|
||||
"d_id": record.get("d_id", None),
|
||||
"summary": record.get("summary", "No summary available"),
|
||||
}
|
||||
for record in result
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving document summary: {str(e)}")
|
||||
logging.error(
|
||||
f"An error occurred while retrieving document summary: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"):
|
||||
async def get_memory_linked_document_ids(
|
||||
self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"
|
||||
):
|
||||
"""
|
||||
Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
|
||||
|
||||
|
|
@ -511,11 +600,11 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
elif memory_type == "SemanticMemory":
|
||||
relationship = "HAS_SEMANTIC_MEMORY"
|
||||
try:
|
||||
query = f'''
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
|
||||
WHERE document.d_id = '{summary_id}'
|
||||
RETURN document.d_id AS d_id
|
||||
'''
|
||||
"""
|
||||
logging.info(f"Generated Cypher query: {query}")
|
||||
result = self.query(query)
|
||||
return [record["d_id"] for record in result]
|
||||
|
|
@ -523,9 +612,13 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def create_document_node_cypher(self, document_summary: dict, user_id: str,
|
||||
memory_type: str = "PublicMemory",public_memory_id:str=None) -> str:
|
||||
def create_document_node_cypher(
|
||||
self,
|
||||
document_summary: dict,
|
||||
user_id: str,
|
||||
memory_type: str = "PublicMemory",
|
||||
public_memory_id: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a Cypher query to create a Document node. If the memory type is 'Semantic',
|
||||
link it to a SemanticMemory node for a user. If the memory type is 'PublicMemory',
|
||||
|
|
@ -546,37 +639,46 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
# Validate the input parameters
|
||||
if not isinstance(document_summary, dict):
|
||||
raise ValueError("The document_summary must be a dictionary.")
|
||||
if not all(key in document_summary for key in ['DocumentCategory', 'Title', 'Summary', 'd_id']):
|
||||
raise ValueError("The document_summary dictionary is missing required keys.")
|
||||
if not all(
|
||||
key in document_summary
|
||||
for key in ["DocumentCategory", "Title", "Summary", "d_id"]
|
||||
):
|
||||
raise ValueError(
|
||||
"The document_summary dictionary is missing required keys."
|
||||
)
|
||||
if not isinstance(user_id, str) or not user_id:
|
||||
raise ValueError("The user_id must be a non-empty string.")
|
||||
if memory_type not in ["SemanticMemory", "PublicMemory"]:
|
||||
raise ValueError("The memory_type must be either 'Semantic' or 'PublicMemory'.")
|
||||
raise ValueError(
|
||||
"The memory_type must be either 'Semantic' or 'PublicMemory'."
|
||||
)
|
||||
|
||||
# Escape single quotes in the document summary data
|
||||
title = document_summary['Title'].replace("'", "\\'")
|
||||
summary = document_summary['Summary'].replace("'", "\\'")
|
||||
document_category = document_summary['DocumentCategory'].replace("'", "\\'")
|
||||
d_id = document_summary['d_id'].replace("'", "\\'")
|
||||
title = document_summary["Title"].replace("'", "\\'")
|
||||
summary = document_summary["Summary"].replace("'", "\\'")
|
||||
document_category = document_summary["DocumentCategory"].replace("'", "\\'")
|
||||
d_id = document_summary["d_id"].replace("'", "\\'")
|
||||
|
||||
memory_node_type = "SemanticMemory" if memory_type == "SemanticMemory" else "PublicMemory"
|
||||
memory_node_type = (
|
||||
"SemanticMemory" if memory_type == "SemanticMemory" else "PublicMemory"
|
||||
)
|
||||
|
||||
user_memory_link = ''
|
||||
user_memory_link = ""
|
||||
if memory_type == "SemanticMemory":
|
||||
user_memory_link = f'''
|
||||
user_memory_link = f"""
|
||||
// Ensure the User node exists
|
||||
MERGE (user:User {{ userId: '{user_id}' }})
|
||||
MERGE (memory:SemanticMemory {{ userId: '{user_id}' }})
|
||||
MERGE (user)-[:HAS_SEMANTIC_MEMORY]->(memory)
|
||||
'''
|
||||
"""
|
||||
elif memory_type == "PublicMemory":
|
||||
logging.info(f"Public memory id: {public_memory_id}")
|
||||
user_memory_link = f'''
|
||||
user_memory_link = f"""
|
||||
// Merge with the existing PublicMemory node or create a new one if it does not exist
|
||||
MATCH (memory:PublicMemory {{ memoryId: {public_memory_id} }})
|
||||
'''
|
||||
"""
|
||||
|
||||
cypher_query = f'''
|
||||
cypher_query = f"""
|
||||
{user_memory_link}
|
||||
|
||||
// Create the Document node with its properties
|
||||
|
|
@ -590,13 +692,15 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
// Link the Document node to the {memory_node_type} node
|
||||
MERGE (memory)-[:HAS_DOCUMENT]->(document)
|
||||
'''
|
||||
"""
|
||||
|
||||
logging.info(f"Generated Cypher query: {cypher_query}")
|
||||
|
||||
return cypher_query
|
||||
|
||||
def update_document_node_with_db_ids(self, vectordb_namespace: str, document_id: str, user_id: str = None):
|
||||
def update_document_node_with_db_ids(
|
||||
self, vectordb_namespace: str, document_id: str, user_id: str = None
|
||||
):
|
||||
"""
|
||||
Update the namespace of a Document node in the database. The document can be linked
|
||||
either to a SemanticMemory node (if a user ID is provided) or to a PublicMemory node.
|
||||
|
|
@ -612,23 +716,24 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
|
||||
if user_id:
|
||||
# Update for a document linked to a SemanticMemory node
|
||||
cypher_query = f'''
|
||||
cypher_query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}})
|
||||
SET document.vectordbNamespace = '{vectordb_namespace}'
|
||||
RETURN document
|
||||
'''
|
||||
"""
|
||||
else:
|
||||
# Update for a document linked to a PublicMemory node
|
||||
cypher_query = f'''
|
||||
cypher_query = f"""
|
||||
MATCH (:PublicMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}})
|
||||
SET document.vectordbNamespace = '{vectordb_namespace}'
|
||||
RETURN document
|
||||
'''
|
||||
"""
|
||||
|
||||
return cypher_query
|
||||
|
||||
def run_merge_query(self, user_id: str, memory_type: str,
|
||||
similarity_threshold: float) -> str:
|
||||
def run_merge_query(
|
||||
self, user_id: str, memory_type: str, similarity_threshold: float
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a Cypher query to merge nodes in a Neo4j database based on a similarity threshold.
|
||||
|
||||
|
|
@ -645,29 +750,28 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
Returns:
|
||||
str: A Cypher query string that can be executed in a Neo4j session.
|
||||
"""
|
||||
if memory_type == 'SemanticMemory':
|
||||
relationship_base = 'HAS_SEMANTIC_MEMORY'
|
||||
relationship_type = 'HAS_KNOWLEDGE'
|
||||
memory_label = 'semantic'
|
||||
elif memory_type == 'EpisodicMemory':
|
||||
relationship_base = 'HAS_EPISODIC_MEMORY'
|
||||
if memory_type == "SemanticMemory":
|
||||
relationship_base = "HAS_SEMANTIC_MEMORY"
|
||||
relationship_type = "HAS_KNOWLEDGE"
|
||||
memory_label = "semantic"
|
||||
elif memory_type == "EpisodicMemory":
|
||||
relationship_base = "HAS_EPISODIC_MEMORY"
|
||||
# relationship_type = 'EPISODIC_MEMORY'
|
||||
relationship_type = 'HAS_EVENT'
|
||||
memory_label='episodic'
|
||||
elif memory_type == 'Buffer':
|
||||
relationship_base = 'HAS_BUFFER_MEMORY'
|
||||
relationship_type = 'CURRENTLY_HOLDING'
|
||||
memory_label= 'buffer'
|
||||
relationship_type = "HAS_EVENT"
|
||||
memory_label = "episodic"
|
||||
elif memory_type == "Buffer":
|
||||
relationship_base = "HAS_BUFFER_MEMORY"
|
||||
relationship_type = "CURRENTLY_HOLDING"
|
||||
memory_label = "buffer"
|
||||
|
||||
|
||||
query= f"""MATCH (u:User {{userId: '{user_id}'}})-[:{relationship_base}]->(sm:{memory_type})
|
||||
query = f"""MATCH (u:User {{userId: '{user_id}'}})-[:{relationship_base}]->(sm:{memory_type})
|
||||
MATCH (sm)-[:{relationship_type}]->(n)
|
||||
RETURN labels(n) AS NodeType, collect(n) AS Nodes
|
||||
"""
|
||||
|
||||
node_results = self.query(query)
|
||||
|
||||
node_types = [record['NodeType'] for record in node_results]
|
||||
node_types = [record["NodeType"] for record in node_results]
|
||||
|
||||
for node in node_types:
|
||||
query = f"""
|
||||
|
|
@ -703,16 +807,18 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
- Exception: If an error occurs during the database query execution.
|
||||
"""
|
||||
try:
|
||||
query = f'''
|
||||
query = f"""
|
||||
MATCH (user:User {{userId: '{user_id}'}})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document)
|
||||
WHERE document.documentCategory = '{category}'
|
||||
RETURN document.vectordbNamespace AS namespace
|
||||
'''
|
||||
"""
|
||||
result = self.query(query)
|
||||
namespaces = [record["namespace"] for record in result]
|
||||
return namespaces
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving namespaces by document category: {str(e)}")
|
||||
logging.error(
|
||||
f"An error occurred while retrieving namespaces by document category: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def create_memory_node(self, labels, topic=None):
|
||||
|
|
@ -734,7 +840,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
topic = "PublicMemory"
|
||||
|
||||
# Prepare labels as a string
|
||||
label_list = ', '.join(f"'{label}'" for label in labels)
|
||||
label_list = ", ".join(f"'{label}'" for label in labels)
|
||||
|
||||
# Cypher query to find or create the memory node with the given description and labels
|
||||
memory_cypher = f"""
|
||||
|
|
@ -746,17 +852,24 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
try:
|
||||
result = self.query(memory_cypher)
|
||||
# Assuming the result is a list of records, where each record contains 'memoryId'
|
||||
memory_id = result[0]['memoryId'] if result else None
|
||||
memory_id = result[0]["memoryId"] if result else None
|
||||
self.close()
|
||||
return memory_id
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating or finding memory node: {e}")
|
||||
raise
|
||||
|
||||
def link_user_to_public(self, user_id: str, public_property_value: str, public_property_name: str = 'name',
|
||||
relationship_type: str = 'HAS_PUBLIC'):
|
||||
def link_user_to_public(
|
||||
self,
|
||||
user_id: str,
|
||||
public_property_value: str,
|
||||
public_property_name: str = "name",
|
||||
relationship_type: str = "HAS_PUBLIC",
|
||||
):
|
||||
if not user_id or not public_property_value:
|
||||
raise ValueError("Valid User ID and Public property value are required for linking.")
|
||||
raise ValueError(
|
||||
"Valid User ID and Public property value are required for linking."
|
||||
)
|
||||
|
||||
try:
|
||||
link_cypher = f"""
|
||||
|
|
@ -784,7 +897,9 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
logging.error(f"Error deleting {topic} memory node: {e}")
|
||||
raise
|
||||
|
||||
def unlink_memory_from_user(self, memory_id: int, user_id: str, topic: str='PublicMemory') -> None:
|
||||
def unlink_memory_from_user(
|
||||
self, memory_id: int, user_id: str, topic: str = "PublicMemory"
|
||||
) -> None:
|
||||
"""
|
||||
Unlink a memory node from a user node.
|
||||
|
||||
|
|
@ -801,9 +916,13 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
raise ValueError("Valid User ID and Memory ID are required for unlinking.")
|
||||
|
||||
if topic not in ["SemanticMemory", "PublicMemory"]:
|
||||
raise ValueError("The memory_type must be either 'SemanticMemory' or 'PublicMemory'.")
|
||||
raise ValueError(
|
||||
"The memory_type must be either 'SemanticMemory' or 'PublicMemory'."
|
||||
)
|
||||
|
||||
relationship_type = "HAS_SEMANTIC_MEMORY" if topic == "SemanticMemory" else "HAS_PUBLIC_MEMORY"
|
||||
relationship_type = (
|
||||
"HAS_SEMANTIC_MEMORY" if topic == "SemanticMemory" else "HAS_PUBLIC_MEMORY"
|
||||
)
|
||||
|
||||
try:
|
||||
unlink_cypher = f"""
|
||||
|
|
@ -815,7 +934,6 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
logging.error(f"Error unlinking {topic} from user: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def link_public_memory_to_user(self, memory_id, user_id):
|
||||
# Link an existing Public Memory node to a User node
|
||||
link_cypher = f"""
|
||||
|
|
@ -825,7 +943,7 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
"""
|
||||
self.query(link_cypher)
|
||||
|
||||
def retrieve_node_id_for_memory_type(self, topic: str = 'SemanticMemory'):
|
||||
def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
|
||||
link_cypher = f""" MATCH(publicMemory: {topic})
|
||||
RETURN
|
||||
id(publicMemory)
|
||||
|
|
@ -835,18 +953,14 @@ class Neo4jGraphDB(AbstractGraphDB):
|
|||
return node_ids
|
||||
|
||||
|
||||
|
||||
|
||||
from .networkx_graph import NetworkXGraphDB
|
||||
|
||||
|
||||
class GraphDBFactory:
|
||||
def create_graph_db(self, db_type, **kwargs):
|
||||
if db_type == 'neo4j':
|
||||
if db_type == "neo4j":
|
||||
return Neo4jGraphDB(**kwargs)
|
||||
elif db_type == 'networkx':
|
||||
elif db_type == "networkx":
|
||||
return NetworkXGraphDB(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import networkx as nx
|
|||
|
||||
|
||||
class NetworkXGraphDB:
|
||||
def __init__(self, filename='networkx_graph.pkl'):
|
||||
def __init__(self, filename="networkx_graph.pkl"):
|
||||
self.filename = filename
|
||||
try:
|
||||
self.graph = self.load_graph() # Attempt to load an existing graph
|
||||
|
|
@ -12,32 +12,36 @@ class NetworkXGraphDB:
|
|||
self.graph = nx.Graph() # Create a new graph if loading failed
|
||||
|
||||
def save_graph(self):
|
||||
""" Save the graph to a file using pickle """
|
||||
with open(self.filename, 'wb') as f:
|
||||
"""Save the graph to a file using pickle"""
|
||||
with open(self.filename, "wb") as f:
|
||||
pickle.dump(self.graph, f)
|
||||
|
||||
def load_graph(self):
|
||||
""" Load the graph from a file using pickle """
|
||||
with open(self.filename, 'rb') as f:
|
||||
"""Load the graph from a file using pickle"""
|
||||
with open(self.filename, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def create_base_cognitive_architecture(self, user_id: str):
|
||||
# Add nodes for user and memory types if they don't exist
|
||||
self.graph.add_node(user_id, type='User')
|
||||
self.graph.add_node(f"{user_id}_semantic", type='SemanticMemory')
|
||||
self.graph.add_node(f"{user_id}_episodic", type='EpisodicMemory')
|
||||
self.graph.add_node(f"{user_id}_buffer", type='Buffer')
|
||||
self.graph.add_node(user_id, type="User")
|
||||
self.graph.add_node(f"{user_id}_semantic", type="SemanticMemory")
|
||||
self.graph.add_node(f"{user_id}_episodic", type="EpisodicMemory")
|
||||
self.graph.add_node(f"{user_id}_buffer", type="Buffer")
|
||||
|
||||
# Add edges to connect user to memory types
|
||||
self.graph.add_edge(user_id, f"{user_id}_semantic", relation='HAS_SEMANTIC_MEMORY')
|
||||
self.graph.add_edge(user_id, f"{user_id}_episodic", relation='HAS_EPISODIC_MEMORY')
|
||||
self.graph.add_edge(user_id, f"{user_id}_buffer", relation='HAS_BUFFER')
|
||||
self.graph.add_edge(
|
||||
user_id, f"{user_id}_semantic", relation="HAS_SEMANTIC_MEMORY"
|
||||
)
|
||||
self.graph.add_edge(
|
||||
user_id, f"{user_id}_episodic", relation="HAS_EPISODIC_MEMORY"
|
||||
)
|
||||
self.graph.add_edge(user_id, f"{user_id}_buffer", relation="HAS_BUFFER")
|
||||
|
||||
self.save_graph() # Save the graph after modifying it
|
||||
|
||||
def delete_all_user_memories(self, user_id: str):
|
||||
# Remove nodes and edges related to the user's memories
|
||||
for memory_type in ['semantic', 'episodic', 'buffer']:
|
||||
for memory_type in ["semantic", "episodic", "buffer"]:
|
||||
memory_node = f"{user_id}_{memory_type}"
|
||||
self.graph.remove_node(memory_node)
|
||||
|
||||
|
|
@ -60,31 +64,59 @@ class NetworkXGraphDB:
|
|||
def retrieve_buffer_memory(self, user_id: str):
|
||||
return [n for n in self.graph.neighbors(f"{user_id}_buffer")]
|
||||
|
||||
def generate_graph_semantic_memory_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id):
|
||||
def generate_graph_semantic_memory_document_summary(
|
||||
self,
|
||||
document_summary,
|
||||
unique_graphdb_mapping_values,
|
||||
document_namespace,
|
||||
user_id,
|
||||
):
|
||||
for node, attributes in unique_graphdb_mapping_values.items():
|
||||
self.graph.add_node(node, **attributes)
|
||||
self.graph.add_edge(f"{user_id}_semantic", node, relation='HAS_KNOWLEDGE')
|
||||
self.graph.add_edge(f"{user_id}_semantic", node, relation="HAS_KNOWLEDGE")
|
||||
self.save_graph()
|
||||
|
||||
def generate_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id):
|
||||
self.generate_graph_semantic_memory_document_summary(document_summary, unique_graphdb_mapping_values, document_namespace, user_id)
|
||||
def generate_document_summary(
|
||||
self,
|
||||
document_summary,
|
||||
unique_graphdb_mapping_values,
|
||||
document_namespace,
|
||||
user_id,
|
||||
):
|
||||
self.generate_graph_semantic_memory_document_summary(
|
||||
document_summary, unique_graphdb_mapping_values, document_namespace, user_id
|
||||
)
|
||||
|
||||
async def get_document_categories(self, user_id):
|
||||
return [self.graph.nodes[n]['category'] for n in self.graph.neighbors(f"{user_id}_semantic") if 'category' in self.graph.nodes[n]]
|
||||
return [
|
||||
self.graph.nodes[n]["category"]
|
||||
for n in self.graph.neighbors(f"{user_id}_semantic")
|
||||
if "category" in self.graph.nodes[n]
|
||||
]
|
||||
|
||||
async def get_document_ids(self, user_id, category):
|
||||
return [n for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category]
|
||||
return [
|
||||
n
|
||||
for n in self.graph.neighbors(f"{user_id}_semantic")
|
||||
if self.graph.nodes[n].get("category") == category
|
||||
]
|
||||
|
||||
def create_document_node(self, document_summary, user_id):
|
||||
d_id = document_summary['d_id']
|
||||
d_id = document_summary["d_id"]
|
||||
self.graph.add_node(d_id, **document_summary)
|
||||
self.graph.add_edge(f"{user_id}_semantic", d_id, relation='HAS_DOCUMENT')
|
||||
self.graph.add_edge(f"{user_id}_semantic", d_id, relation="HAS_DOCUMENT")
|
||||
self.save_graph()
|
||||
|
||||
def update_document_node_with_namespace(self, user_id, vectordb_namespace, document_id):
|
||||
def update_document_node_with_namespace(
|
||||
self, user_id, vectordb_namespace, document_id
|
||||
):
|
||||
if self.graph.has_node(document_id):
|
||||
self.graph.nodes[document_id]['vectordbNamespace'] = vectordb_namespace
|
||||
self.graph.nodes[document_id]["vectordbNamespace"] = vectordb_namespace
|
||||
self.save_graph()
|
||||
|
||||
def get_namespaces_by_document_category(self, user_id, category):
|
||||
return [self.graph.nodes[n].get('vectordbNamespace') for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category]
|
||||
return [
|
||||
self.graph.nodes[n].get("vectordbNamespace")
|
||||
for n in self.graph.neighbors(f"{user_id}_semantic")
|
||||
if self.graph.nodes[n].get("category") == category
|
||||
]
|
||||
|
|
|
|||
|
|
@ -31,36 +31,47 @@ import os
|
|||
|
||||
|
||||
class DatabaseConfig:
|
||||
def __init__(self, db_type=None, db_name=None, host=None, user=None, password=None, port=None, config_file=None):
|
||||
def __init__(
|
||||
self,
|
||||
db_type=None,
|
||||
db_name=None,
|
||||
host=None,
|
||||
user=None,
|
||||
password=None,
|
||||
port=None,
|
||||
config_file=None,
|
||||
):
|
||||
if config_file:
|
||||
self.load_from_file(config_file)
|
||||
else:
|
||||
# Load default values from environment variables or use provided values
|
||||
self.db_type = db_type or os.getenv('DB_TYPE', 'sqlite')
|
||||
self.db_name = db_name or os.getenv('DB_NAME', 'database.db')
|
||||
self.host = host or os.getenv('DB_HOST', 'localhost')
|
||||
self.user = user or os.getenv('DB_USER', 'user')
|
||||
self.password = password or os.getenv('DB_PASSWORD', 'password')
|
||||
self.port = port or os.getenv('DB_PORT', '5432')
|
||||
self.db_type = db_type or os.getenv("DB_TYPE", "sqlite")
|
||||
self.db_name = db_name or os.getenv("DB_NAME", "database.db")
|
||||
self.host = host or os.getenv("DB_HOST", "localhost")
|
||||
self.user = user or os.getenv("DB_USER", "user")
|
||||
self.password = password or os.getenv("DB_PASSWORD", "password")
|
||||
self.port = port or os.getenv("DB_PORT", "5432")
|
||||
|
||||
def load_from_file(self, file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
config = json.load(file)
|
||||
self.db_type = config.get('db_type', 'sqlite')
|
||||
self.db_name = config.get('db_name', 'database.db')
|
||||
self.host = config.get('host', 'localhost')
|
||||
self.user = config.get('user', 'user')
|
||||
self.password = config.get('password', 'password')
|
||||
self.port = config.get('port', '5432')
|
||||
self.db_type = config.get("db_type", "sqlite")
|
||||
self.db_name = config.get("db_name", "database.db")
|
||||
self.host = config.get("host", "localhost")
|
||||
self.user = config.get("user", "user")
|
||||
self.password = config.get("password", "password")
|
||||
self.port = config.get("port", "5432")
|
||||
|
||||
def get_sqlalchemy_database_url(self):
|
||||
if self.db_type == 'sqlite':
|
||||
if self.db_type == "sqlite":
|
||||
db_path = Path(self.db_name).absolute() # Ensure the path is absolute
|
||||
return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path
|
||||
elif self.db_type == 'duckdb':
|
||||
db_path = Path(self.db_name).absolute() # Ensure the path is absolute for DuckDB as well
|
||||
elif self.db_type == "duckdb":
|
||||
db_path = Path(
|
||||
self.db_name
|
||||
).absolute() # Ensure the path is absolute for DuckDB as well
|
||||
return f"duckdb+aiosqlite:///{db_path}"
|
||||
elif self.db_type == 'postgresql':
|
||||
elif self.db_type == "postgresql":
|
||||
# Ensure optional parameters are handled gracefully
|
||||
port_str = f":{self.port}" if self.port else ""
|
||||
password_str = f":{self.password}" if self.password else ""
|
||||
|
|
@ -68,10 +79,18 @@ class DatabaseConfig:
|
|||
else:
|
||||
raise ValueError(f"Unsupported DB_TYPE: {self.db_type}")
|
||||
|
||||
|
||||
# Example usage with a configuration file:
|
||||
# config = DatabaseConfig(config_file='path/to/config.json')
|
||||
# Or set them programmatically:
|
||||
config = DatabaseConfig(db_type='postgresql', db_name='mydatabase', user='myuser', password='mypassword', host='myhost', port='5432')
|
||||
config = DatabaseConfig(
|
||||
db_type="postgresql",
|
||||
db_name="mydatabase",
|
||||
user="myuser",
|
||||
password="mypassword",
|
||||
host="myhost",
|
||||
port="5432",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = config.get_sqlalchemy_database_url()
|
||||
|
||||
|
|
@ -79,7 +98,7 @@ SQLALCHEMY_DATABASE_URL = config.get_sqlalchemy_database_url()
|
|||
engine = create_async_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
pool_recycle=3600,
|
||||
echo=True # Enable logging for tutorial purposes
|
||||
echo=True, # Enable logging for tutorial purposes
|
||||
)
|
||||
# Use AsyncSession for the session
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
|
|
@ -90,6 +109,7 @@ AsyncSessionLocal = sessionmaker(
|
|||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# Use asynccontextmanager to define an async context manager
|
||||
@asynccontextmanager
|
||||
async def get_db():
|
||||
|
|
@ -99,6 +119,7 @@ async def get_db():
|
|||
finally:
|
||||
await db.close()
|
||||
|
||||
|
||||
#
|
||||
# if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev':
|
||||
# host = os.environ.get('POSTGRES_HOST')
|
||||
|
|
@ -127,4 +148,3 @@ async def get_db():
|
|||
#
|
||||
# # Use the asyncpg driver for async operation
|
||||
# SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from .models.sessions import Session
|
||||
|
|
@ -9,9 +8,9 @@ from .models.metadatas import MetaDatas
|
|||
from .models.docs import DocsModel
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope(session):
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
|
|
@ -44,6 +43,8 @@ def update_entity_graph_summary(session, model, entity_id, new_value):
|
|||
return "Successfully updated entity"
|
||||
else:
|
||||
return "Entity not found"
|
||||
|
||||
|
||||
async def update_entity(session, model, entity_id, new_value):
|
||||
async with session_scope(session) as s:
|
||||
# Retrieve the entity from the database
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
from ..database import Base
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class DocsModel(Base):
|
||||
__tablename__ = 'docs'
|
||||
__tablename__ = "docs"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
||||
operation_id = Column(String, ForeignKey("operations.id"), index=True)
|
||||
doc_name = Column(String, nullable=True)
|
||||
graph_summary = Column(Boolean, nullable=True)
|
||||
memory_category = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
operations = relationship("Operation", back_populates="docs")
|
||||
operations = relationship("Operation", back_populates="docs")
|
||||
|
|
|
|||
|
|
@ -4,23 +4,27 @@ from sqlalchemy import Column, String, DateTime, ForeignKey
|
|||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
from ..database import Base
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class MemoryModel(Base):
|
||||
__tablename__ = 'memories'
|
||||
__tablename__ = "memories"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey('users.id'), index=True)
|
||||
operation_id = Column(String, ForeignKey('operations.id'), index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), index=True)
|
||||
operation_id = Column(String, ForeignKey("operations.id"), index=True)
|
||||
memory_name = Column(String, nullable=True)
|
||||
memory_category = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
methods_list = Column(String , nullable=True)
|
||||
methods_list = Column(String, nullable=True)
|
||||
attributes_list = Column(String, nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="memories")
|
||||
operation = relationship("Operation", back_populates="memories")
|
||||
metadatas = relationship("MetaDatas", back_populates="memory", cascade="all, delete-orphan")
|
||||
metadatas = relationship(
|
||||
"MetaDatas", back_populates="memory", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Memory(id={self.id}, user_id={self.user_id}, created_at={self.created_at}, updated_at={self.updated_at})>"
|
||||
|
|
|
|||
|
|
@ -4,17 +4,19 @@ from sqlalchemy import Column, String, DateTime, ForeignKey
|
|||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
from ..database import Base
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class MetaDatas(Base):
|
||||
__tablename__ = 'metadatas'
|
||||
__tablename__ = "metadatas"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey('users.id'), index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), index=True)
|
||||
version = Column(String, nullable=False)
|
||||
contract_metadata = Column(String, nullable=False)
|
||||
memory_id = Column(String, ForeignKey('memories.id'), index=True)
|
||||
memory_id = Column(String, ForeignKey("memories.id"), index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="metadatas")
|
||||
memory = relationship("MemoryModel", back_populates="metadatas")
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
|||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
from ..database import Base
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class Operation(Base):
|
||||
__tablename__ = 'operations'
|
||||
__tablename__ = "operations"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
|
||||
user_id = Column(String, ForeignKey("users.id"), index=True) # Link to User
|
||||
operation_type = Column(String, nullable=True)
|
||||
operation_status = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from ..database import Base
|
|||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = 'sessions'
|
||||
__tablename__ = "sessions"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey('users.id'), index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,17 @@ from sqlalchemy import Column, String, DateTime
|
|||
from sqlalchemy.orm import relationship
|
||||
import os
|
||||
import sys
|
||||
from .memory import MemoryModel
|
||||
from .memory import MemoryModel
|
||||
from .operation import Operation
|
||||
from .sessions import Session
|
||||
from .metadatas import MetaDatas
|
||||
from .docs import DocsModel
|
||||
|
||||
from ..database import Base
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
session_id = Column(String, nullable=True, unique=True)
|
||||
|
|
@ -22,9 +22,15 @@ class User(Base):
|
|||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
memories = relationship("MemoryModel", back_populates="user", cascade="all, delete-orphan")
|
||||
operations = relationship("Operation", back_populates="user", cascade="all, delete-orphan")
|
||||
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
memories = relationship(
|
||||
"MemoryModel", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
operations = relationship(
|
||||
"Operation", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = relationship(
|
||||
"Session", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
metadatas = relationship("MetaDatas", back_populates="user")
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import os, sys
|
||||
|
||||
# Add the parent directory to sys.path
|
||||
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
print(os.getcwd())
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# import marvin
|
||||
|
|
@ -22,6 +23,7 @@ from cognitive_architecture.database.relationaldb.models.operation import Operat
|
|||
from cognitive_architecture.database.relationaldb.models.docs import DocsModel
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from cognitive_architecture.database.relationaldb.database import engine
|
||||
|
||||
load_dotenv()
|
||||
from typing import Optional
|
||||
import time
|
||||
|
|
@ -31,7 +33,11 @@ tracemalloc.start()
|
|||
|
||||
from datetime import datetime
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from cognitive_architecture.database.vectordb.vectordb import PineconeVectorDB, WeaviateVectorDB, LanceDB
|
||||
from cognitive_architecture.database.vectordb.vectordb import (
|
||||
PineconeVectorDB,
|
||||
WeaviateVectorDB,
|
||||
LanceDB,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
import uuid
|
||||
import weaviate
|
||||
|
|
@ -43,6 +49,7 @@ from vector_db_type import VectorDBType
|
|||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
# marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
|
||||
class VectorDBFactory:
|
||||
def __init__(self):
|
||||
self.db_map = {
|
||||
|
|
@ -63,15 +70,12 @@ class VectorDBFactory:
|
|||
):
|
||||
if db_type in self.db_map:
|
||||
return self.db_map[db_type](
|
||||
user_id,
|
||||
index_name,
|
||||
memory_id,
|
||||
namespace,
|
||||
embeddings
|
||||
user_id, index_name, memory_id, namespace, embeddings
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
|
||||
class BaseMemory:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -95,21 +99,18 @@ class BaseMemory:
|
|||
self.memory_id,
|
||||
db_type=self.db_type,
|
||||
namespace=self.namespace,
|
||||
embeddings=self.embeddings
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
def init_client(self, embeddings, namespace: str):
|
||||
return self.vector_db.init_client(embeddings, namespace)
|
||||
|
||||
|
||||
|
||||
def create_field(self, field_type, **kwargs):
|
||||
field_mapping = {
|
||||
"Str": fields.Str,
|
||||
"Int": fields.Int,
|
||||
"Float": fields.Float,
|
||||
"Bool": fields.Bool,
|
||||
|
||||
}
|
||||
return field_mapping[field_type](**kwargs)
|
||||
|
||||
|
|
@ -121,7 +122,6 @@ class BaseMemory:
|
|||
dynamic_schema_instance = Schema.from_dict(dynamic_fields)()
|
||||
return dynamic_schema_instance
|
||||
|
||||
|
||||
async def get_version_from_db(self, user_id, memory_id):
|
||||
# Logic to retrieve the version from the database.
|
||||
|
||||
|
|
@ -137,11 +137,11 @@ class BaseMemory:
|
|||
)
|
||||
|
||||
if result:
|
||||
|
||||
version_in_db, created_at = result
|
||||
logging.info(f"version_in_db: {version_in_db}")
|
||||
from ast import literal_eval
|
||||
version_in_db= literal_eval(version_in_db)
|
||||
|
||||
version_in_db = literal_eval(version_in_db)
|
||||
version_in_db = version_in_db.get("version")
|
||||
return [version_in_db, created_at]
|
||||
else:
|
||||
|
|
@ -157,20 +157,33 @@ class BaseMemory:
|
|||
|
||||
# If there is no metadata, insert it.
|
||||
if version_from_db is None:
|
||||
|
||||
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, version = str(int(time.time())) ,memory_id=self.memory_id, contract_metadata=params))
|
||||
session.add(
|
||||
MetaDatas(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=self.user_id,
|
||||
version=str(int(time.time())),
|
||||
memory_id=self.memory_id,
|
||||
contract_metadata=params,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
return params
|
||||
|
||||
# If params version is higher, update the metadata.
|
||||
elif version_in_params > version_from_db[0]:
|
||||
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, memory_id=self.memory_id, contract_metadata=params))
|
||||
session.add(
|
||||
MetaDatas(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=self.user_id,
|
||||
memory_id=self.memory_id,
|
||||
contract_metadata=params,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
return params
|
||||
else:
|
||||
return params
|
||||
|
||||
|
||||
async def add_memories(
|
||||
self,
|
||||
observation: Optional[str] = None,
|
||||
|
|
@ -179,11 +192,14 @@ class BaseMemory:
|
|||
namespace: Optional[str] = None,
|
||||
custom_fields: Optional[str] = None,
|
||||
embeddings: Optional[str] = None,
|
||||
|
||||
):
|
||||
return await self.vector_db.add_memories(
|
||||
observation=observation, loader_settings=loader_settings,
|
||||
params=params, namespace=namespace, metadata_schema_class = None, embeddings=embeddings
|
||||
observation=observation,
|
||||
loader_settings=loader_settings,
|
||||
params=params,
|
||||
namespace=namespace,
|
||||
metadata_schema_class=None,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
# Add other db_type conditions if necessary
|
||||
|
||||
|
|
@ -200,17 +216,15 @@ class BaseMemory:
|
|||
logging.info(observation)
|
||||
|
||||
return await self.vector_db.fetch_memories(
|
||||
observation=observation, search_type= search_type, params=params,
|
||||
observation=observation,
|
||||
search_type=search_type,
|
||||
params=params,
|
||||
namespace=namespace,
|
||||
n_of_observations=n_of_observations
|
||||
n_of_observations=n_of_observations,
|
||||
)
|
||||
|
||||
async def delete_memories(self, namespace:str, params: Optional[str] = None):
|
||||
return await self.vector_db.delete_memories(namespace,params)
|
||||
|
||||
|
||||
async def count_memories(self, namespace:str, params: Optional[str] = None):
|
||||
return await self.vector_db.count_memories(namespace,params)
|
||||
|
||||
|
||||
async def delete_memories(self, namespace: str, params: Optional[str] = None):
|
||||
return await self.vector_db.delete_memories(namespace, params)
|
||||
|
||||
async def count_memories(self, namespace: str, params: Optional[str] = None):
|
||||
return await self.vector_db.count_memories(namespace, params)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class ChunkStrategy(Enum):
|
||||
EXACT = 'exact'
|
||||
PARAGRAPH = 'paragraph'
|
||||
SENTENCE = 'sentence'
|
||||
VANILLA = 'vanilla'
|
||||
SUMMARY = 'summary'
|
||||
"""Chunking strategies for the vector database."""
|
||||
EXACT = "exact"
|
||||
PARAGRAPH = "paragraph"
|
||||
SENTENCE = "sentence"
|
||||
VANILLA = "vanilla"
|
||||
SUMMARY = "summary"
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
from cognitive_architecture.database.vectordb.chunkers.chunk_strategy import ChunkStrategy
|
||||
import re
|
||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
"""Module for chunking text data based on various strategies."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from cognitive_architecture.database.vectordb.chunkers.chunk_strategy import ChunkStrategy
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
"""Chunk the given source data into smaller parts based on the specified strategy."""
|
||||
if chunk_strategy == ChunkStrategy.VANILLA:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
|
||||
chunked_data = chunk_data_by_paragraph(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||
|
|
@ -21,68 +25,41 @@ def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_ove
|
|||
|
||||
|
||||
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
|
||||
# adapt this for different chunking strategies
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
# try:
|
||||
# pages = text_splitter.create_documents([source_data])
|
||||
# except:
|
||||
# try:
|
||||
"""Chunk the given source data into smaller parts using a vanilla strategy."""
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size
|
||||
, chunk_overlap=chunk_overlap
|
||||
, length_function=len)
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
# except:
|
||||
# pages = text_splitter.create_documents(source_data.content)
|
||||
# pages = source_data.load_and_split()
|
||||
return pages
|
||||
|
||||
|
||||
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
|
||||
"""
|
||||
Chunk the given source data into smaller parts, returning the first five and last five chunks.
|
||||
|
||||
Parameters:
|
||||
- source_data (str): The source data to be chunked.
|
||||
- chunk_size (int): The size of each chunk.
|
||||
- chunk_overlap (int): The overlap between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
- List: A list containing the first five and last five chunks of the chunked source data.
|
||||
"""
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
|
||||
"""Chunk the given source data into smaller parts, focusing on summarizing content."""
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size
|
||||
, chunk_overlap=chunk_overlap
|
||||
, length_function=len)
|
||||
try:
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
except:
|
||||
except Exception as e:
|
||||
pages = text_splitter.create_documents(source_data.content)
|
||||
logging.error(f"An error occurred: %s {str(e)}")
|
||||
|
||||
# Return the first 5 and last 5 chunks
|
||||
if len(pages) > 10:
|
||||
return pages[:5] + pages[-5:]
|
||||
else:
|
||||
return pages # Return all chunks if there are 10 or fewer
|
||||
return pages
|
||||
|
||||
|
||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||
"""Chunk the data into exact sizes as specified, without considering content."""
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size - chunk_overlap)]
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
"""Chunk the data by sentences, ensuring each chunk does not exceed the specified size."""
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentence_endings = r"(?<=[.!?…]) +"
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
|
|
@ -96,6 +73,7 @@ def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
|||
|
||||
|
||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
||||
"""Chunk the data by paragraphs, with consideration for chunk size and overlap."""
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
|
|
@ -103,20 +81,13 @@ def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
|||
start_idx = 0
|
||||
|
||||
while start_idx < total_length:
|
||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||
end_idx = min(start_idx + chunk_size, total_length)
|
||||
next_paragraph_index = data.find("\n\n", start_idx + check_bound, end_idx)
|
||||
|
||||
# Find the next paragraph index within the current chunk and bound
|
||||
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
||||
|
||||
# If a next paragraph index is found within the current chunk
|
||||
if next_paragraph_index != -1:
|
||||
# Update end_idx to include the paragraph delimiter
|
||||
end_idx = next_paragraph_index + 2
|
||||
|
||||
chunks.append(data[start_idx:end_idx + overlap])
|
||||
|
||||
# Update start_idx to be the current end_idx
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
return chunks
|
||||
|
|
|
|||
|
|
@ -7,17 +7,20 @@ from .response import Response
|
|||
|
||||
|
||||
class CogneeManager:
|
||||
def __init__(self, embeddings: Embeddings = None,
|
||||
vector_db: VectorDB = None,
|
||||
vector_db_key: str = None,
|
||||
embedding_api_key: str = None,
|
||||
webhook_url: str = None,
|
||||
lines_per_batch: int = 1000,
|
||||
webhook_key: str = None,
|
||||
document_id: str = None,
|
||||
chunk_validation_url: str = None,
|
||||
internal_api_key: str = "test123",
|
||||
base_url="http://localhost:8000"):
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings = None,
|
||||
vector_db: VectorDB = None,
|
||||
vector_db_key: str = None,
|
||||
embedding_api_key: str = None,
|
||||
webhook_url: str = None,
|
||||
lines_per_batch: int = 1000,
|
||||
webhook_key: str = None,
|
||||
document_id: str = None,
|
||||
chunk_validation_url: str = None,
|
||||
internal_api_key: str = "test123",
|
||||
base_url="http://localhost:8000",
|
||||
):
|
||||
self.embeddings = embeddings if embeddings else Embeddings()
|
||||
self.vector_db = vector_db if vector_db else VectorDB()
|
||||
self.webhook_url = webhook_url
|
||||
|
|
@ -32,12 +35,12 @@ class CogneeManager:
|
|||
|
||||
def serialize(self):
|
||||
data = {
|
||||
'EmbeddingsMetadata': json.dumps(self.embeddings.serialize()),
|
||||
'VectorDBMetadata': json.dumps(self.vector_db.serialize()),
|
||||
'WebhookURL': self.webhook_url,
|
||||
'LinesPerBatch': self.lines_per_batch,
|
||||
'DocumentID': self.document_id,
|
||||
'ChunkValidationURL': self.chunk_validation_url,
|
||||
"EmbeddingsMetadata": json.dumps(self.embeddings.serialize()),
|
||||
"VectorDBMetadata": json.dumps(self.vector_db.serialize()),
|
||||
"WebhookURL": self.webhook_url,
|
||||
"LinesPerBatch": self.lines_per_batch,
|
||||
"DocumentID": self.document_id,
|
||||
"ChunkValidationURL": self.chunk_validation_url,
|
||||
}
|
||||
return {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
|
|
@ -49,11 +52,22 @@ class CogneeManager:
|
|||
|
||||
data = self.serialize()
|
||||
headers = self.generate_headers()
|
||||
multipart_form_data = [('file', (os.path.basename(filepath), open(filepath, 'rb'), 'application/octet-stream'))
|
||||
for filepath in file_paths]
|
||||
multipart_form_data = [
|
||||
(
|
||||
"file",
|
||||
(
|
||||
os.path.basename(filepath),
|
||||
open(filepath, "rb"),
|
||||
"application/octet-stream",
|
||||
),
|
||||
)
|
||||
for filepath in file_paths
|
||||
]
|
||||
|
||||
print(f"embedding {len(file_paths)} documents at {url}")
|
||||
response = requests.post(url, files=multipart_form_data, headers=headers, stream=True, data=data)
|
||||
response = requests.post(
|
||||
url, files=multipart_form_data, headers=headers, stream=True, data=data
|
||||
)
|
||||
|
||||
if response.status_code == 500:
|
||||
print(response.text)
|
||||
|
|
@ -75,9 +89,7 @@ class CogneeManager:
|
|||
"Authorization": self.internal_api_key,
|
||||
}
|
||||
|
||||
data = {
|
||||
'JobIDs': job_ids
|
||||
}
|
||||
data = {"JobIDs": job_ids}
|
||||
|
||||
print(f"retrieving job statuses for {len(job_ids)} jobs at {url}")
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
|
@ -101,9 +113,7 @@ class CogneeManager:
|
|||
data = self.serialize()
|
||||
headers = self.generate_headers()
|
||||
|
||||
files = {
|
||||
'SourceData': open(filepath, 'rb')
|
||||
}
|
||||
files = {"SourceData": open(filepath, "rb")}
|
||||
|
||||
print(f"embedding document at file path {filepath} at {url}")
|
||||
response = requests.post(url, headers=headers, data=data, files=files)
|
||||
|
|
@ -146,6 +156,6 @@ class CogneeManager:
|
|||
"Authorization": self.internal_api_key,
|
||||
"X-EmbeddingAPI-Key": self.embeddings_api_key,
|
||||
"X-VectorDB-Key": self.vector_db_key,
|
||||
"X-Webhook-Key": self.webhook_key
|
||||
"X-Webhook-Key": self.webhook_key,
|
||||
}
|
||||
return {k: v for k, v in headers.items() if v is not None}
|
||||
return {k: v for k, v in headers.items() if v is not None}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,15 @@ from ..chunkers.chunk_strategy import ChunkStrategy
|
|||
|
||||
|
||||
class Embeddings:
|
||||
def __init__(self, embeddings_type: EmbeddingsType = EmbeddingsType.OPEN_AI,
|
||||
chunk_size: int = 256,
|
||||
chunk_overlap: int = 128,
|
||||
chunk_strategy: ChunkStrategy = ChunkStrategy.EXACT,
|
||||
docker_image: str = None,
|
||||
hugging_face_model_name: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_type: EmbeddingsType = EmbeddingsType.OPEN_AI,
|
||||
chunk_size: int = 256,
|
||||
chunk_overlap: int = 128,
|
||||
chunk_strategy: ChunkStrategy = ChunkStrategy.EXACT,
|
||||
docker_image: str = None,
|
||||
hugging_face_model_name: str = None,
|
||||
):
|
||||
self.embeddings_type = embeddings_type
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
|
@ -18,12 +21,14 @@ class Embeddings:
|
|||
|
||||
def serialize(self):
|
||||
data = {
|
||||
'embeddings_type': self.embeddings_type.name if self.embeddings_type else None,
|
||||
'chunk_size': self.chunk_size,
|
||||
'chunk_overlap': self.chunk_overlap,
|
||||
'chunk_strategy': self.chunk_strategy.name if self.chunk_strategy else None,
|
||||
'docker_image': self.docker_image,
|
||||
'hugging_face_model_name': self.hugging_face_model_name
|
||||
"embeddings_type": self.embeddings_type.name
|
||||
if self.embeddings_type
|
||||
else None,
|
||||
"chunk_size": self.chunk_size,
|
||||
"chunk_overlap": self.chunk_overlap,
|
||||
"chunk_strategy": self.chunk_strategy.name if self.chunk_strategy else None,
|
||||
"docker_image": self.docker_image,
|
||||
"hugging_face_model_name": self.hugging_face_model_name,
|
||||
}
|
||||
|
||||
return {k: v for k, v in data.items() if v is not None}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class EmbeddingsType(Enum):
|
||||
OPEN_AI = 'open_ai'
|
||||
COHERE = 'cohere'
|
||||
SELF_HOSTED = 'self_hosted'
|
||||
HUGGING_FACE = 'hugging_face'
|
||||
IMAGE = 'image'
|
||||
OPEN_AI = "open_ai"
|
||||
COHERE = "cohere"
|
||||
SELF_HOSTED = "self_hosted"
|
||||
HUGGING_FACE = "hugging_face"
|
||||
IMAGE = "image"
|
||||
|
|
|
|||
|
|
@ -15,4 +15,4 @@ class Job:
|
|||
return "Job(" + ", ".join(attributes) + ")"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
return self.__str__()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ import os
|
|||
import sys
|
||||
|
||||
from cognitive_architecture.database.vectordb.chunkers.chunkers import chunk_data
|
||||
from cognitive_architecture.shared.language_processing import translate_text, detect_language
|
||||
from cognitive_architecture.shared.language_processing import (
|
||||
translate_text,
|
||||
detect_language,
|
||||
)
|
||||
|
||||
from langchain.document_loaders import UnstructuredURLLoader
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
|
|
@ -15,28 +18,36 @@ import requests
|
|||
|
||||
|
||||
async def fetch_pdf_content(file_url):
|
||||
response = requests.get(file_url)
|
||||
response = requests.get(file_url)
|
||||
pdf_stream = BytesIO(response.content)
|
||||
with fitz.open(stream=pdf_stream, filetype='pdf') as doc:
|
||||
with fitz.open(stream=pdf_stream, filetype="pdf") as doc:
|
||||
return "".join(page.get_text() for page in doc)
|
||||
|
||||
|
||||
async def fetch_text_content(file_url):
|
||||
loader = UnstructuredURLLoader(urls=file_url)
|
||||
return loader.load()
|
||||
|
||||
async def process_content(content, metadata, loader_strategy, chunk_size, chunk_overlap):
|
||||
pages = chunk_data(chunk_strategy=loader_strategy, source_data=content, chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
|
||||
async def process_content(
|
||||
content, metadata, loader_strategy, chunk_size, chunk_overlap
|
||||
):
|
||||
pages = chunk_data(
|
||||
chunk_strategy=loader_strategy,
|
||||
source_data=content,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"metadata": "None"}
|
||||
|
||||
chunk_count= 0
|
||||
chunk_count = 0
|
||||
|
||||
for chunk in pages:
|
||||
chunk_count+=1
|
||||
chunk_count += 1
|
||||
chunk.metadata = metadata
|
||||
chunk.metadata["chunk_count"]=chunk_count
|
||||
chunk.metadata["chunk_count"] = chunk_count
|
||||
if detect_language(pages) != "en":
|
||||
logging.info("Translating Page")
|
||||
for page in pages:
|
||||
|
|
@ -45,6 +56,7 @@ async def process_content(content, metadata, loader_strategy, chunk_size, chunk
|
|||
|
||||
return pages
|
||||
|
||||
|
||||
async def _document_loader(observation: str, loader_settings: dict):
|
||||
document_format = loader_settings.get("format", "text")
|
||||
loader_strategy = loader_settings.get("strategy", "VANILLA")
|
||||
|
|
@ -65,7 +77,13 @@ async def _document_loader(observation: str, loader_settings: dict):
|
|||
else:
|
||||
raise ValueError(f"Unsupported document format: {document_format}")
|
||||
|
||||
pages = await process_content(content, metadata=None, loader_strategy=loader_strategy, chunk_size= chunk_size, chunk_overlap= chunk_overlap)
|
||||
pages = await process_content(
|
||||
content,
|
||||
metadata=None,
|
||||
loader_strategy=loader_strategy,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
chunked_doc.append(pages)
|
||||
|
||||
elif loader_settings.get("source") == "DEVICE":
|
||||
|
|
@ -76,17 +94,28 @@ async def _document_loader(observation: str, loader_settings: dict):
|
|||
documents = loader.load()
|
||||
for document in documents:
|
||||
# print ("Document: ", document.page_content)
|
||||
pages = await process_content(content= str(document.page_content), metadata=document.metadata, loader_strategy= loader_strategy, chunk_size = chunk_size, chunk_overlap = chunk_overlap)
|
||||
pages = await process_content(
|
||||
content=str(document.page_content),
|
||||
metadata=document.metadata,
|
||||
loader_strategy=loader_strategy,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
chunked_doc.append(pages)
|
||||
else:
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
loader = PyPDFLoader(loader_settings.get("single_document_path"))
|
||||
documents= loader.load()
|
||||
documents = loader.load()
|
||||
|
||||
for document in documents:
|
||||
pages = await process_content(content=str(document.page_content), metadata=document.metadata,
|
||||
loader_strategy=loader_strategy, chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
pages = await process_content(
|
||||
content=str(document.page_content),
|
||||
metadata=document.metadata,
|
||||
loader_strategy=loader_strategy,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
chunked_doc.append(pages)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source type: {loader_settings.get('source')}")
|
||||
|
|
@ -94,7 +123,6 @@ async def _document_loader(observation: str, loader_settings: dict):
|
|||
return chunked_doc
|
||||
|
||||
|
||||
|
||||
# async def _document_loader( observation: str, loader_settings: dict):
|
||||
#
|
||||
# document_format = loader_settings.get("format", "text")
|
||||
|
|
@ -196,11 +224,3 @@ async def _document_loader(observation: str, loader_settings: dict):
|
|||
# else:
|
||||
# raise ValueError(f"Error: ")
|
||||
# return chunked_doc
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,19 @@ from .job import Job
|
|||
|
||||
|
||||
class Response:
|
||||
def __init__(self, error=None, message=None, successful_uploads=None, failed_uploads=None,
|
||||
empty_files_count=None, duplicate_files_count=None, job_id=None,
|
||||
jobs=None, job_status=None, status_code=None):
|
||||
def __init__(
|
||||
self,
|
||||
error=None,
|
||||
message=None,
|
||||
successful_uploads=None,
|
||||
failed_uploads=None,
|
||||
empty_files_count=None,
|
||||
duplicate_files_count=None,
|
||||
job_id=None,
|
||||
jobs=None,
|
||||
job_status=None,
|
||||
status_code=None,
|
||||
):
|
||||
self.error = error
|
||||
self.message = message
|
||||
self.successful_uploads = successful_uploads
|
||||
|
|
@ -18,33 +28,37 @@ class Response:
|
|||
|
||||
@classmethod
|
||||
def from_json(cls, json_dict, status_code):
|
||||
successful_uploads = cls._convert_successful_uploads_to_jobs(json_dict.get('successful_uploads', None))
|
||||
jobs = cls._convert_to_jobs(json_dict.get('Jobs', None))
|
||||
successful_uploads = cls._convert_successful_uploads_to_jobs(
|
||||
json_dict.get("successful_uploads", None)
|
||||
)
|
||||
jobs = cls._convert_to_jobs(json_dict.get("Jobs", None))
|
||||
|
||||
return cls(
|
||||
error=json_dict.get('error'),
|
||||
message=json_dict.get('message'),
|
||||
error=json_dict.get("error"),
|
||||
message=json_dict.get("message"),
|
||||
successful_uploads=successful_uploads,
|
||||
failed_uploads=json_dict.get('failed_uploads'),
|
||||
empty_files_count=json_dict.get('empty_files_count'),
|
||||
duplicate_files_count=json_dict.get('duplicate_files_count'),
|
||||
job_id=json_dict.get('JobID'),
|
||||
failed_uploads=json_dict.get("failed_uploads"),
|
||||
empty_files_count=json_dict.get("empty_files_count"),
|
||||
duplicate_files_count=json_dict.get("duplicate_files_count"),
|
||||
job_id=json_dict.get("JobID"),
|
||||
jobs=jobs,
|
||||
job_status=json_dict.get('JobStatus'),
|
||||
status_code=status_code
|
||||
job_status=json_dict.get("JobStatus"),
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_successful_uploads_to_jobs(cls, successful_uploads):
|
||||
if not successful_uploads:
|
||||
return None
|
||||
return [Job(filename=key, job_id=val) for key, val in successful_uploads.items()]
|
||||
return [
|
||||
Job(filename=key, job_id=val) for key, val in successful_uploads.items()
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _convert_to_jobs(cls, jobs):
|
||||
if not jobs:
|
||||
return None
|
||||
return [Job(job_id=job['JobID'], job_status=job['JobStatus']) for job in jobs]
|
||||
return [Job(job_id=job["JobID"], job_status=job["JobStatus"]) for job in jobs]
|
||||
|
||||
def __str__(self):
|
||||
attributes = []
|
||||
|
|
@ -69,4 +83,4 @@ class Response:
|
|||
if self.status_code is not None:
|
||||
attributes.append(f"status_code: {self.status_code}")
|
||||
|
||||
return "Response(" + ", ".join(attributes) + ")"
|
||||
return "Response(" + ", ".join(attributes) + ")"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class VectorDBType(Enum):
|
||||
PINECONE = 'pinecone'
|
||||
WEAVIATE = 'weaviate'
|
||||
MILVUS = 'milvus'
|
||||
QDRANT = 'qdrant'
|
||||
DEEPLAKE = 'deeplake'
|
||||
VESPA = 'vespa'
|
||||
PGVECTOR = 'pgvector'
|
||||
REDIS = 'redis'
|
||||
LANCEDB = 'lancedb'
|
||||
MONGODB = 'mongodb'
|
||||
FAISS = 'faiss'
|
||||
PINECONE = "pinecone"
|
||||
WEAVIATE = "weaviate"
|
||||
MILVUS = "milvus"
|
||||
QDRANT = "qdrant"
|
||||
DEEPLAKE = "deeplake"
|
||||
VESPA = "vespa"
|
||||
PGVECTOR = "pgvector"
|
||||
REDIS = "redis"
|
||||
LANCEDB = "lancedb"
|
||||
MONGODB = "mongodb"
|
||||
FAISS = "faiss"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
|
||||
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
|
||||
import logging
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from marshmallow import Schema, fields
|
||||
from cognitive_architecture.database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
# Add the parent directory to sys.path
|
||||
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ logging.basicConfig(level=logging.INFO)
|
|||
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
|
||||
from weaviate.gql.get import HybridFusion
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
import os
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
|
@ -28,6 +29,8 @@ config.load()
|
|||
LTM_MEMORY_ID_DEFAULT = "00000"
|
||||
ST_MEMORY_ID_DEFAULT = "0000"
|
||||
BUFFER_ID_DEFAULT = "0000"
|
||||
|
||||
|
||||
class VectorDB:
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
|
|
@ -37,7 +40,7 @@ class VectorDB:
|
|||
index_name: str,
|
||||
memory_id: str,
|
||||
namespace: str = None,
|
||||
embeddings = None,
|
||||
embeddings=None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.index_name = index_name
|
||||
|
|
@ -45,6 +48,7 @@ class VectorDB:
|
|||
self.memory_id = memory_id
|
||||
self.embeddings = embeddings
|
||||
|
||||
|
||||
class PineconeVectorDB(VectorDB):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -54,13 +58,21 @@ class PineconeVectorDB(VectorDB):
|
|||
# Pinecone initialization logic
|
||||
pass
|
||||
|
||||
|
||||
import langchain.embeddings
|
||||
|
||||
|
||||
class WeaviateVectorDB(VectorDB):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
|
||||
self.init_weaviate(embeddings=self.embeddings, namespace=self.namespace)
|
||||
|
||||
def init_weaviate(self, embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")), namespace=None,retriever_type="",):
|
||||
def init_weaviate(
|
||||
self,
|
||||
embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")),
|
||||
namespace=None,
|
||||
retriever_type="",
|
||||
):
|
||||
# Weaviate initialization logic
|
||||
auth_config = weaviate.auth.AuthApiKey(
|
||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
||||
|
|
@ -91,15 +103,16 @@ class WeaviateVectorDB(VectorDB):
|
|||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
else :
|
||||
else:
|
||||
return client
|
||||
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# store = InMemoryStore()
|
||||
# retriever = ParentDocumentRetriever(
|
||||
# vectorstore=vectorstore,
|
||||
# docstore=store,
|
||||
# child_splitter=child_splitter,
|
||||
# )
|
||||
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# store = InMemoryStore()
|
||||
# retriever = ParentDocumentRetriever(
|
||||
# vectorstore=vectorstore,
|
||||
# docstore=store,
|
||||
# child_splitter=child_splitter,
|
||||
# )
|
||||
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||
|
|
@ -111,10 +124,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
:param metadata_schema_class: Custom metadata schema class (optional).
|
||||
:return: A list containing the validated document data.
|
||||
"""
|
||||
document_data = {
|
||||
"metadata": params,
|
||||
"page_content": observation
|
||||
}
|
||||
document_data = {"metadata": params, "page_content": observation}
|
||||
|
||||
def get_document_schema():
|
||||
class DynamicDocumentSchema(Schema):
|
||||
|
|
@ -128,30 +138,42 @@ class WeaviateVectorDB(VectorDB):
|
|||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
def _stuct(self, observation, params, metadata_schema_class =None):
|
||||
def _stuct(self, observation, params, metadata_schema_class=None):
|
||||
"""Utility function to create the document structure with optional custom fields."""
|
||||
# Construct document data
|
||||
document_data = {
|
||||
"metadata": params,
|
||||
"page_content": observation
|
||||
}
|
||||
document_data = {"metadata": params, "page_content": observation}
|
||||
|
||||
def get_document_schema():
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
return DynamicDocumentSchema
|
||||
|
||||
# Validate and deserialize # Default to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema()
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, metadata_schema_class=None, embeddings = 'hybrid'):
|
||||
|
||||
async def add_memories(
|
||||
self,
|
||||
observation,
|
||||
loader_settings=None,
|
||||
params=None,
|
||||
namespace=None,
|
||||
metadata_schema_class=None,
|
||||
embeddings="hybrid",
|
||||
):
|
||||
# Update Weaviate memories here
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
params['user_id'] = self.user_id
|
||||
params["user_id"] = self.user_id
|
||||
logging.info("User id is %s", self.user_id)
|
||||
retriever = self.init_weaviate(embeddings=OpenAIEmbeddings(),namespace = namespace, retriever_type="single_document_context")
|
||||
retriever = self.init_weaviate(
|
||||
embeddings=OpenAIEmbeddings(),
|
||||
namespace=namespace,
|
||||
retriever_type="single_document_context",
|
||||
)
|
||||
if loader_settings:
|
||||
# Assuming _document_loader returns a list of documents
|
||||
documents = await _document_loader(observation, loader_settings)
|
||||
|
|
@ -160,27 +182,49 @@ class WeaviateVectorDB(VectorDB):
|
|||
for doc_list in documents:
|
||||
for doc in doc_list:
|
||||
chunk_count += 1
|
||||
params['chunk_count'] = doc.metadata.get("chunk_count", "None")
|
||||
logging.info("Loading document with provided loader settings %s", str(doc))
|
||||
params['source'] = doc.metadata.get("source", "None")
|
||||
params["chunk_count"] = doc.metadata.get("chunk_count", "None")
|
||||
logging.info(
|
||||
"Loading document with provided loader settings %s", str(doc)
|
||||
)
|
||||
params["source"] = doc.metadata.get("source", "None")
|
||||
logging.info("Params are %s", str(params))
|
||||
retriever.add_documents([
|
||||
Document(metadata=params, page_content=doc.page_content)])
|
||||
retriever.add_documents(
|
||||
[Document(metadata=params, page_content=doc.page_content)]
|
||||
)
|
||||
else:
|
||||
chunk_count = 0
|
||||
from cognitive_architecture.database.vectordb.chunkers.chunkers import chunk_data
|
||||
documents = [chunk_data(chunk_strategy="VANILLA", source_data=observation, chunk_size=300,
|
||||
chunk_overlap=20)]
|
||||
from cognitive_architecture.database.vectordb.chunkers.chunkers import (
|
||||
chunk_data,
|
||||
)
|
||||
|
||||
documents = [
|
||||
chunk_data(
|
||||
chunk_strategy="VANILLA",
|
||||
source_data=observation,
|
||||
chunk_size=300,
|
||||
chunk_overlap=20,
|
||||
)
|
||||
]
|
||||
for doc in documents[0]:
|
||||
chunk_count += 1
|
||||
params['chunk_order'] = chunk_count
|
||||
params['source'] = "User loaded"
|
||||
logging.info("Loading document with default loader settings %s", str(doc))
|
||||
params["chunk_order"] = chunk_count
|
||||
params["source"] = "User loaded"
|
||||
logging.info(
|
||||
"Loading document with default loader settings %s", str(doc)
|
||||
)
|
||||
logging.info("Params are %s", str(params))
|
||||
retriever.add_documents([
|
||||
Document(metadata=params, page_content=doc.page_content)])
|
||||
retriever.add_documents(
|
||||
[Document(metadata=params, page_content=doc.page_content)]
|
||||
)
|
||||
|
||||
async def fetch_memories(self, observation: str, namespace: str = None, search_type: str = 'hybrid',params=None, **kwargs):
|
||||
async def fetch_memories(
|
||||
self,
|
||||
observation: str,
|
||||
namespace: str = None,
|
||||
search_type: str = "hybrid",
|
||||
params=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Fetch documents from weaviate.
|
||||
|
||||
|
|
@ -196,12 +240,9 @@ class WeaviateVectorDB(VectorDB):
|
|||
Example:
|
||||
fetch_memories(query="some query", search_type='text', additional_param='value')
|
||||
"""
|
||||
client = self.init_weaviate(namespace =self.namespace)
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
if search_type is None:
|
||||
search_type = 'hybrid'
|
||||
|
||||
|
||||
|
||||
search_type = "hybrid"
|
||||
|
||||
if not namespace:
|
||||
namespace = self.namespace
|
||||
|
|
@ -222,37 +263,41 @@ class WeaviateVectorDB(VectorDB):
|
|||
for prop in class_obj["properties"]
|
||||
]
|
||||
|
||||
base_query = client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
).with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(params_user_id).with_limit(10)
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
)
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", "distance"]
|
||||
)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
)
|
||||
|
||||
n_of_observations = kwargs.get('n_of_observations', 2)
|
||||
n_of_observations = kwargs.get("n_of_observations", 2)
|
||||
|
||||
# try:
|
||||
if search_type == 'text':
|
||||
if search_type == "text":
|
||||
query_output = (
|
||||
base_query
|
||||
.with_near_text({"concepts": [observation]})
|
||||
base_query.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'hybrid':
|
||||
elif search_type == "hybrid":
|
||||
query_output = (
|
||||
base_query
|
||||
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
base_query.with_hybrid(
|
||||
query=observation, fusion_type=HybridFusion.RELATIVE_SCORE
|
||||
)
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'bm25':
|
||||
elif search_type == "bm25":
|
||||
query_output = (
|
||||
base_query
|
||||
.with_bm25(query=observation)
|
||||
base_query.with_bm25(query=observation)
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'summary':
|
||||
elif search_type == "summary":
|
||||
filter_object = {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
|
|
@ -266,20 +311,32 @@ class WeaviateVectorDB(VectorDB):
|
|||
"operator": "LessThan",
|
||||
"valueNumber": 30,
|
||||
},
|
||||
]
|
||||
],
|
||||
}
|
||||
base_query = client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
).with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(filter_object).with_limit(30)
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
list(list_objects_of_class(namespace, client.schema.get())),
|
||||
)
|
||||
.with_additional(
|
||||
[
|
||||
"id",
|
||||
"creationTimeUnix",
|
||||
"lastUpdateTimeUnix",
|
||||
"score",
|
||||
"distance",
|
||||
]
|
||||
)
|
||||
.with_where(filter_object)
|
||||
.with_limit(30)
|
||||
)
|
||||
query_output = (
|
||||
base_query
|
||||
# .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
.do()
|
||||
)
|
||||
|
||||
elif search_type == 'summary_filter_by_object_name':
|
||||
elif search_type == "summary_filter_by_object_name":
|
||||
filter_object = {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
|
|
@ -293,17 +350,27 @@ class WeaviateVectorDB(VectorDB):
|
|||
"operator": "Equal",
|
||||
"valueText": params,
|
||||
},
|
||||
]
|
||||
],
|
||||
}
|
||||
base_query = client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
).with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(filter_object).with_limit(30).with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
query_output = (
|
||||
base_query
|
||||
.do()
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
list(list_objects_of_class(namespace, client.schema.get())),
|
||||
)
|
||||
.with_additional(
|
||||
[
|
||||
"id",
|
||||
"creationTimeUnix",
|
||||
"lastUpdateTimeUnix",
|
||||
"score",
|
||||
"distance",
|
||||
]
|
||||
)
|
||||
.with_where(filter_object)
|
||||
.with_limit(30)
|
||||
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
)
|
||||
query_output = base_query.do()
|
||||
# from weaviate.classes import Filter
|
||||
# client = weaviate.connect_to_wcs(
|
||||
# cluster_url=config.weaviate_url,
|
||||
|
|
@ -311,20 +378,18 @@ class WeaviateVectorDB(VectorDB):
|
|||
# )
|
||||
|
||||
return query_output
|
||||
elif search_type == 'generate':
|
||||
generate_prompt = kwargs.get('generate_prompt', "")
|
||||
elif search_type == "generate":
|
||||
generate_prompt = kwargs.get("generate_prompt", "")
|
||||
query_output = (
|
||||
base_query
|
||||
.with_generate(single_prompt=observation)
|
||||
base_query.with_generate(single_prompt=observation)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'generate_grouped':
|
||||
generate_prompt = kwargs.get('generate_prompt', "")
|
||||
elif search_type == "generate_grouped":
|
||||
generate_prompt = kwargs.get("generate_prompt", "")
|
||||
query_output = (
|
||||
base_query
|
||||
.with_generate(grouped_task=observation)
|
||||
base_query.with_generate(grouped_task=observation)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
|
|
@ -338,12 +403,10 @@ class WeaviateVectorDB(VectorDB):
|
|||
|
||||
return query_output
|
||||
|
||||
|
||||
|
||||
async def delete_memories(self, namespace:str, params: dict = None):
|
||||
async def delete_memories(self, namespace: str, params: dict = None):
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
client = self.init_weaviate(namespace = self.namespace)
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
if params:
|
||||
where_filter = {
|
||||
"path": ["id"],
|
||||
|
|
@ -366,7 +429,6 @@ class WeaviateVectorDB(VectorDB):
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
async def count_memories(self, namespace: str = None, params: dict = None) -> int:
|
||||
"""
|
||||
Count memories in a Weaviate database.
|
||||
|
|
@ -380,7 +442,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
|
||||
client = self.init_weaviate(namespace =namespace)
|
||||
client = self.init_weaviate(namespace=namespace)
|
||||
|
||||
try:
|
||||
object_count = client.query.aggregate(namespace).with_meta_count().do()
|
||||
|
|
@ -391,7 +453,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
return 0
|
||||
|
||||
def update_memories(self, observation, namespace: str, params: dict = None):
|
||||
client = self.init_weaviate(namespace = self.namespace)
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
|
||||
client.data_object.update(
|
||||
data_object={
|
||||
|
|
@ -416,12 +478,15 @@ class WeaviateVectorDB(VectorDB):
|
|||
)
|
||||
return
|
||||
|
||||
|
||||
import os
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
class LanceDB(VectorDB):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -434,21 +499,28 @@ class LanceDB(VectorDB):
|
|||
db = lancedb.connect(uri, api_key=os.getenv("LANCEDB_API_KEY"))
|
||||
return db
|
||||
|
||||
def create_table(self, name: str, schema: Optional[pa.Schema] = None, data: Optional[pd.DataFrame] = None):
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
data: Optional[pd.DataFrame] = None,
|
||||
):
|
||||
# Create a table in LanceDB. If schema is not provided, it will be inferred from the data.
|
||||
if data is not None and schema is None:
|
||||
schema = pa.Schema.from_pandas(data)
|
||||
table = self.db.create_table(name, schema=schema)
|
||||
if data is not None:
|
||||
table.add(data.to_dict('records'))
|
||||
table.add(data.to_dict("records"))
|
||||
return table
|
||||
|
||||
def add_memories(self, table_name: str, data: pd.DataFrame):
|
||||
# Add data to an existing table in LanceDB
|
||||
table = self.db.open_table(table_name)
|
||||
table.add(data.to_dict('records'))
|
||||
table.add(data.to_dict("records"))
|
||||
|
||||
def fetch_memories(self, table_name: str, query_vector: List[float], top_k: int = 10):
|
||||
def fetch_memories(
|
||||
self, table_name: str, query_vector: List[float], top_k: int = 10
|
||||
):
|
||||
# Perform a vector search in the specified table
|
||||
table = self.db.open_table(table_name)
|
||||
results = table.search(query_vector).limit(top_k).to_pandas()
|
||||
|
|
|
|||
|
|
@ -16,15 +16,16 @@ sys.path.insert(0, parent_dir)
|
|||
|
||||
environment = os.getenv("AWS_ENV", "dev")
|
||||
|
||||
|
||||
def fetch_secret(secret_name, region_name, env_file_path):
|
||||
print("Initializing session")
|
||||
session = boto3.session.Session()
|
||||
print("Session initialized")
|
||||
client = session.client(service_name="secretsmanager", region_name = region_name)
|
||||
client = session.client(service_name="secretsmanager", region_name=region_name)
|
||||
print("Client initialized")
|
||||
|
||||
try:
|
||||
response = client.get_secret_value(SecretId = secret_name)
|
||||
response = client.get_secret_value(SecretId=secret_name)
|
||||
except Exception as e:
|
||||
print(f"Error retrieving secret: {e}")
|
||||
return None
|
||||
|
|
@ -46,6 +47,7 @@ def fetch_secret(secret_name, region_name, env_file_path):
|
|||
else:
|
||||
print(f"The .env file was not found at: {env_file_path}.")
|
||||
|
||||
|
||||
ENV_FILE_PATH = os.path.abspath("../.env")
|
||||
|
||||
if os.path.exists(ENV_FILE_PATH):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from ..shared.data_models import Node, Edge, KnowledgeGraph, GraphQLQuery, Memor
|
|||
from ..config import Config
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ import logging
|
|||
# Function to read query prompts from files
|
||||
def read_query_prompt(filename):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
with open(filename, "r") as file:
|
||||
return file.read()
|
||||
except FileNotFoundError:
|
||||
logging.info(f"Error: File not found. Attempted to read: {filename}")
|
||||
|
|
@ -37,7 +38,9 @@ def read_query_prompt(filename):
|
|||
def generate_graph(input) -> KnowledgeGraph:
|
||||
model = "gpt-4-1106-preview"
|
||||
user_prompt = f"Use the given format to extract information from the following input: {input}."
|
||||
system_prompt = read_query_prompt('cognitive_architecture/llm/prompts/generate_graph_prompt.txt')
|
||||
system_prompt = read_query_prompt(
|
||||
"cognitive_architecture/llm/prompts/generate_graph_prompt.txt"
|
||||
)
|
||||
|
||||
out = aclient.chat.completions.create(
|
||||
model=model,
|
||||
|
|
@ -56,38 +59,40 @@ def generate_graph(input) -> KnowledgeGraph:
|
|||
return out
|
||||
|
||||
|
||||
|
||||
async def generate_summary(input) -> MemorySummary:
|
||||
out = aclient.chat.completions.create(
|
||||
out = aclient.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format summarize and reduce the following input: {input}. """,
|
||||
|
||||
},
|
||||
{ "role":"system", "content": """You are a top-tier algorithm
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a top-tier algorithm
|
||||
designed for summarizing existing knowledge graphs in structured formats based on a knowledge graph.
|
||||
## 1. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination.
|
||||
## 2. Don't forget your main goal is to reduce the number of nodes in the knowledge graph while preserving the information contained in it."""}
|
||||
## 2. Don't forget your main goal is to reduce the number of nodes in the knowledge graph while preserving the information contained in it.""",
|
||||
},
|
||||
],
|
||||
response_model=MemorySummary,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def user_query_to_edges_and_nodes( input: str) ->KnowledgeGraph:
|
||||
system_prompt = read_query_prompt('cognitive_architecture/llm/prompts/generate_graph_prompt.txt')
|
||||
def user_query_to_edges_and_nodes(input: str) -> KnowledgeGraph:
|
||||
system_prompt = read_query_prompt(
|
||||
"cognitive_architecture/llm/prompts/generate_graph_prompt.txt"
|
||||
)
|
||||
return aclient.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information from the following input: {input}. """,
|
||||
|
||||
},
|
||||
{"role": "system", "content":system_prompt}
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
response_model=KnowledgeGraph,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
import time
|
||||
|
||||
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
|
||||
|
|
@ -41,7 +40,9 @@ def retry_with_exponential_backoff(
|
|||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
|
|
@ -90,7 +91,9 @@ def aretry_with_exponential_backoff(
|
|||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
|
|
@ -135,6 +138,3 @@ def get_embedding_with_backoff(text, model="text-embedding-ada-002"):
|
|||
response = create_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
|
||||
DEFAULT_PRESET = "cognitive_architecture_chat"
|
||||
preset_options = [DEFAULT_PRESET]
|
||||
|
||||
|
||||
|
||||
def use_preset():
|
||||
"""Placeholder for different present options"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +1,36 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
id: int
|
||||
description: str
|
||||
category: str
|
||||
color: str ="blue"
|
||||
color: str = "blue"
|
||||
memory_type: str
|
||||
created_at: Optional[float] = None
|
||||
summarized: Optional[bool] = None
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: int
|
||||
target: int
|
||||
description: str
|
||||
color: str= "blue"
|
||||
color: str = "blue"
|
||||
created_at: Optional[float] = None
|
||||
summarized: Optional[bool] = None
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: List[Node] = Field(..., default_factory=list)
|
||||
edges: List[Edge] = Field(..., default_factory=list)
|
||||
|
||||
|
||||
class GraphQLQuery(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
nodes: List[Node] = Field(..., default_factory=list)
|
||||
edges: List[Edge] = Field(..., default_factory=list)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@ from botocore.exceptions import BotoCoreError, ClientError
|
|||
from langdetect import detect, LangDetectException
|
||||
import iso639
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
|
||||
# Basic configuration of the logging system
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def detect_language(text):
|
||||
|
|
@ -34,8 +36,8 @@ def detect_language(text):
|
|||
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
|
||||
|
||||
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
|
||||
if detected_lang_iso639_1 == 'hr':
|
||||
return 'sr'
|
||||
if detected_lang_iso639_1 == "hr":
|
||||
return "sr"
|
||||
return detected_lang_iso639_1
|
||||
|
||||
except LangDetectException as e:
|
||||
|
|
@ -46,8 +48,12 @@ def detect_language(text):
|
|||
return -1
|
||||
|
||||
|
||||
|
||||
def translate_text(text, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
|
||||
def translate_text(
|
||||
text,
|
||||
source_language: str = "sr",
|
||||
target_language: str = "en",
|
||||
region_name="eu-west-1",
|
||||
):
|
||||
"""
|
||||
Translate text from source language to target language using AWS Translate.
|
||||
|
||||
|
|
@ -68,9 +74,15 @@ def translate_text(text, source_language:str='sr', target_language:str='en', reg
|
|||
return "Both source and target language codes are required."
|
||||
|
||||
try:
|
||||
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
|
||||
result = translate.translate_text(Text=text, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
|
||||
return result.get('TranslatedText', 'No translation found.')
|
||||
translate = boto3.client(
|
||||
service_name="translate", region_name=region_name, use_ssl=True
|
||||
)
|
||||
result = translate.translate_text(
|
||||
Text=text,
|
||||
SourceLanguageCode=source_language,
|
||||
TargetLanguageCode=target_language,
|
||||
)
|
||||
return result.get("TranslatedText", "No translation found.")
|
||||
|
||||
except BotoCoreError as e:
|
||||
logging.info(f"BotoCoreError occurred: {e}")
|
||||
|
|
@ -81,8 +93,8 @@ def translate_text(text, source_language:str='sr', target_language:str='en', reg
|
|||
return "Error with AWS client or network issue."
|
||||
|
||||
|
||||
source_language = 'sr'
|
||||
target_language = 'en'
|
||||
source_language = "sr"
|
||||
target_language = "en"
|
||||
text_to_translate = "Ja volim da pecam i idem na reku da šetam pored nje ponekad"
|
||||
|
||||
translated_text = translate_text(text_to_translate, source_language, target_language)
|
||||
|
|
|
|||
|
|
@ -22,12 +22,15 @@ class Node:
|
|||
self.description = description
|
||||
self.color = color
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source, target, label, color):
|
||||
self.source = source
|
||||
self.target = target
|
||||
self.label = label
|
||||
self.color = color
|
||||
|
||||
|
||||
# def visualize_knowledge_graph(kg: KnowledgeGraph):
|
||||
# dot = Digraph(comment="Knowledge Graph")
|
||||
#
|
||||
|
|
@ -82,6 +85,7 @@ def get_document_names(doc_input):
|
|||
# doc_input is not valid
|
||||
return []
|
||||
|
||||
|
||||
def format_dict(d):
|
||||
# Initialize an empty list to store formatted items
|
||||
formatted_items = []
|
||||
|
|
@ -89,7 +93,9 @@ def format_dict(d):
|
|||
# Iterate through all key-value pairs
|
||||
for key, value in d.items():
|
||||
# Format key-value pairs with a colon and space, and adding quotes for string values
|
||||
formatted_item = f"{key}: '{value}'" if isinstance(value, str) else f"{key}: {value}"
|
||||
formatted_item = (
|
||||
f"{key}: '{value}'" if isinstance(value, str) else f"{key}: {value}"
|
||||
)
|
||||
formatted_items.append(formatted_item)
|
||||
|
||||
# Join all formatted items with a comma and a space
|
||||
|
|
@ -114,7 +120,7 @@ def create_node_variable_mapping(nodes):
|
|||
mapping = {}
|
||||
for node in nodes:
|
||||
variable_name = f"{node['category']}{node['id']}".lower()
|
||||
mapping[node['id']] = variable_name
|
||||
mapping[node["id"]] = variable_name
|
||||
return mapping
|
||||
|
||||
|
||||
|
|
@ -123,18 +129,23 @@ def create_edge_variable_mapping(edges):
|
|||
for edge in edges:
|
||||
# Construct a unique identifier for the edge
|
||||
variable_name = f"edge{edge['source']}to{edge['target']}".lower()
|
||||
mapping[(edge['source'], edge['target'])] = variable_name
|
||||
mapping[(edge["source"], edge["target"])] = variable_name
|
||||
return mapping
|
||||
|
||||
|
||||
|
||||
def generate_letter_uuid(length=8):
|
||||
"""Generate a random string of uppercase letters with the specified length."""
|
||||
letters = string.ascii_uppercase # A-Z
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
from cognitive_architecture.database.relationaldb.models.operation import Operation
|
||||
from cognitive_architecture.database.relationaldb.database_crud import session_scope, add_entity, update_entity, fetch_job_id
|
||||
from cognitive_architecture.database.relationaldb.database_crud import (
|
||||
session_scope,
|
||||
add_entity,
|
||||
update_entity,
|
||||
fetch_job_id,
|
||||
)
|
||||
from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas
|
||||
from cognitive_architecture.database.relationaldb.models.docs import DocsModel
|
||||
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
|
||||
|
|
@ -142,42 +153,56 @@ from cognitive_architecture.database.relationaldb.models.user import User
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
import logging
|
||||
|
||||
|
||||
async def get_vectordb_namespace(session: AsyncSession, user_id: str):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(MemoryModel.memory_name).where(MemoryModel.user_id == user_id).order_by(MemoryModel.created_at.desc())
|
||||
select(MemoryModel.memory_name)
|
||||
.where(MemoryModel.user_id == user_id)
|
||||
.order_by(MemoryModel.created_at.desc())
|
||||
)
|
||||
namespace = [row[0] for row in result.fetchall()]
|
||||
return namespace
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}")
|
||||
logging.error(
|
||||
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def get_vectordb_document_name(session: AsyncSession, user_id: str):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(DocsModel.doc_name).where(DocsModel.user_id == user_id).order_by(DocsModel.created_at.desc())
|
||||
select(DocsModel.doc_name)
|
||||
.where(DocsModel.user_id == user_id)
|
||||
.order_by(DocsModel.created_at.desc())
|
||||
)
|
||||
doc_names = [row[0] for row in result.fetchall()]
|
||||
return doc_names
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}")
|
||||
logging.error(
|
||||
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def get_model_id_name(session: AsyncSession, id: str):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(MemoryModel.memory_name).where(MemoryModel.id == id).order_by(MemoryModel.created_at.desc())
|
||||
select(MemoryModel.memory_name)
|
||||
.where(MemoryModel.id == id)
|
||||
.order_by(MemoryModel.created_at.desc())
|
||||
)
|
||||
doc_names = [row[0] for row in result.fetchall()]
|
||||
return doc_names
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}")
|
||||
logging.error(
|
||||
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: str):
|
||||
"""
|
||||
Asynchronously retrieves the latest memory names and document details for a given user.
|
||||
|
|
@ -207,13 +232,13 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
|
|||
.join(Operation.memories) # Explicit join with memories table
|
||||
.options(
|
||||
contains_eager(Operation.docs), # Informs ORM of the join for docs
|
||||
contains_eager(Operation.memories) # Informs ORM of the join for memories
|
||||
contains_eager(Operation.memories), # Informs ORM of the join for memories
|
||||
)
|
||||
.where(
|
||||
(Operation.user_id == user_id) & # Filter by user_id
|
||||
or_(
|
||||
(Operation.user_id == user_id)
|
||||
& or_( # Filter by user_id
|
||||
DocsModel.graph_summary == False, # Condition 1: graph_summary is False
|
||||
DocsModel.graph_summary == None # Condition 3: graph_summary is None
|
||||
DocsModel.graph_summary == None, # Condition 3: graph_summary is None
|
||||
) # Filter by user_id
|
||||
)
|
||||
.order_by(Operation.created_at.desc()) # Order by creation date
|
||||
|
|
@ -223,7 +248,11 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
|
|||
|
||||
# Extract memory names and document names and IDs
|
||||
# memory_names = [memory.memory_name for op in operations for memory in op.memories]
|
||||
memory_details = [(memory.memory_name, memory.memory_category) for op in operations for memory in op.memories]
|
||||
memory_details = [
|
||||
(memory.memory_name, memory.memory_category)
|
||||
for op in operations
|
||||
for memory in op.memories
|
||||
]
|
||||
docs = [(doc.doc_name, doc.id) for op in operations for doc in op.docs]
|
||||
|
||||
return memory_details, docs
|
||||
|
|
@ -232,6 +261,8 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
|
|||
# # Handle the exception as needed
|
||||
# print(f"An error occurred: {e}")
|
||||
# return None
|
||||
|
||||
|
||||
async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
|
||||
"""
|
||||
Asynchronously retrieves memory names associated with a specific document ID.
|
||||
|
|
@ -254,8 +285,12 @@ async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
|
|||
try:
|
||||
result = await session.execute(
|
||||
select(MemoryModel.memory_name)
|
||||
.join(Operation, Operation.id == MemoryModel.operation_id) # Join with Operation
|
||||
.join(DocsModel, DocsModel.operation_id == Operation.id) # Join with DocsModel
|
||||
.join(
|
||||
Operation, Operation.id == MemoryModel.operation_id
|
||||
) # Join with Operation
|
||||
.join(
|
||||
DocsModel, DocsModel.operation_id == Operation.id
|
||||
) # Join with DocsModel
|
||||
.where(DocsModel.id == docs_id) # Filtering based on the passed document ID
|
||||
.distinct() # To avoid duplicate memory names
|
||||
)
|
||||
|
|
@ -269,7 +304,6 @@ async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
|
|||
return None
|
||||
|
||||
|
||||
|
||||
#
|
||||
# async def main():
|
||||
# user_id = "user"
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ import os
|
|||
print(os.getcwd())
|
||||
|
||||
|
||||
|
||||
|
||||
from cognitive_architecture.database.relationaldb.models.user import User
|
||||
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
|
||||
|
||||
|
|
@ -27,7 +25,6 @@ import uuid
|
|||
load_dotenv()
|
||||
|
||||
|
||||
|
||||
from cognitive_architecture.database.vectordb.basevectordb import BaseMemory
|
||||
|
||||
from cognitive_architecture.config import Config
|
||||
|
|
@ -36,8 +33,6 @@ config = Config()
|
|||
config.load()
|
||||
|
||||
|
||||
|
||||
|
||||
class DynamicBaseMemory(BaseMemory):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -145,8 +140,8 @@ class Memory:
|
|||
db_type: str = None,
|
||||
namespace: str = None,
|
||||
memory_id: str = None,
|
||||
memory_class = None,
|
||||
job_id:str = None
|
||||
memory_class=None,
|
||||
job_id: str = None,
|
||||
) -> None:
|
||||
self.load_environment_variables()
|
||||
self.memory_id = memory_id
|
||||
|
|
@ -157,20 +152,25 @@ class Memory:
|
|||
self.namespace = namespace
|
||||
self.memory_instances = []
|
||||
self.memory_class = memory_class
|
||||
self.job_id=job_id
|
||||
self.job_id = job_id
|
||||
# self.memory_class = DynamicBaseMemory(
|
||||
# "Memory", user_id, str(self.memory_id), index_name, db_type, namespace
|
||||
# )
|
||||
|
||||
|
||||
|
||||
def load_environment_variables(self) -> None:
|
||||
load_dotenv()
|
||||
self.OPENAI_TEMPERATURE = config.openai_temperature
|
||||
self.OPENAI_API_KEY = config.openai_key
|
||||
|
||||
@classmethod
|
||||
async def create_memory(cls, user_id: str, session, job_id:str=None, memory_label:str=None, **kwargs):
|
||||
async def create_memory(
|
||||
cls,
|
||||
user_id: str,
|
||||
session,
|
||||
job_id: str = None,
|
||||
memory_label: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Class method that acts as a factory method for creating Memory instances.
|
||||
It performs necessary DB checks or updates before instance creation.
|
||||
|
|
@ -180,9 +180,14 @@ class Memory:
|
|||
|
||||
if existing_user:
|
||||
# Handle existing user scenario...
|
||||
memory_id = await cls.check_existing_memory(user_id,memory_label, session)
|
||||
memory_id = await cls.check_existing_memory(user_id, memory_label, session)
|
||||
if memory_id is None:
|
||||
memory_id = await cls.handle_new_memory(user_id = user_id, session= session,job_id=job_id, memory_name= memory_label)
|
||||
memory_id = await cls.handle_new_memory(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
memory_name=memory_label,
|
||||
)
|
||||
logging.info(
|
||||
f"Existing user {user_id} found in the DB. Memory ID: {memory_id}"
|
||||
)
|
||||
|
|
@ -190,16 +195,33 @@ class Memory:
|
|||
# Handle new user scenario...
|
||||
await cls.handle_new_user(user_id, session)
|
||||
|
||||
memory_id = await cls.handle_new_memory(user_id =user_id, session=session, job_id=job_id, memory_name= memory_label)
|
||||
memory_id = await cls.handle_new_memory(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
job_id=job_id,
|
||||
memory_name=memory_label,
|
||||
)
|
||||
logging.info(
|
||||
f"New user {user_id} created in the DB. Memory ID: {memory_id}"
|
||||
)
|
||||
|
||||
memory_class = DynamicBaseMemory(
|
||||
memory_label, user_id, str(memory_id), index_name=memory_label , db_type=config.vectordb, **kwargs
|
||||
memory_label,
|
||||
user_id,
|
||||
str(memory_id),
|
||||
index_name=memory_label,
|
||||
db_type=config.vectordb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return cls(user_id=user_id, session=session, memory_id=memory_id, job_id =job_id, memory_class=memory_class, **kwargs)
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
memory_id=memory_id,
|
||||
job_id=job_id,
|
||||
memory_class=memory_class,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def list_memory_classes(self):
|
||||
"""
|
||||
|
|
@ -215,19 +237,20 @@ class Memory:
|
|||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def check_existing_memory(user_id: str, memory_label:str, session):
|
||||
async def check_existing_memory(user_id: str, memory_label: str, session):
|
||||
"""Check if a user memory exists in the DB and return it. Filters by user and label"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(MemoryModel.id).where(MemoryModel.user_id == user_id)
|
||||
select(MemoryModel.id)
|
||||
.where(MemoryModel.user_id == user_id)
|
||||
.filter_by(memory_name=memory_label)
|
||||
.order_by(MemoryModel.created_at)
|
||||
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def handle_new_user(user_id: str, session):
|
||||
"""
|
||||
|
|
@ -251,7 +274,13 @@ class Memory:
|
|||
return f"Error creating user: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
async def handle_new_memory(user_id: str, session, job_id: str = None, memory_name: str = None, memory_category:str='PUBLIC'):
|
||||
async def handle_new_memory(
|
||||
user_id: str,
|
||||
session,
|
||||
job_id: str = None,
|
||||
memory_name: str = None,
|
||||
memory_category: str = "PUBLIC",
|
||||
):
|
||||
"""
|
||||
Handle new memory creation associated with a user.
|
||||
|
||||
|
|
@ -296,7 +325,6 @@ class Memory:
|
|||
except Exception as e:
|
||||
return f"Error creating memory: {str(e)}"
|
||||
|
||||
|
||||
async def add_memory_instance(self, memory_class_name: str):
|
||||
"""Add a new memory instance to the memory_instances list."""
|
||||
instance = DynamicBaseMemory(
|
||||
|
|
@ -446,7 +474,9 @@ async def main():
|
|||
from database.relationaldb.database import AsyncSessionLocal
|
||||
|
||||
async with session_scope(AsyncSessionLocal()) as session:
|
||||
memory = await Memory.create_memory("677", session, "SEMANTICMEMORY", namespace="SEMANTICMEMORY")
|
||||
memory = await Memory.create_memory(
|
||||
"677", session, "SEMANTICMEMORY", namespace="SEMANTICMEMORY"
|
||||
)
|
||||
ff = memory.memory_instances
|
||||
logging.info("ssss %s", ff)
|
||||
|
||||
|
|
@ -462,8 +492,13 @@ async def main():
|
|||
await memory.add_dynamic_memory_class("semanticmemory", "SEMANTICMEMORY")
|
||||
await memory.add_method_to_class(memory.semanticmemory_class, "add_memories")
|
||||
await memory.add_method_to_class(memory.semanticmemory_class, "fetch_memories")
|
||||
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
|
||||
observation='some_observation', params=params, loader_settings=loader_settings)
|
||||
sss = await memory.dynamic_method_call(
|
||||
memory.semanticmemory_class,
|
||||
"add_memories",
|
||||
observation="some_observation",
|
||||
params=params,
|
||||
loader_settings=loader_settings,
|
||||
)
|
||||
|
||||
# susu = await memory.dynamic_method_call(
|
||||
# memory.semanticmemory_class,
|
||||
|
|
|
|||
435
main.py
435
main.py
|
|
@ -7,7 +7,10 @@ from cognitive_architecture.database.relationaldb.models.memory import MemoryMod
|
|||
from cognitive_architecture.classifiers.classifier import classify_documents
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from cognitive_architecture.database.relationaldb.database_crud import session_scope, update_entity_graph_summary
|
||||
from cognitive_architecture.database.relationaldb.database_crud import (
|
||||
session_scope,
|
||||
update_entity_graph_summary,
|
||||
)
|
||||
from cognitive_architecture.database.relationaldb.database import AsyncSessionLocal
|
||||
from cognitive_architecture.utils import generate_letter_uuid
|
||||
import instructor
|
||||
|
|
@ -17,12 +20,18 @@ from cognitive_architecture.database.relationaldb.database_crud import fetch_job
|
|||
import uuid
|
||||
from cognitive_architecture.database.relationaldb.models.sessions import Session
|
||||
from cognitive_architecture.database.relationaldb.models.operation import Operation
|
||||
from cognitive_architecture.database.relationaldb.database_crud import session_scope, add_entity, update_entity, fetch_job_id
|
||||
from cognitive_architecture.database.relationaldb.database_crud import (
|
||||
session_scope,
|
||||
add_entity,
|
||||
update_entity,
|
||||
fetch_job_id,
|
||||
)
|
||||
from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas
|
||||
from cognitive_architecture.database.relationaldb.models.docs import DocsModel
|
||||
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
|
||||
from cognitive_architecture.database.relationaldb.models.user import User
|
||||
from cognitive_architecture.classifiers.classifier import classify_call
|
||||
|
||||
aclient = instructor.patch(OpenAI())
|
||||
DEFAULT_PRESET = "promethai_chat"
|
||||
preset_options = [DEFAULT_PRESET]
|
||||
|
|
@ -30,6 +39,7 @@ PROMETHAI_DIR = os.path.join(os.path.expanduser("~"), ".")
|
|||
load_dotenv()
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
from cognitive_architecture.config import Config
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
from cognitive_architecture.utils import get_document_names
|
||||
|
|
@ -37,14 +47,28 @@ from sqlalchemy.orm import selectinload, joinedload, contains_eager
|
|||
import logging
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from cognitive_architecture.utils import get_document_names, generate_letter_uuid, get_memory_name_by_doc_id, get_unsumarized_vector_db_namespace, get_vectordb_namespace, get_vectordb_document_name
|
||||
from cognitive_architecture.shared.language_processing import translate_text, detect_language
|
||||
from cognitive_architecture.utils import (
|
||||
get_document_names,
|
||||
generate_letter_uuid,
|
||||
get_memory_name_by_doc_id,
|
||||
get_unsumarized_vector_db_namespace,
|
||||
get_vectordb_namespace,
|
||||
get_vectordb_document_name,
|
||||
)
|
||||
from cognitive_architecture.shared.language_processing import (
|
||||
translate_text,
|
||||
detect_language,
|
||||
)
|
||||
from cognitive_architecture.classifiers.classifier import classify_user_input
|
||||
|
||||
async def fetch_document_vectordb_namespace(session: AsyncSession, user_id: str, namespace_id:str, doc_id:str=None):
|
||||
logging.info("user id is", user_id)
|
||||
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, memory_label=namespace_id)
|
||||
|
||||
async def fetch_document_vectordb_namespace(
|
||||
session: AsyncSession, user_id: str, namespace_id: str, doc_id: str = None
|
||||
):
|
||||
logging.info("user id is", user_id)
|
||||
memory = await Memory.create_memory(
|
||||
user_id, session, namespace=namespace_id, memory_label=namespace_id
|
||||
)
|
||||
|
||||
# Managing memory attributes
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
|
|
@ -66,15 +90,26 @@ async def fetch_document_vectordb_namespace(session: AsyncSession, user_id: str,
|
|||
print(f"No attribute named in memory.")
|
||||
|
||||
print("Available memory classes:", await memory.list_memory_classes())
|
||||
result = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
|
||||
observation="placeholder", search_type="summary_filter_by_object_name", params=doc_id)
|
||||
result = await memory.dynamic_method_call(
|
||||
dynamic_memory_class,
|
||||
"fetch_memories",
|
||||
observation="placeholder",
|
||||
search_type="summary_filter_by_object_name",
|
||||
params=doc_id,
|
||||
)
|
||||
logging.info("Result is %s", str(result))
|
||||
|
||||
return result, namespace_id
|
||||
|
||||
|
||||
|
||||
async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, content:str=None, job_id:str=None, loader_settings:dict=None, memory_type:str="PRIVATE"):
|
||||
async def load_documents_to_vectorstore(
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
content: str = None,
|
||||
job_id: str = None,
|
||||
loader_settings: dict = None,
|
||||
memory_type: str = "PRIVATE",
|
||||
):
|
||||
namespace_id = str(generate_letter_uuid()) + "_" + "SEMANTICMEMORY"
|
||||
namespace_class = namespace_id + "_class"
|
||||
|
||||
|
|
@ -96,12 +131,21 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
|
|||
operation_type="DATA_LOAD",
|
||||
),
|
||||
)
|
||||
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, job_id=job_id, memory_label=namespace_id)
|
||||
memory = await Memory.create_memory(
|
||||
user_id,
|
||||
session,
|
||||
namespace=namespace_id,
|
||||
job_id=job_id,
|
||||
memory_label=namespace_id,
|
||||
)
|
||||
if content is not None:
|
||||
document_names = [content[:30]]
|
||||
if loader_settings is not None:
|
||||
document_source = loader_settings.get("document_names") if isinstance(loader_settings.get("document_names"),
|
||||
list) else loader_settings.get("path", "None")
|
||||
document_source = (
|
||||
loader_settings.get("document_names")
|
||||
if isinstance(loader_settings.get("document_names"), list)
|
||||
else loader_settings.get("path", "None")
|
||||
)
|
||||
logging.info("Document source is %s", document_source)
|
||||
# try:
|
||||
document_names = get_document_names(document_source[0])
|
||||
|
|
@ -109,12 +153,21 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
|
|||
# except:
|
||||
# document_names = document_source
|
||||
for doc in document_names:
|
||||
from cognitive_architecture.shared.language_processing import translate_text, detect_language
|
||||
#translates doc titles to english
|
||||
from cognitive_architecture.shared.language_processing import (
|
||||
translate_text,
|
||||
detect_language,
|
||||
)
|
||||
|
||||
# translates doc titles to english
|
||||
if loader_settings is not None:
|
||||
logging.info("Detecting language of document %s", doc)
|
||||
loader_settings["single_document_path"]= loader_settings.get("path", "None")[0] +"/"+doc
|
||||
logging.info("Document path is %s", loader_settings.get("single_document_path", "None"))
|
||||
loader_settings["single_document_path"] = (
|
||||
loader_settings.get("path", "None")[0] + "/" + doc
|
||||
)
|
||||
logging.info(
|
||||
"Document path is %s",
|
||||
loader_settings.get("single_document_path", "None"),
|
||||
)
|
||||
memory_category = loader_settings.get("memory_category", "PUBLIC")
|
||||
if loader_settings is None:
|
||||
memory_category = "CUSTOM"
|
||||
|
|
@ -122,7 +175,7 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
|
|||
doc_ = doc.strip(".pdf").replace("-", " ")
|
||||
doc_ = translate_text(doc_, "sr", "en")
|
||||
else:
|
||||
doc_=doc
|
||||
doc_ = doc
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
logging.info("Document name is %s", doc_)
|
||||
|
|
@ -131,17 +184,15 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
|
|||
DocsModel(
|
||||
id=doc_id,
|
||||
operation_id=job_id,
|
||||
graph_summary= False,
|
||||
memory_category= memory_category,
|
||||
doc_name=doc_
|
||||
)
|
||||
graph_summary=False,
|
||||
memory_category=memory_category,
|
||||
doc_name=doc_,
|
||||
),
|
||||
)
|
||||
# Managing memory attributes
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
await memory.manage_memory_attributes(existing_user)
|
||||
params = {
|
||||
"doc_id":doc_id
|
||||
}
|
||||
params = {"doc_id": doc_id}
|
||||
print("Namespace id is %s", namespace_id)
|
||||
await memory.add_dynamic_memory_class(namespace_id.lower(), namespace_id)
|
||||
|
||||
|
|
@ -157,13 +208,18 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
|
|||
print(f"No attribute named in memory.")
|
||||
|
||||
print("Available memory classes:", await memory.list_memory_classes())
|
||||
result = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
|
||||
observation=content, params=params, loader_settings=loader_settings)
|
||||
result = await memory.dynamic_method_call(
|
||||
dynamic_memory_class,
|
||||
"add_memories",
|
||||
observation=content,
|
||||
params=params,
|
||||
loader_settings=loader_settings,
|
||||
)
|
||||
await update_entity(session, Operation, job_id, "SUCCESS")
|
||||
return 1
|
||||
|
||||
async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_input: str):
|
||||
|
||||
async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_input: str):
|
||||
try:
|
||||
new_user = User(id=user_id)
|
||||
await add_entity(session, new_user)
|
||||
|
|
@ -189,19 +245,32 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
|
|||
else:
|
||||
translated_query = query_input
|
||||
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
cypher_query = await neo4j_graph_db.generate_cypher_query_for_user_prompt_decomposition(user_id, translated_query)
|
||||
cypher_query = (
|
||||
await neo4j_graph_db.generate_cypher_query_for_user_prompt_decomposition(
|
||||
user_id, translated_query
|
||||
)
|
||||
)
|
||||
result = neo4j_graph_db.query(cypher_query)
|
||||
|
||||
neo4j_graph_db.run_merge_query(user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8)
|
||||
neo4j_graph_db.run_merge_query(user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8)
|
||||
neo4j_graph_db.run_merge_query(
|
||||
user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8
|
||||
)
|
||||
neo4j_graph_db.run_merge_query(
|
||||
user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
|
||||
await update_entity(session, Operation, job_id, "SUCCESS")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# async def add_documents_to_graph_db(session: AsyncSession, user_id: Optional[str] = None,
|
||||
# document_memory_types: Optional[List[str]] = None):
|
||||
# """ Add documents to a graph database, handling multiple memory types """
|
||||
|
|
@ -256,106 +325,159 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
|
|||
# return e
|
||||
|
||||
|
||||
async def add_documents_to_graph_db(session: AsyncSession, user_id: str= None, document_memory_types:list=None):
|
||||
async def add_documents_to_graph_db(
|
||||
session: AsyncSession, user_id: str = None, document_memory_types: list = None
|
||||
):
|
||||
""""""
|
||||
if document_memory_types is None:
|
||||
document_memory_types = ['PUBLIC']
|
||||
document_memory_types = ["PUBLIC"]
|
||||
|
||||
logging.info("Document memory types are", document_memory_types)
|
||||
try:
|
||||
# await update_document_vectordb_namespace(postgres_session, user_id)
|
||||
memory_details, docs = await get_unsumarized_vector_db_namespace(session, user_id)
|
||||
memory_details, docs = await get_unsumarized_vector_db_namespace(
|
||||
session, user_id
|
||||
)
|
||||
|
||||
logging.info("Docs are", docs)
|
||||
memory_details= [detail for detail in memory_details if detail[1] in document_memory_types]
|
||||
memory_details = [
|
||||
detail for detail in memory_details if detail[1] in document_memory_types
|
||||
]
|
||||
logging.info("Memory details", memory_details)
|
||||
for doc in docs:
|
||||
logging.info("Memory names are", memory_details)
|
||||
doc_name, doc_id = doc
|
||||
logging.info("Doc id is", doc_id)
|
||||
try:
|
||||
classification_content = await fetch_document_vectordb_namespace(session, user_id, memory_details[0][0], doc_id)
|
||||
retrieval_chunks = [item['text'] for item in
|
||||
classification_content[0]['data']['Get'][memory_details[0][0]]]
|
||||
classification_content = await fetch_document_vectordb_namespace(
|
||||
session, user_id, memory_details[0][0], doc_id
|
||||
)
|
||||
retrieval_chunks = [
|
||||
item["text"]
|
||||
for item in classification_content[0]["data"]["Get"][
|
||||
memory_details[0][0]
|
||||
]
|
||||
]
|
||||
logging.info("Classification content is", classification_content)
|
||||
except:
|
||||
classification_content = ""
|
||||
retrieval_chunks = ""
|
||||
# retrieval_chunks = [item['text'] for item in classification_content[0]['data']['Get'][memory_details[0]]]
|
||||
# Concatenating the extracted text values
|
||||
concatenated_retrievals = ' '.join(retrieval_chunks)
|
||||
concatenated_retrievals = " ".join(retrieval_chunks)
|
||||
print(concatenated_retrievals)
|
||||
logging.info("Retrieval chunks are", retrieval_chunks)
|
||||
classification = await classify_documents(doc_name, document_id =doc_id, content=concatenated_retrievals)
|
||||
classification = await classify_documents(
|
||||
doc_name, document_id=doc_id, content=concatenated_retrievals
|
||||
)
|
||||
|
||||
logging.info("Classification is %s", str(classification))
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
if document_memory_types == ['PUBLIC']:
|
||||
await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic="PublicMemory")
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
if document_memory_types == ["PUBLIC"]:
|
||||
await create_public_memory(
|
||||
user_id=user_id, labels=["sr"], topic="PublicMemory"
|
||||
)
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
topic="PublicMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
print(ids)
|
||||
else:
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic="SemanticMemory")
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
|
||||
topic="SemanticMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
print(ids)
|
||||
|
||||
for id in ids:
|
||||
print(id.get('memoryId'))
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
if document_memory_types == ['PUBLIC']:
|
||||
|
||||
rs = neo4j_graph_db.create_document_node_cypher(classification, user_id, public_memory_id=id.get('memoryId'))
|
||||
print(id.get("memoryId"))
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
if document_memory_types == ["PUBLIC"]:
|
||||
rs = neo4j_graph_db.create_document_node_cypher(
|
||||
classification, user_id, public_memory_id=id.get("memoryId")
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
else:
|
||||
rs = neo4j_graph_db.create_document_node_cypher(classification, user_id, memory_type='SemanticMemory')
|
||||
rs = neo4j_graph_db.create_document_node_cypher(
|
||||
classification, user_id, memory_type="SemanticMemory"
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
logging.info("Cypher query is %s", str(rs))
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
neo4j_graph_db.query(rs)
|
||||
neo4j_graph_db.close()
|
||||
logging.info("WE GOT HERE")
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
if memory_details[0][1] == "PUBLIC":
|
||||
|
||||
neo4j_graph_db.update_document_node_with_db_ids( vectordb_namespace=memory_details[0][0],
|
||||
document_id=doc_id)
|
||||
neo4j_graph_db.update_document_node_with_db_ids(
|
||||
vectordb_namespace=memory_details[0][0], document_id=doc_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
else:
|
||||
neo4j_graph_db.update_document_node_with_db_ids( vectordb_namespace=memory_details[0][0],
|
||||
document_id=doc_id, user_id=user_id)
|
||||
neo4j_graph_db.update_document_node_with_db_ids(
|
||||
vectordb_namespace=memory_details[0][0],
|
||||
document_id=doc_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
# await update_entity_graph_summary(session, DocsModel, doc_id, True)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
class ResponseString(BaseModel):
|
||||
response: str = Field(default=None) # Defaulting to None or you can use a default string like ""
|
||||
response: str = Field(
|
||||
default=None
|
||||
) # Defaulting to None or you can use a default string like ""
|
||||
quotation: str = Field(default=None) # Same here
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
def generate_graph(input) -> ResponseString:
|
||||
out = aclient.chat.completions.create(
|
||||
out = aclient.chat.completions.create(
|
||||
model="gpt-4-1106-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given context to answer query and use help of associated context: {input}. """,
|
||||
|
||||
},
|
||||
{ "role":"system", "content": """You are a top-tier algorithm
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a top-tier algorithm
|
||||
designed for using context summaries based on cognitive psychology to answer user queries, and provide a simple response.
|
||||
Do not mention anything explicit about cognitive architecture, but use the context to answer the query. If you are using a document, reference document metadata field"""}
|
||||
Do not mention anything explicit about cognitive architecture, but use the context to answer the query. If you are using a document, reference document metadata field""",
|
||||
},
|
||||
],
|
||||
response_model=ResponseString,
|
||||
)
|
||||
return out
|
||||
async def user_context_enrichment(session, user_id:str, query:str, generative_response:bool=False, memory_type:str=None)->str:
|
||||
|
||||
|
||||
async def user_context_enrichment(
|
||||
session,
|
||||
user_id: str,
|
||||
query: str,
|
||||
generative_response: bool = False,
|
||||
memory_type: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously enriches the user context by integrating various memory systems and document classifications.
|
||||
|
||||
|
|
@ -387,31 +509,38 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
|||
enriched_context = await user_context_enrichment(session, "user123", "How does cognitive architecture work?")
|
||||
```
|
||||
"""
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
# await user_query_to_graph_db(session, user_id, query)
|
||||
|
||||
semantic_mem = neo4j_graph_db.retrieve_semantic_memory(user_id=user_id)
|
||||
neo4j_graph_db.close()
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
episodic_mem = neo4j_graph_db.retrieve_episodic_memory(user_id=user_id)
|
||||
neo4j_graph_db.close()
|
||||
# public_mem = neo4j_graph_db.retrieve_public_memory(user_id=user_id)
|
||||
|
||||
|
||||
|
||||
if detect_language(query) != "en":
|
||||
query = translate_text(query, "sr", "en")
|
||||
logging.info("Translated query is %s", str(query))
|
||||
|
||||
if memory_type=='PublicMemory':
|
||||
|
||||
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(user_id=user_id, memory_type=memory_type)
|
||||
if memory_type == "PublicMemory":
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(
|
||||
user_id=user_id, memory_type=memory_type
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
logging.info("Summaries are is %s", summaries)
|
||||
# logging.info("Context from graphdb is %s", context)
|
||||
|
|
@ -424,7 +553,9 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
|||
relevant_summary_id = None
|
||||
|
||||
for _ in range(max_attempts):
|
||||
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries))
|
||||
relevant_summary_id = await classify_call(
|
||||
query=query, document_summaries=str(summaries)
|
||||
)
|
||||
|
||||
logging.info("Relevant summary id is %s", relevant_summary_id)
|
||||
|
||||
|
|
@ -432,22 +563,32 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
|||
break
|
||||
|
||||
# logging.info("Relevant categories after the classifier are %s", relevant_categories)
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(user_id, summary_id = relevant_summary_id, memory_type=memory_type)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(
|
||||
user_id, summary_id=relevant_summary_id, memory_type=memory_type
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
# postgres_id = neo4j_graph_db.query(get_doc_ids)
|
||||
logging.info("Postgres ids are %s", postgres_id)
|
||||
namespace_id = await get_memory_name_by_doc_id(session, postgres_id[0])
|
||||
logging.info("Namespace ids are %s", namespace_id)
|
||||
params= {"doc_id":postgres_id[0]}
|
||||
params = {"doc_id": postgres_id[0]}
|
||||
namespace_id = namespace_id[0]
|
||||
namespace_class = namespace_id + "_class"
|
||||
if memory_type =='PublicMemory':
|
||||
user_id = 'system_user'
|
||||
if memory_type == "PublicMemory":
|
||||
user_id = "system_user"
|
||||
|
||||
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, job_id="23232",
|
||||
memory_label=namespace_id)
|
||||
memory = await Memory.create_memory(
|
||||
user_id,
|
||||
session,
|
||||
namespace=namespace_id,
|
||||
job_id="23232",
|
||||
memory_label=namespace_id,
|
||||
)
|
||||
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
print("here is the existing user", existing_user)
|
||||
|
|
@ -468,17 +609,26 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
|||
print(f"No attribute named in memory.")
|
||||
|
||||
print("Available memory classes:", await memory.list_memory_classes())
|
||||
results = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
|
||||
observation=query, params=postgres_id[0], search_type="summary_filter_by_object_name")
|
||||
results = await memory.dynamic_method_call(
|
||||
dynamic_memory_class,
|
||||
"fetch_memories",
|
||||
observation=query,
|
||||
params=postgres_id[0],
|
||||
search_type="summary_filter_by_object_name",
|
||||
)
|
||||
logging.info("Result is %s", str(results))
|
||||
|
||||
|
||||
search_context = ""
|
||||
|
||||
for result in results['data']['Get'][namespace_id]:
|
||||
for result in results["data"]["Get"][namespace_id]:
|
||||
# Assuming 'result' is a dictionary and has keys like 'source', 'text'
|
||||
source = result['source'].replace('-', ' ').replace('.pdf', '').replace('.data/', '')
|
||||
text = result['text']
|
||||
source = (
|
||||
result["source"]
|
||||
.replace("-", " ")
|
||||
.replace(".pdf", "")
|
||||
.replace(".data/", "")
|
||||
)
|
||||
text = result["text"]
|
||||
search_context += f"Document source: {source}, Document text: {text} \n"
|
||||
|
||||
else:
|
||||
|
|
@ -502,7 +652,9 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
|||
return generative_result.model_dump_json()
|
||||
|
||||
|
||||
async def create_public_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]:
|
||||
async def create_public_memory(
|
||||
user_id: str = None, labels: list = None, topic: str = None
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Create a public memory node associated with a user in a Neo4j graph database.
|
||||
If Public Memory exists, it will return the id of the memory.
|
||||
|
|
@ -521,16 +673,17 @@ async def create_public_memory(user_id: str=None, labels:list=None, topic:str=No
|
|||
"""
|
||||
# Validate input parameters
|
||||
if not labels:
|
||||
labels = ['sr'] # Labels for the memory node
|
||||
labels = ["sr"] # Labels for the memory node
|
||||
|
||||
if not topic:
|
||||
topic = "PublicMemory"
|
||||
|
||||
|
||||
try:
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
# Create the memory node
|
||||
|
|
@ -541,7 +694,10 @@ async def create_public_memory(user_id: str=None, labels:list=None, topic:str=No
|
|||
logging.error(f"Error creating public memory node: {e}")
|
||||
return None
|
||||
|
||||
async def attach_user_to_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]:
|
||||
|
||||
async def attach_user_to_memory(
|
||||
user_id: str = None, labels: list = None, topic: str = None
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Link user to public memory
|
||||
|
||||
|
|
@ -560,33 +716,41 @@ async def attach_user_to_memory(user_id: str=None, labels:list=None, topic:str=N
|
|||
if not user_id:
|
||||
raise ValueError("User ID is required.")
|
||||
if not labels:
|
||||
labels = ['sr'] # Labels for the memory node
|
||||
labels = ["sr"] # Labels for the memory node
|
||||
|
||||
if not topic:
|
||||
topic = "PublicMemory"
|
||||
|
||||
|
||||
try:
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
neo4j_graph_db.close()
|
||||
|
||||
for id in ids:
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
linked_memory = neo4j_graph_db.link_public_memory_to_user(memory_id=id.get('memoryId'), user_id=user_id)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
linked_memory = neo4j_graph_db.link_public_memory_to_user(
|
||||
memory_id=id.get("memoryId"), user_id=user_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
return 1
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating public memory node: {e}")
|
||||
return None
|
||||
|
||||
async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]:
|
||||
|
||||
async def unlink_user_from_memory(
|
||||
user_id: str = None, labels: list = None, topic: str = None
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Unlink user from memory
|
||||
|
||||
|
|
@ -604,34 +768,39 @@ async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str
|
|||
if not user_id:
|
||||
raise ValueError("User ID is required.")
|
||||
if not labels:
|
||||
labels = ['sr'] # Labels for the memory node
|
||||
labels = ["sr"] # Labels for the memory node
|
||||
|
||||
if not topic:
|
||||
topic = "PublicMemory"
|
||||
|
||||
|
||||
try:
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
# Assuming the topic for public memory is predefined, e.g., "PublicMemory"
|
||||
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
|
||||
neo4j_graph_db.close()
|
||||
|
||||
for id in ids:
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
linked_memory = neo4j_graph_db.unlink_memory_from_user(memory_id=id.get('memoryId'), user_id=user_id)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
linked_memory = neo4j_graph_db.unlink_memory_from_user(
|
||||
memory_id=id.get("memoryId"), user_id=user_id
|
||||
)
|
||||
neo4j_graph_db.close()
|
||||
return 1
|
||||
except Neo4jError as e:
|
||||
logging.error(f"Error creating public memory node: {e}")
|
||||
return None
|
||||
|
||||
async def relevance_feedback(query: str, input_type: str):
|
||||
|
||||
async def relevance_feedback(query: str, input_type: str):
|
||||
max_attempts = 6
|
||||
result = None
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
|
|
@ -641,7 +810,6 @@ async def relevance_feedback(query: str, input_type: str):
|
|||
return result
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
user_id = "user_test_1_1"
|
||||
|
||||
|
|
@ -649,8 +817,6 @@ async def main():
|
|||
# await update_entity(session, DocsModel, "8cd9a022-5a7a-4af5-815a-f988415536ae", True)
|
||||
# output = await get_unsumarized_vector_db_namespace(session, user_id)
|
||||
|
||||
|
||||
|
||||
class GraphQLQuery(BaseModel):
|
||||
query: str
|
||||
|
||||
|
|
@ -713,7 +879,7 @@ async def main():
|
|||
# print(out)
|
||||
# load_doc_to_graph = await add_documents_to_graph_db(session, user_id)
|
||||
# print(load_doc_to_graph)
|
||||
user_id = 'test_user'
|
||||
user_id = "test_user"
|
||||
# loader_settings = {
|
||||
# "format": "PDF",
|
||||
# "source": "DEVICE",
|
||||
|
|
@ -723,10 +889,15 @@ async def main():
|
|||
# await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
||||
# await add_documents_to_graph_db(session, user_id)
|
||||
#
|
||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||
password=config.graph_database_password)
|
||||
neo4j_graph_db = Neo4jGraphDB(
|
||||
url=config.graph_database_url,
|
||||
username=config.graph_database_username,
|
||||
password=config.graph_database_password,
|
||||
)
|
||||
|
||||
out = neo4j_graph_db.run_merge_query(user_id = user_id, memory_type="SemanticMemory", similarity_threshold=0.5)
|
||||
out = neo4j_graph_db.run_merge_query(
|
||||
user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.5
|
||||
)
|
||||
bb = neo4j_graph_db.query(out)
|
||||
print(bb)
|
||||
|
||||
|
|
@ -798,6 +969,4 @@ async def main():
|
|||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
201
poetry.lock
generated
201
poetry.lock
generated
|
|
@ -167,6 +167,20 @@ doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-
|
|||
test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (<0.22)"]
|
||||
|
||||
[[package]]
|
||||
name = "astroid"
|
||||
version = "3.0.3"
|
||||
description = "An abstract syntax tree for Python with inference support."
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "astroid-3.0.3-py3-none-any.whl", hash = "sha256:92fcf218b89f449cdf9f7b39a269f8d5d617b27be68434912e11e79203963a17"},
|
||||
{file = "astroid-3.0.3.tar.gz", hash = "sha256:4148645659b08b70d72460ed1921158027a9e53ae8b7234149b1400eddacbb93"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "astunparse"
|
||||
version = "1.6.3"
|
||||
|
|
@ -520,6 +534,17 @@ files = [
|
|||
[package.dependencies]
|
||||
pycparser = "*"
|
||||
|
||||
[[package]]
|
||||
name = "cfgv"
|
||||
version = "3.4.0"
|
||||
description = "Validate configuration and produce human readable error messages."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
|
||||
{file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chardet"
|
||||
version = "5.2.0"
|
||||
|
|
@ -1027,6 +1052,32 @@ files = [
|
|||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.3.8"
|
||||
description = "serialize all of Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
|
||||
{file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
graph = ["objgraph (>=1.7.2)"]
|
||||
profile = ["gprof2dot (>=2022.7.29)"]
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.3.8"
|
||||
description = "Distribution utilities"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
|
||||
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
|
|
@ -1962,6 +2013,20 @@ files = [
|
|||
[package.extras]
|
||||
tests = ["freezegun", "pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.5.34"
|
||||
description = "File identification library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "identify-2.5.34-py2.py3-none-any.whl", hash = "sha256:a4316013779e433d08b96e5eabb7f641e6c7942e4ab5d4c509ebd2e7a8994aed"},
|
||||
{file = "identify-2.5.34.tar.gz", hash = "sha256:ee17bc9d499899bc9eaec1ac7bf2dc9eedd480db9d88b96d123d3b64a9d34f5d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
license = ["ukkonen"]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.6"
|
||||
|
|
@ -2037,6 +2102,20 @@ files = [
|
|||
{file = "iso639-0.1.4.tar.gz", hash = "sha256:88b70cf6c64ee9c2c2972292818c8beb32db9ea6f4de1f8471a9b081a3d92e98"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "5.13.2"
|
||||
description = "A Python utility / library to sort Python imports."
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"},
|
||||
{file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
colors = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.1.2"
|
||||
|
|
@ -2760,6 +2839,17 @@ pillow = ">=8"
|
|||
pyparsing = ">=2.3.1"
|
||||
python-dateutil = ">=2.7"
|
||||
|
||||
[[package]]
|
||||
name = "mccabe"
|
||||
version = "0.7.0"
|
||||
description = "McCabe checker, plugin for flake8"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
|
|
@ -2996,6 +3086,20 @@ plot = ["matplotlib"]
|
|||
tgrep = ["pyparsing"]
|
||||
twitter = ["twython"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.8.0"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
|
||||
files = [
|
||||
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
|
||||
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.2"
|
||||
|
|
@ -3317,8 +3421,8 @@ files = [
|
|||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
@ -3618,6 +3722,21 @@ urllib3 = ">=1.21.1"
|
|||
[package.extras]
|
||||
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.19.5,<3.20.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "4.2.0"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"},
|
||||
{file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
|
||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"]
|
||||
|
||||
[[package]]
|
||||
name = "plotly"
|
||||
version = "5.18.0"
|
||||
|
|
@ -3663,6 +3782,24 @@ docs = ["sphinx (>=1.7.1)"]
|
|||
redis = ["redis"]
|
||||
tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "3.6.1"
|
||||
description = "A framework for managing and maintaining multi-language pre-commit hooks."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pre_commit-3.6.1-py2.py3-none-any.whl", hash = "sha256:9fe989afcf095d2c4796ce7c553cf28d4d4a9b9346de3cda079bcf40748454a4"},
|
||||
{file = "pre_commit-3.6.1.tar.gz", hash = "sha256:c90961d8aa706f75d60935aba09469a6b0bcb8345f127c3fbee4bdc5f114cf4b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cfgv = ">=2.0.0"
|
||||
identify = ">=1.0.0"
|
||||
nodeenv = ">=0.11.1"
|
||||
pyyaml = ">=5.1"
|
||||
virtualenv = ">=20.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "preshed"
|
||||
version = "3.0.9"
|
||||
|
|
@ -4076,6 +4213,35 @@ benchmarks = ["pytest-benchmark"]
|
|||
tests = ["datasets", "duckdb", "ml_dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"]
|
||||
torch = ["torch"]
|
||||
|
||||
[[package]]
|
||||
name = "pylint"
|
||||
version = "3.0.3"
|
||||
description = "python code static checker"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "pylint-3.0.3-py3-none-any.whl", hash = "sha256:7a1585285aefc5165db81083c3e06363a27448f6b467b3b0f30dbd0ac1f73810"},
|
||||
{file = "pylint-3.0.3.tar.gz", hash = "sha256:58c2398b0301e049609a8429789ec6edf3aabe9b6c5fec916acd18639c16de8b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
astroid = ">=3.0.1,<=3.1.0-dev0"
|
||||
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||
dill = [
|
||||
{version = ">=0.2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
]
|
||||
isort = ">=4.2.5,<5.13.0 || >5.13.0,<6"
|
||||
mccabe = ">=0.6,<0.8"
|
||||
platformdirs = ">=2.2.0"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
tomlkit = ">=0.10.1"
|
||||
|
||||
[package.extras]
|
||||
spelling = ["pyenchant (>=3.2,<4.0)"]
|
||||
testutils = ["gitpython (>3)"]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.23.8"
|
||||
|
|
@ -5790,6 +5956,17 @@ dev = ["tokenizers[testing]"]
|
|||
docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.12.3"
|
||||
|
|
@ -6213,6 +6390,26 @@ testing = ["pytest (>=7.4.0)"]
|
|||
tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"]
|
||||
tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.25.0"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"},
|
||||
{file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
distlib = ">=0.3.7,<1"
|
||||
filelock = ">=3.12.2,<4"
|
||||
platformdirs = ">=3.9.1,<5"
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
|
||||
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "wasabi"
|
||||
version = "1.1.2"
|
||||
|
|
@ -6513,4 +6710,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "671f878d3fc3b864ac68ef553f3f48ac247bfee0ae60540f260fea7fda727e86"
|
||||
content-hash = "d484dd5ab17563c78699c17296b56155a967f10c432f715a96efbd07e15b34e1"
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ iso639 = "^0.1.4"
|
|||
debugpy = "^1.8.0"
|
||||
lancedb = "^0.5.5"
|
||||
pyarrow = "^15.0.0"
|
||||
pylint = "^3.0.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue