ran black, fixed some linting issues

This commit is contained in:
Vasilije 2024-02-15 18:01:12 +01:00
parent 1f3ac1ec97
commit b3f29d3f2d
36 changed files with 1778 additions and 966 deletions

191
api.py
View file

@ -39,12 +39,13 @@ app = FastAPI(debug=True)
# #
# auth = JWTBearer(jwks) # auth = JWTBearer(jwks)
@app.get("/") @app.get("/")
async def root(): async def root():
""" """
Root endpoint that returns a welcome message. Root endpoint that returns a welcome message.
""" """
return { "message": "Hello, World, I am alive!" } return {"message": "Hello, World, I am alive!"}
@app.get("/health") @app.get("/health")
@ -61,8 +62,8 @@ class Payload(BaseModel):
@app.post("/add-memory", response_model=dict) @app.post("/add-memory", response_model=dict)
async def add_memory( async def add_memory(
payload: Payload, payload: Payload,
# files: List[UploadFile] = File(...), # files: List[UploadFile] = File(...),
): ):
try: try:
logging.info(" Adding to Memory ") logging.info(" Adding to Memory ")
@ -70,68 +71,76 @@ async def add_memory(
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from main import load_documents_to_vectorstore from main import load_documents_to_vectorstore
if 'settings' in decoded_payload and decoded_payload['settings'] is not None: if (
settings_for_loader = decoded_payload['settings'] "settings" in decoded_payload
and decoded_payload["settings"] is not None
):
settings_for_loader = decoded_payload["settings"]
else: else:
settings_for_loader = None settings_for_loader = None
if 'content' in decoded_payload and decoded_payload['content'] is not None: if "content" in decoded_payload and decoded_payload["content"] is not None:
content = decoded_payload['content'] content = decoded_payload["content"]
else: else:
content = None content = None
output = await load_documents_to_vectorstore(session, decoded_payload['user_id'], content=content, output = await load_documents_to_vectorstore(
loader_settings=settings_for_loader) session,
decoded_payload["user_id"],
content=content,
loader_settings=settings_for_loader,
)
return JSONResponse(content={"response": output}, status_code=200) return JSONResponse(content={"response": output}, status_code=200)
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(content={"response": {"error": str(e)}}, status_code=503)
content={"response": {"error": str(e)}}, status_code=503
)
@app.post("/add-architecture-public-memory", response_model=dict) @app.post("/add-architecture-public-memory", response_model=dict)
async def add_memory( async def add_memory(
payload: Payload, payload: Payload,
# files: List[UploadFile] = File(...), # files: List[UploadFile] = File(...),
): ):
try: try:
logging.info(" Adding to Memory ") logging.info(" Adding to Memory ")
decoded_payload = payload.payload decoded_payload = payload.payload
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from main import load_documents_to_vectorstore from main import load_documents_to_vectorstore
if 'content' in decoded_payload and decoded_payload['content'] is not None:
content = decoded_payload['content'] if "content" in decoded_payload and decoded_payload["content"] is not None:
content = decoded_payload["content"]
else: else:
content = None content = None
user_id = 'system_user' user_id = "system_user"
loader_settings = { loader_settings = {"format": "PDF", "source": "DEVICE", "path": [".data"]}
"format": "PDF",
"source": "DEVICE",
"path": [".data"]
}
output = await load_documents_to_vectorstore(session, user_id=user_id, content=content, output = await load_documents_to_vectorstore(
loader_settings=loader_settings) session,
user_id=user_id,
content=content,
loader_settings=loader_settings,
)
return JSONResponse(content={"response": output}, status_code=200) return JSONResponse(content={"response": output}, status_code=200)
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(content={"response": {"error": str(e)}}, status_code=503)
content={"response": {"error": str(e)}}, status_code=503
)
@app.post("/user-query-to-graph") @app.post("/user-query-to-graph")
async def user_query_to_graph(payload: Payload): async def user_query_to_graph(payload: Payload):
try: try:
from main import user_query_to_graph_db from main import user_query_to_graph_db
decoded_payload = payload.payload decoded_payload = payload.payload
# Execute the query - replace this with the actual execution method # Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
# Assuming you have a method in Neo4jGraphDB to execute the query # Assuming you have a method in Neo4jGraphDB to execute the query
result = await user_query_to_graph_db(session=session, user_id=decoded_payload['user_id'], result = await user_query_to_graph_db(
query_input=decoded_payload['query']) session=session,
user_id=decoded_payload["user_id"],
query_input=decoded_payload["query"],
)
return result return result
@ -144,17 +153,23 @@ async def document_to_graph_db(payload: Payload):
logging.info("Adding documents to graph db") logging.info("Adding documents to graph db")
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
if 'settings' in decoded_payload and decoded_payload['settings'] is not None: if "settings" in decoded_payload and decoded_payload["settings"] is not None:
settings_for_loader = decoded_payload['settings'] settings_for_loader = decoded_payload["settings"]
else: else:
settings_for_loader = None settings_for_loader = None
if 'memory_type' in decoded_payload and decoded_payload['memory_type'] is not None: if (
memory_type = decoded_payload['memory_type'] "memory_type" in decoded_payload
and decoded_payload["memory_type"] is not None
):
memory_type = decoded_payload["memory_type"]
else: else:
memory_type = None memory_type = None
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
result = await add_documents_to_graph_db(session=session, user_id=decoded_payload['user_id'], result = await add_documents_to_graph_db(
document_memory_types=memory_type) session=session,
user_id=decoded_payload["user_id"],
document_memory_types=memory_type,
)
return result return result
except Exception as e: except Exception as e:
@ -166,10 +181,13 @@ async def cognitive_context_enrichment(payload: Payload):
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
result = await user_context_enrichment(session, user_id=decoded_payload['user_id'], result = await user_context_enrichment(
query=decoded_payload['query'], session,
generative_response=decoded_payload['generative_response'], user_id=decoded_payload["user_id"],
memory_type=decoded_payload['memory_type']) query=decoded_payload["query"],
generative_response=decoded_payload["generative_response"],
memory_type=decoded_payload["memory_type"],
)
return JSONResponse(content={"response": result}, status_code=200) return JSONResponse(content={"response": result}, status_code=200)
except Exception as e: except Exception as e:
@ -182,8 +200,11 @@ async def classify_user_query(payload: Payload):
decoded_payload = payload.payload decoded_payload = payload.payload
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from main import relevance_feedback from main import relevance_feedback
result = await relevance_feedback(query=decoded_payload['query'],
input_type=decoded_payload['knowledge_type']) result = await relevance_feedback(
query=decoded_payload["query"],
input_type=decoded_payload["knowledge_type"],
)
return JSONResponse(content={"response": result}, status_code=200) return JSONResponse(content={"response": result}, status_code=200)
except Exception as e: except Exception as e:
@ -197,9 +218,14 @@ async def user_query_classfier(payload: Payload):
# Execute the query - replace this with the actual execution method # Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from cognitive_architecture.classifiers.classifier import classify_user_query from cognitive_architecture.classifiers.classifier import (
classify_user_query,
)
# Assuming you have a method in Neo4jGraphDB to execute the query # Assuming you have a method in Neo4jGraphDB to execute the query
result = await classify_user_query(session, decoded_payload['user_id'], decoded_payload['query']) result = await classify_user_query(
session, decoded_payload["user_id"], decoded_payload["query"]
)
return JSONResponse(content={"response": result}, status_code=200) return JSONResponse(content={"response": result}, status_code=200)
except Exception as e: except Exception as e:
@ -211,41 +237,43 @@ async def drop_db(payload: Payload):
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
if decoded_payload['operation'] == 'drop': if decoded_payload["operation"] == "drop":
if os.environ.get("AWS_ENV") == "dev":
if os.environ.get('AWS_ENV') == 'dev': host = os.environ.get("POSTGRES_HOST")
host = os.environ.get('POSTGRES_HOST') username = os.environ.get("POSTGRES_USER")
username = os.environ.get('POSTGRES_USER') password = os.environ.get("POSTGRES_PASSWORD")
password = os.environ.get('POSTGRES_PASSWORD') database_name = os.environ.get("POSTGRES_DB")
database_name = os.environ.get('POSTGRES_DB')
else: else:
pass pass
from cognitive_architecture.database.create_database import drop_database, create_admin_engine from cognitive_architecture.database.create_database import (
drop_database,
create_admin_engine,
)
engine = create_admin_engine(username, password, host, database_name) engine = create_admin_engine(username, password, host, database_name)
connection = engine.raw_connection() connection = engine.raw_connection()
drop_database(connection, database_name) drop_database(connection, database_name)
return JSONResponse(content={"response": "DB dropped"}, status_code=200) return JSONResponse(content={"response": "DB dropped"}, status_code=200)
else: else:
if os.environ.get("AWS_ENV") == "dev":
if os.environ.get('AWS_ENV') == 'dev': host = os.environ.get("POSTGRES_HOST")
host = os.environ.get('POSTGRES_HOST') username = os.environ.get("POSTGRES_USER")
username = os.environ.get('POSTGRES_USER') password = os.environ.get("POSTGRES_PASSWORD")
password = os.environ.get('POSTGRES_PASSWORD') database_name = os.environ.get("POSTGRES_DB")
database_name = os.environ.get('POSTGRES_DB')
else: else:
pass pass
from cognitive_architecture.database.create_database import create_database, create_admin_engine from cognitive_architecture.database.create_database import (
create_database,
create_admin_engine,
)
engine = create_admin_engine(username, password, host, database_name) engine = create_admin_engine(username, password, host, database_name)
connection = engine.raw_connection() connection = engine.raw_connection()
create_database(connection, database_name) create_database(connection, database_name)
return JSONResponse(content={"response": " DB drop"}, status_code=200) return JSONResponse(content={"response": " DB drop"}, status_code=200)
except Exception as e: except Exception as e:
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@ -255,18 +283,18 @@ async def create_public_memory(payload: Payload):
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
if 'user_id' in decoded_payload and decoded_payload['user_id'] is not None: if "user_id" in decoded_payload and decoded_payload["user_id"] is not None:
user_id = decoded_payload['user_id'] user_id = decoded_payload["user_id"]
else: else:
user_id = None user_id = None
if 'labels' in decoded_payload and decoded_payload['labels'] is not None: if "labels" in decoded_payload and decoded_payload["labels"] is not None:
labels = decoded_payload['labels'] labels = decoded_payload["labels"]
else: else:
labels = None labels = None
if 'topic' in decoded_payload and decoded_payload['topic'] is not None: if "topic" in decoded_payload and decoded_payload["topic"] is not None:
topic = decoded_payload['topic'] topic = decoded_payload["topic"]
else: else:
topic = None topic = None
@ -286,21 +314,26 @@ async def attach_user_to_public_memory(payload: Payload):
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
if 'topic' in decoded_payload and decoded_payload['topic'] is not None: if "topic" in decoded_payload and decoded_payload["topic"] is not None:
topic = decoded_payload['topic'] topic = decoded_payload["topic"]
else: else:
topic = None topic = None
if 'labels' in decoded_payload and decoded_payload['labels'] is not None: if "labels" in decoded_payload and decoded_payload["labels"] is not None:
labels = decoded_payload['labels'] labels = decoded_payload["labels"]
else: else:
labels = ['sr'] labels = ["sr"]
# Execute the query - replace this with the actual execution method # Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from main import attach_user_to_memory, create_public_memory from main import attach_user_to_memory, create_public_memory
# Assuming you have a method in Neo4jGraphDB to execute the query # Assuming you have a method in Neo4jGraphDB to execute the query
await create_public_memory(user_id=decoded_payload['user_id'], topic=topic, labels=labels) await create_public_memory(
result = await attach_user_to_memory(user_id=decoded_payload['user_id'], topic=topic, labels=labels) user_id=decoded_payload["user_id"], topic=topic, labels=labels
)
result = await attach_user_to_memory(
user_id=decoded_payload["user_id"], topic=topic, labels=labels
)
return JSONResponse(content={"response": result}, status_code=200) return JSONResponse(content={"response": result}, status_code=200)
except Exception as e: except Exception as e:
@ -312,17 +345,21 @@ async def unlink_user_from_public_memory(payload: Payload):
try: try:
decoded_payload = payload.payload decoded_payload = payload.payload
if 'topic' in decoded_payload and decoded_payload['topic'] is not None: if "topic" in decoded_payload and decoded_payload["topic"] is not None:
topic = decoded_payload['topic'] topic = decoded_payload["topic"]
else: else:
topic = None topic = None
# Execute the query - replace this with the actual execution method # Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session: async with session_scope(session=AsyncSessionLocal()) as session:
from main import unlink_user_from_memory from main import unlink_user_from_memory
# Assuming you have a method in Neo4jGraphDB to execute the query # Assuming you have a method in Neo4jGraphDB to execute the query
result = await unlink_user_from_memory(user_id=decoded_payload['user_id'], topic=topic, result = await unlink_user_from_memory(
labels=decoded_payload['labels']) user_id=decoded_payload["user_id"],
topic=topic,
labels=decoded_payload["labels"],
)
return JSONResponse(content={"response": result}, status_code=200) return JSONResponse(content={"response": result}, status_code=200)
except Exception as e: except Exception as e:

View file

@ -3,12 +3,7 @@ import logging
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import json import json
#TO DO, ADD ALL CLASSIFIERS HERE # TO DO, ADD ALL CLASSIFIERS HERE
from langchain.chains import create_extraction_chain from langchain.chains import create_extraction_chain
@ -16,6 +11,7 @@ from langchain.chat_models import ChatOpenAI
from ..config import Config from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader from ..database.vectordb.loaders.loaders import _document_loader
config = Config() config = Config()
config.load() config.load()
OPENAI_API_KEY = config.openai_key OPENAI_API_KEY = config.openai_key
@ -23,149 +19,163 @@ from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader from langchain.document_loaders import DirectoryLoader
async def classify_documents(query:str, document_id:str, content:str): async def classify_documents(query: str, document_id: str, content: str):
document_context = content
document_context = content
logging.info("This is the document context", document_context) logging.info("This is the document context", document_context)
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(
"""You are a summarizer and classifier. Determine what book this is and where does it belong in the output : {query}, Id: {d_id} Document context is: {context}""" """You are a summarizer and classifier. Determine what book this is and where does it belong in the output : {query}, Id: {d_id} Document context is: {context}"""
) )
json_structure = [{ json_structure = [
"name": "summarizer", {
"description": "Summarization and classification", "name": "summarizer",
"parameters": { "description": "Summarization and classification",
"type": "object", "parameters": {
"properties": { "type": "object",
"DocumentCategory": { "properties": {
"type": "string", "DocumentCategory": {
"description": "The classification of documents in groups such as legal, medical, etc." "type": "string",
"description": "The classification of documents in groups such as legal, medical, etc.",
},
"Title": {
"type": "string",
"description": "The title of the document",
},
"Summary": {
"type": "string",
"description": "The summary of the document",
},
"d_id": {"type": "string", "description": "The id of the document"},
}, },
"Title": { "required": ["DocumentCategory", "Title", "Summary", "d_id"],
"type": "string", },
"description": "The title of the document" }
}, ]
"Summary": { chain_filter = prompt_classify | llm.bind(
"type": "string", function_call={"name": "summarizer"}, functions=json_structure
"description": "The summary of the document" )
}, classifier_output = await chain_filter.ainvoke(
"d_id": { {"query": query, "d_id": document_id, "context": str(document_context)}
"type": "string", )
"description": "The id of the document" arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
}
}, "required": ["DocumentCategory", "Title", "Summary","d_id"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "summarizer"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "d_id": document_id, "context": str(document_context)})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str) print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str) arguments_dict = json.loads(arguments_str)
return arguments_dict return arguments_dict
# classify retrievals according to type of retrieval # classify retrievals according to type of retrieval
def classify_retrieval(): def classify_retrieval():
pass pass
async def classify_user_input(query, input_type): async def classify_user_input(query, input_type):
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}""" """You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
) )
json_structure = [{ json_structure = [
"name": "classifier", {
"description": "Classification", "name": "classifier",
"parameters": { "description": "Classification",
"type": "object", "parameters": {
"properties": { "type": "object",
"InputClassification": { "properties": {
"type": "boolean", "InputClassification": {
"description": "The classification of the input" "type": "boolean",
} "description": "The classification of the input",
}, "required": ["InputClassification"] } }
}] },
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure) "required": ["InputClassification"],
classifier_output = await chain_filter.ainvoke({"query": query, "input_type": input_type}) },
arguments_str = classifier_output.additional_kwargs['function_call']['arguments'] }
]
chain_filter = prompt_classify | llm.bind(
function_call={"name": "classifier"}, functions=json_structure
)
classifier_output = await chain_filter.ainvoke(
{"query": query, "input_type": input_type}
)
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
logging.info("This is the arguments string %s", arguments_str) logging.info("This is the arguments string %s", arguments_str)
arguments_dict = json.loads(arguments_str) arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None)) logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
InputClassification = arguments_dict.get('InputClassification', None) InputClassification = arguments_dict.get("InputClassification", None)
logging.info("This is the classification %s", InputClassification) logging.info("This is the classification %s", InputClassification)
return InputClassification return InputClassification
# classify documents according to type of document # classify documents according to type of document
async def classify_call(query, document_summaries): async def classify_call(query, document_summaries):
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}""" """You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
) )
json_structure = [{ json_structure = [
"name": "classifier", {
"description": "Classification", "name": "classifier",
"parameters": { "description": "Classification",
"type": "object", "parameters": {
"properties": { "type": "object",
"DocumentSummary": { "properties": {
"type": "string", "DocumentSummary": {
"description": "The summary of the document and the topic it deals with." "type": "string",
"description": "The summary of the document and the topic it deals with.",
},
"d_id": {"type": "string", "description": "The id of the document"},
}, },
"d_id": { "required": ["DocumentSummary"],
"type": "string", },
"description": "The id of the document" }
} ]
chain_filter = prompt_classify | llm.bind(
function_call={"name": "classifier"}, functions=json_structure
}, "required": ["DocumentSummary"] } )
}] classifier_output = await chain_filter.ainvoke(
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure) {"query": query, "document_summaries": document_summaries}
classifier_output = await chain_filter.ainvoke({"query": query, "document_summaries": document_summaries}) )
arguments_str = classifier_output.additional_kwargs['function_call']['arguments'] arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
print("This is the arguments string", arguments_str) print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str) arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None)) logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
classfier_id = arguments_dict.get('d_id', None) classfier_id = arguments_dict.get("d_id", None)
print("This is the classifier id ", classfier_id) print("This is the classifier id ", classfier_id)
return classfier_id return classfier_id
async def classify_user_query(query, context, document_types): async def classify_user_query(query, context, document_types):
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}""" """You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
) )
json_structure = [{ json_structure = [
"name": "classifier", {
"description": "Classification", "name": "classifier",
"parameters": { "description": "Classification",
"type": "object", "parameters": {
"properties": { "type": "object",
"UserQueryClassifier": { "properties": {
"type": "bool", "UserQueryClassifier": {
"description": "The classification of documents in groups such as legal, medical, etc." "type": "bool",
} "description": "The classification of documents in groups such as legal, medical, etc.",
}
},
}, "required": ["UserQueryClassiffier"] } "required": ["UserQueryClassiffier"],
}] },
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure) }
classifier_output = await chain_filter.ainvoke({"query": query, "context": context, "document_types": document_types}) ]
arguments_str = classifier_output.additional_kwargs['function_call']['arguments'] chain_filter = prompt_classify | llm.bind(
function_call={"name": "classifier"}, functions=json_structure
)
classifier_output = await chain_filter.ainvoke(
{"query": query, "context": context, "document_types": document_types}
)
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
print("This is the arguments string", arguments_str) print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str) arguments_dict = json.loads(arguments_str)
classfier_value = arguments_dict.get('UserQueryClassifier', None) classfier_value = arguments_dict.get("UserQueryClassifier", None)
print("This is the classifier value", classfier_value) print("This is the classifier value", classfier_value)

View file

@ -10,52 +10,65 @@ from dotenv import load_dotenv
base_dir = Path(__file__).resolve().parent.parent base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory # Load the .env file from the base directory
dotenv_path = base_dir / '.env' dotenv_path = base_dir / ".env"
load_dotenv(dotenv_path=dotenv_path) load_dotenv(dotenv_path=dotenv_path)
@dataclass @dataclass
class Config: class Config:
# Paths and Directories # Paths and Directories
memgpt_dir: str = field(default_factory=lambda: os.getenv('COG_ARCH_DIR', 'cognitive_achitecture')) memgpt_dir: str = field(
config_path: str = field(default_factory=lambda: os.path.join(os.getenv('COG_ARCH_DIR', 'cognitive_achitecture'), 'config')) default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture")
)
config_path: str = field(
default_factory=lambda: os.path.join(
os.getenv("COG_ARCH_DIR", "cognitive_achitecture"), "config"
)
)
vectordb:str = 'lancedb' vectordb: str = "lancedb"
# Model parameters # Model parameters
model: str = 'gpt-4-1106-preview' model: str = "gpt-4-1106-preview"
model_endpoint: str = 'openai' model_endpoint: str = "openai"
openai_key: Optional[str] = os.getenv('OPENAI_API_KEY') openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0)) openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
# Embedding parameters # Embedding parameters
embedding_model: str = 'openai' embedding_model: str = "openai"
embedding_dim: int = 1536 embedding_dim: int = 1536
embedding_chunk_size: int = 300 embedding_chunk_size: int = 300
# Database parameters # Database parameters
if os.getenv('ENV') == 'prod' or os.getenv('ENV') == 'dev' or os.getenv('AWS_ENV') == 'dev' or os.getenv('AWS_ENV') == 'prd': if (
graph_database_url: str = os.getenv('GRAPH_DB_URL_PROD') os.getenv("ENV") == "prod"
graph_database_username: str = os.getenv('GRAPH_DB_USER') or os.getenv("ENV") == "dev"
graph_database_password: str = os.getenv('GRAPH_DB_PW') or os.getenv("AWS_ENV") == "dev"
or os.getenv("AWS_ENV") == "prd"
):
graph_database_url: str = os.getenv("GRAPH_DB_URL_PROD")
graph_database_username: str = os.getenv("GRAPH_DB_USER")
graph_database_password: str = os.getenv("GRAPH_DB_PW")
else: else:
graph_database_url: str = os.getenv('GRAPH_DB_URL') graph_database_url: str = os.getenv("GRAPH_DB_URL")
graph_database_username: str = os.getenv('GRAPH_DB_USER') graph_database_username: str = os.getenv("GRAPH_DB_USER")
graph_database_password: str = os.getenv('GRAPH_DB_PW') graph_database_password: str = os.getenv("GRAPH_DB_PW")
weaviate_url: str = os.getenv('WEAVIATE_URL') weaviate_url: str = os.getenv("WEAVIATE_URL")
weaviate_api_key: str = os.getenv('WEAVIATE_API_KEY') weaviate_api_key: str = os.getenv("WEAVIATE_API_KEY")
postgres_user: str = os.getenv('POSTGRES_USER') postgres_user: str = os.getenv("POSTGRES_USER")
postgres_password: str = os.getenv('POSTGRES_PASSWORD') postgres_password: str = os.getenv("POSTGRES_PASSWORD")
postgres_db: str = os.getenv('POSTGRES_DB') postgres_db: str = os.getenv("POSTGRES_DB")
if os.getenv('ENV') == 'prod' or os.getenv('ENV') == 'dev' or os.getenv('AWS_ENV') == 'dev' or os.getenv('AWS_ENV') == 'prd': if (
postgres_host: str = os.getenv('POSTGRES_PROD_HOST') os.getenv("ENV") == "prod"
elif os.getenv('ENV') == 'docker': or os.getenv("ENV") == "dev"
postgres_host: str = os.getenv('POSTGRES_HOST_DOCKER') or os.getenv("AWS_ENV") == "dev"
elif os.getenv('ENV') == 'local': or os.getenv("AWS_ENV") == "prd"
postgres_host: str = os.getenv('POSTGRES_HOST_LOCAL') ):
postgres_host: str = os.getenv("POSTGRES_PROD_HOST")
elif os.getenv("ENV") == "docker":
postgres_host: str = os.getenv("POSTGRES_HOST_DOCKER")
elif os.getenv("ENV") == "local":
postgres_host: str = os.getenv("POSTGRES_HOST_LOCAL")
# Client ID # Client ID
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex) anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
@ -84,12 +97,12 @@ class Config:
# Save the current settings to the config file # Save the current settings to the config file
for attr, value in self.__dict__.items(): for attr, value in self.__dict__.items():
section, option = attr.split('_', 1) section, option = attr.split("_", 1)
if not config.has_section(section): if not config.has_section(section):
config.add_section(section) config.add_section(section)
config.set(section, option, str(value)) config.set(section, option, str(value))
with open(self.config_path, 'w') as configfile: with open(self.config_path, "w") as configfile:
config.write(configfile) config.write(configfile)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:

View file

@ -22,12 +22,17 @@ from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from contextlib import contextmanager from contextlib import contextmanager
from dotenv import load_dotenv from dotenv import load_dotenv
from relationaldb.database import Base # Assuming all models are imported within this module from relationaldb.database import (
from relationaldb.database import DatabaseConfig # Assuming DatabaseConfig is defined as before Base,
) # Assuming all models are imported within this module
from relationaldb.database import (
DatabaseConfig,
) # Assuming DatabaseConfig is defined as before
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DatabaseManager: class DatabaseManager:
def __init__(self, config: DatabaseConfig): def __init__(self, config: DatabaseConfig):
self.config = config self.config = config
@ -36,7 +41,7 @@ class DatabaseManager:
@contextmanager @contextmanager
def get_connection(self): def get_connection(self):
if self.db_type in ['sqlite', 'duckdb']: if self.db_type in ["sqlite", "duckdb"]:
# For SQLite and DuckDB, the engine itself manages connections # For SQLite and DuckDB, the engine itself manages connections
yield self.engine yield self.engine
else: else:
@ -47,7 +52,7 @@ class DatabaseManager:
connection.close() connection.close()
def database_exists(self, db_name): def database_exists(self, db_name):
if self.db_type in ['sqlite', 'duckdb']: if self.db_type in ["sqlite", "duckdb"]:
# For SQLite and DuckDB, check if the database file exists # For SQLite and DuckDB, check if the database file exists
return os.path.exists(db_name) return os.path.exists(db_name)
else: else:
@ -57,14 +62,14 @@ class DatabaseManager:
return result is not None return result is not None
def create_database(self, db_name): def create_database(self, db_name):
if self.db_type not in ['sqlite', 'duckdb']: if self.db_type not in ["sqlite", "duckdb"]:
# For databases like PostgreSQL, create the database explicitly # For databases like PostgreSQL, create the database explicitly
with self.get_connection() as connection: with self.get_connection() as connection:
connection.execution_options(isolation_level="AUTOCOMMIT") connection.execution_options(isolation_level="AUTOCOMMIT")
connection.execute(f"CREATE DATABASE {db_name}") connection.execute(f"CREATE DATABASE {db_name}")
def drop_database(self, db_name): def drop_database(self, db_name):
if self.db_type in ['sqlite', 'duckdb']: if self.db_type in ["sqlite", "duckdb"]:
# For SQLite and DuckDB, simply remove the database file # For SQLite and DuckDB, simply remove the database file
os.remove(db_name) os.remove(db_name)
else: else:
@ -75,9 +80,10 @@ class DatabaseManager:
def create_tables(self): def create_tables(self):
Base.metadata.create_all(bind=self.engine) Base.metadata.create_all(bind=self.engine)
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with SQLite # Example usage with SQLite
config = DatabaseConfig(db_type='sqlite', db_name='mydatabase.db') config = DatabaseConfig(db_type="sqlite", db_name="mydatabase.db")
# For DuckDB, you would set db_type to 'duckdb' and provide the database file name # For DuckDB, you would set db_type to 'duckdb' and provide the database file name
# config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb') # config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb')

View file

@ -1,4 +1,3 @@
import logging import logging
import os import os
@ -26,13 +25,19 @@ from abc import ABC, abstractmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional from typing import List, Dict, Optional
from ...utils import format_dict, append_uuid_to_variable_names, create_edge_variable_mapping, \ from ...utils import (
create_node_variable_mapping, get_unsumarized_vector_db_namespace format_dict,
append_uuid_to_variable_names,
create_edge_variable_mapping,
create_node_variable_mapping,
get_unsumarized_vector_db_namespace,
)
from ...llm.queries import generate_summary, generate_graph from ...llm.queries import generate_summary, generate_graph
import logging import logging
from neo4j import AsyncGraphDatabase, Neo4jError from neo4j import AsyncGraphDatabase, Neo4jError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, List from typing import Any, Dict, Optional, List
DEFAULT_PRESET = "promethai_chat" DEFAULT_PRESET = "promethai_chat"
preset_options = [DEFAULT_PRESET] preset_options = [DEFAULT_PRESET]
PROMETHAI_DIR = os.path.join(os.path.expanduser("~"), ".") PROMETHAI_DIR = os.path.join(os.path.expanduser("~"), ".")
@ -41,7 +46,13 @@ load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
from ...config import Config from ...config import Config
from ...shared.data_models import Node, Edge, KnowledgeGraph, GraphQLQuery, MemorySummary from ...shared.data_models import (
Node,
Edge,
KnowledgeGraph,
GraphQLQuery,
MemorySummary,
)
config = Config() config = Config()
config.load() config.load()
@ -53,8 +64,8 @@ OPENAI_API_KEY = config.openai_key
aclient = instructor.patch(OpenAI()) aclient = instructor.patch(OpenAI())
class AbstractGraphDB(ABC):
class AbstractGraphDB(ABC):
@abstractmethod @abstractmethod
def query(self, query: str, params=None): def query(self, query: str, params=None):
pass pass
@ -73,8 +84,12 @@ class AbstractGraphDB(ABC):
class Neo4jGraphDB(AbstractGraphDB): class Neo4jGraphDB(AbstractGraphDB):
def __init__(self, url: str, username: str, password: str, driver: Optional[Any] = None): def __init__(
self.driver = driver or AsyncGraphDatabase.driver(url, auth=(username, password)) self, url: str, username: str, password: str, driver: Optional[Any] = None
):
self.driver = driver or AsyncGraphDatabase.driver(
url, auth=(username, password)
)
async def close(self) -> None: async def close(self) -> None:
await self.driver.close() await self.driver.close()
@ -84,7 +99,9 @@ class Neo4jGraphDB(AbstractGraphDB):
async with self.driver.session() as session: async with self.driver.session() as session:
yield session yield session
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
try: try:
async with self.get_session() as session: async with self.get_session() as session:
result = await session.run(query, parameters=params) result = await session.run(query, parameters=params)
@ -93,30 +110,28 @@ class Neo4jGraphDB(AbstractGraphDB):
logging.error(f"Neo4j query error: {e.message}") logging.error(f"Neo4j query error: {e.message}")
raise raise
# class Neo4jGraphDB(AbstractGraphDB):
# class Neo4jGraphDB(AbstractGraphDB): # def __init__(self, url, username, password):
# def __init__(self, url, username, password): # # self.graph = Neo4jGraph(url=url, username=username, password=password)
# # self.graph = Neo4jGraph(url=url, username=username, password=password) # from neo4j import GraphDatabase
# from neo4j import GraphDatabase # self.driver = GraphDatabase.driver(url, auth=(username, password))
# self.driver = GraphDatabase.driver(url, auth=(username, password)) # self.openai_key = config.openai_key
# self.openai_key = config.openai_key #
# #
# #
# # def close(self):
# def close(self): # # Method to close the Neo4j driver instance
# # Method to close the Neo4j driver instance # self.driver.close()
# self.driver.close() #
# # def query(self, query, params=None):
# def query(self, query, params=None): # try:
# try: # with self.driver.session() as session:
# with self.driver.session() as session: # result = session.run(query, params).data()
# result = session.run(query, params).data() # return result
# return result # except Exception as e:
# except Exception as e: # logging.error(f"An error occurred while executing the query: {e}")
# logging.error(f"An error occurred while executing the query: {e}") # raise e
# raise e #
#
def create_base_cognitive_architecture(self, user_id: str): def create_base_cognitive_architecture(self, user_id: str):
# Create the user and memory components if they don't exist # Create the user and memory components if they don't exist
@ -131,16 +146,22 @@ class Neo4jGraphDB(AbstractGraphDB):
""" """
return user_memory_cypher return user_memory_cypher
async def retrieve_memory(self, user_id: str, memory_type: str, timestamp: float = None, summarized: bool = None): async def retrieve_memory(
if memory_type == 'SemanticMemory': self,
relationship = 'SEMANTIC_MEMORY' user_id: str,
memory_rel = 'HAS_KNOWLEDGE' memory_type: str,
elif memory_type == 'EpisodicMemory': timestamp: float = None,
relationship = 'EPISODIC_MEMORY' summarized: bool = None,
memory_rel = 'HAS_EVENT' ):
elif memory_type == 'Buffer': if memory_type == "SemanticMemory":
relationship = 'BUFFER' relationship = "SEMANTIC_MEMORY"
memory_rel = 'CURRENTLY_HOLDING' memory_rel = "HAS_KNOWLEDGE"
elif memory_type == "EpisodicMemory":
relationship = "EPISODIC_MEMORY"
memory_rel = "HAS_EVENT"
elif memory_type == "Buffer":
relationship = "BUFFER"
memory_rel = "CURRENTLY_HOLDING"
if timestamp is not None and summarized is not None: if timestamp is not None and summarized is not None:
query = f""" query = f"""
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_{relationship}]->(memory:{memory_type}) MATCH (user:User {{userId: '{user_id}' }})-[:HAS_{relationship}]->(memory:{memory_type})
@ -172,79 +193,100 @@ class Neo4jGraphDB(AbstractGraphDB):
output = self.query(query, params={"user_id": user_id}) output = self.query(query, params={"user_id": user_id})
print("Here is the output", output) print("Here is the output", output)
reduced_graph = await generate_summary(input = output) reduced_graph = await generate_summary(input=output)
return reduced_graph return reduced_graph
def cypher_statement_correcting(self, input: str) -> str:
def cypher_statement_correcting(self, input: str) ->str:
return aclient.chat.completions.create( return aclient.chat.completions.create(
model=config.model, model=config.model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Check the cypher query for syntax issues, and fix any if found and return it as is: {input}. """, "content": f"""Check the cypher query for syntax issues, and fix any if found and return it as is: {input}. """,
}, },
{"role": "system", "content": """You are a top-tier algorithm {
designed for checking cypher queries for neo4j graph databases. You have to return input provided to you as is"""} "role": "system",
"content": """You are a top-tier algorithm
designed for checking cypher queries for neo4j graph databases. You have to return input provided to you as is""",
},
], ],
response_model=GraphQLQuery, response_model=GraphQLQuery,
) )
def generate_create_statements_for_nodes_with_uuid(self, nodes, unique_mapping, base_node_mapping): def generate_create_statements_for_nodes_with_uuid(
self, nodes, unique_mapping, base_node_mapping
):
create_statements = [] create_statements = []
for node in nodes: for node in nodes:
original_variable_name = base_node_mapping[node['id']] original_variable_name = base_node_mapping[node["id"]]
unique_variable_name = unique_mapping[original_variable_name] unique_variable_name = unique_mapping[original_variable_name]
node_label = node['category'].capitalize() node_label = node["category"].capitalize()
properties = {k: v for k, v in node.items() if k not in ['id', 'category']} properties = {k: v for k, v in node.items() if k not in ["id", "category"]}
try: try:
properties = format_dict(properties) properties = format_dict(properties)
except: except:
pass pass
create_statements.append(f"CREATE ({unique_variable_name}:{node_label} {properties})") create_statements.append(
f"CREATE ({unique_variable_name}:{node_label} {properties})"
)
return create_statements return create_statements
# Update the function to generate Cypher CREATE statements for edges with unique variable names # Update the function to generate Cypher CREATE statements for edges with unique variable names
def generate_create_statements_for_edges_with_uuid(self, user_id, edges, unique_mapping, base_node_mapping): def generate_create_statements_for_edges_with_uuid(
self, user_id, edges, unique_mapping, base_node_mapping
):
create_statements = [] create_statements = []
with_statement = f"WITH {', '.join(unique_mapping.values())}, user , semantic, episodic, buffer" with_statement = f"WITH {', '.join(unique_mapping.values())}, user , semantic, episodic, buffer"
create_statements.append(with_statement) create_statements.append(with_statement)
for edge in edges: for edge in edges:
# print("HERE IS THE EDGE", edge) # print("HERE IS THE EDGE", edge)
source_variable = unique_mapping[base_node_mapping[edge['source']]] source_variable = unique_mapping[base_node_mapping[edge["source"]]]
target_variable = unique_mapping[base_node_mapping[edge['target']]] target_variable = unique_mapping[base_node_mapping[edge["target"]]]
relationship = edge['description'].replace(" ", "_").upper() relationship = edge["description"].replace(" ", "_").upper()
create_statements.append(f"CREATE ({source_variable})-[:{relationship}]->({target_variable})") create_statements.append(
f"CREATE ({source_variable})-[:{relationship}]->({target_variable})"
)
return create_statements return create_statements
def generate_memory_type_relationships_with_uuid_and_time_context(self, user_id, nodes, unique_mapping, base_node_mapping): def generate_memory_type_relationships_with_uuid_and_time_context(
self, user_id, nodes, unique_mapping, base_node_mapping
):
create_statements = [] create_statements = []
with_statement = f"WITH {', '.join(unique_mapping.values())}, user, semantic, episodic, buffer" with_statement = f"WITH {', '.join(unique_mapping.values())}, user, semantic, episodic, buffer"
create_statements.append(with_statement) create_statements.append(with_statement)
# Loop through each node and create relationships based on memory_type # Loop through each node and create relationships based on memory_type
for node in nodes: for node in nodes:
original_variable_name = base_node_mapping[node['id']] original_variable_name = base_node_mapping[node["id"]]
unique_variable_name = unique_mapping[original_variable_name] unique_variable_name = unique_mapping[original_variable_name]
if node['memory_type'] == 'semantic': if node["memory_type"] == "semantic":
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_variable_name})") create_statements.append(
elif node['memory_type'] == 'episodic': f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_variable_name})"
create_statements.append(f"CREATE (episodic)-[:HAS_EVENT]->({unique_variable_name})") )
if node['category'] == 'time': elif node["memory_type"] == "episodic":
create_statements.append(f"CREATE (buffer)-[:HAS_TIME_CONTEXT]->({unique_variable_name})") create_statements.append(
f"CREATE (episodic)-[:HAS_EVENT]->({unique_variable_name})"
)
if node["category"] == "time":
create_statements.append(
f"CREATE (buffer)-[:HAS_TIME_CONTEXT]->({unique_variable_name})"
)
# Assuming buffer holds all actions and times # Assuming buffer holds all actions and times
# if node['category'] in ['action', 'time']: # if node['category'] in ['action', 'time']:
create_statements.append(f"CREATE (buffer)-[:CURRENTLY_HOLDING]->({unique_variable_name})") create_statements.append(
f"CREATE (buffer)-[:CURRENTLY_HOLDING]->({unique_variable_name})"
)
return create_statements return create_statements
async def generate_cypher_query_for_user_prompt_decomposition(self, user_id:str, query:str): async def generate_cypher_query_for_user_prompt_decomposition(
self, user_id: str, query: str
):
graph: KnowledgeGraph = generate_graph(query) graph: KnowledgeGraph = generate_graph(query)
import time import time
for node in graph.nodes: for node in graph.nodes:
node.created_at = time.time() node.created_at = time.time()
node.summarized = False node.summarized = False
@ -254,19 +296,41 @@ class Neo4jGraphDB(AbstractGraphDB):
edge.summarized = False edge.summarized = False
graph_dic = graph.dict() graph_dic = graph.dict()
node_variable_mapping = create_node_variable_mapping(graph_dic['nodes']) node_variable_mapping = create_node_variable_mapping(graph_dic["nodes"])
edge_variable_mapping = create_edge_variable_mapping(graph_dic['edges']) edge_variable_mapping = create_edge_variable_mapping(graph_dic["edges"])
# Create unique variable names for each node # Create unique variable names for each node
unique_node_variable_mapping = append_uuid_to_variable_names(node_variable_mapping) unique_node_variable_mapping = append_uuid_to_variable_names(
unique_edge_variable_mapping = append_uuid_to_variable_names(edge_variable_mapping) node_variable_mapping
create_nodes_statements = self.generate_create_statements_for_nodes_with_uuid(graph_dic['nodes'], unique_node_variable_mapping, node_variable_mapping) )
create_edges_statements =self.generate_create_statements_for_edges_with_uuid(user_id, graph_dic['edges'], unique_node_variable_mapping, node_variable_mapping) unique_edge_variable_mapping = append_uuid_to_variable_names(
edge_variable_mapping
)
create_nodes_statements = self.generate_create_statements_for_nodes_with_uuid(
graph_dic["nodes"], unique_node_variable_mapping, node_variable_mapping
)
create_edges_statements = self.generate_create_statements_for_edges_with_uuid(
user_id,
graph_dic["edges"],
unique_node_variable_mapping,
node_variable_mapping,
)
memory_type_statements_with_uuid_and_time_context = self.generate_memory_type_relationships_with_uuid_and_time_context(user_id, memory_type_statements_with_uuid_and_time_context = (
graph_dic['nodes'], unique_node_variable_mapping, node_variable_mapping) self.generate_memory_type_relationships_with_uuid_and_time_context(
user_id,
graph_dic["nodes"],
unique_node_variable_mapping,
node_variable_mapping,
)
)
# # Combine all statements # # Combine all statements
cypher_statements = [self.create_base_cognitive_architecture(user_id)] + create_nodes_statements + create_edges_statements + memory_type_statements_with_uuid_and_time_context cypher_statements = (
[self.create_base_cognitive_architecture(user_id)]
+ create_nodes_statements
+ create_edges_statements
+ memory_type_statements_with_uuid_and_time_context
)
cypher_statements_joined = "\n".join(cypher_statements) cypher_statements_joined = "\n".join(cypher_statements)
logging.info("User Cypher Query raw: %s", cypher_statements_joined) logging.info("User Cypher Query raw: %s", cypher_statements_joined)
# corrected_cypher_statements = self.cypher_statement_correcting(input = cypher_statements_joined) # corrected_cypher_statements = self.cypher_statement_correcting(input = cypher_statements_joined)
@ -274,15 +338,15 @@ class Neo4jGraphDB(AbstractGraphDB):
# return corrected_cypher_statements.query # return corrected_cypher_statements.query
return cypher_statements_joined return cypher_statements_joined
def update_user_query_for_user_prompt_decomposition(self, user_id, user_query): def update_user_query_for_user_prompt_decomposition(self, user_id, user_query):
pass pass
def delete_all_user_memories(self, user_id): def delete_all_user_memories(self, user_id):
try: try:
# Check if the user exists # Check if the user exists
user_exists = self.query(f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user") user_exists = self.query(
f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user"
)
if not user_exists: if not user_exists:
return f"No user found with ID: {user_id}" return f"No user found with ID: {user_id}"
@ -304,12 +368,14 @@ class Neo4jGraphDB(AbstractGraphDB):
def delete_specific_memory_type(self, user_id, memory_type): def delete_specific_memory_type(self, user_id, memory_type):
try: try:
# Check if the user exists # Check if the user exists
user_exists = self.query(f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user") user_exists = self.query(
f"MATCH (user:User {{userId: '{user_id}'}}) RETURN user"
)
if not user_exists: if not user_exists:
return f"No user found with ID: {user_id}" return f"No user found with ID: {user_id}"
# Validate memory type # Validate memory type
if memory_type not in ['SemanticMemory', 'EpisodicMemory', 'Buffer']: if memory_type not in ["SemanticMemory", "EpisodicMemory", "Buffer"]:
return "Invalid memory type. Choose from 'SemanticMemory', 'EpisodicMemory', or 'Buffer'." return "Invalid memory type. Choose from 'SemanticMemory', 'EpisodicMemory', or 'Buffer'."
# Delete specific memory type nodes and relationships for the given user # Delete specific memory type nodes and relationships for the given user
@ -322,7 +388,9 @@ class Neo4jGraphDB(AbstractGraphDB):
except Exception as e: except Exception as e:
return f"An error occurred: {str(e)}" return f"An error occurred: {str(e)}"
def retrieve_semantic_memory(self, user_id: str, timestamp: float = None, summarized: bool = None): def retrieve_semantic_memory(
self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None: if timestamp is not None and summarized is not None:
query = f""" query = f"""
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory) MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory)
@ -352,7 +420,9 @@ class Neo4jGraphDB(AbstractGraphDB):
""" """
return self.query(query, params={"user_id": user_id}) return self.query(query, params={"user_id": user_id})
def retrieve_episodic_memory(self, user_id: str, timestamp: float = None, summarized: bool = None): def retrieve_episodic_memory(
self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None: if timestamp is not None and summarized is not None:
query = f""" query = f"""
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_EPISODIC_MEMORY]->(episodic:EpisodicMemory) MATCH (user:User {{userId: '{user_id}' }})-[:HAS_EPISODIC_MEMORY]->(episodic:EpisodicMemory)
@ -382,8 +452,9 @@ class Neo4jGraphDB(AbstractGraphDB):
""" """
return self.query(query, params={"user_id": user_id}) return self.query(query, params={"user_id": user_id})
def retrieve_buffer_memory(
def retrieve_buffer_memory(self, user_id: str, timestamp: float = None, summarized: bool = None): self, user_id: str, timestamp: float = None, summarized: bool = None
):
if timestamp is not None and summarized is not None: if timestamp is not None and summarized is not None:
query = f""" query = f"""
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_BUFFER]->(buffer:Buffer) MATCH (user:User {{userId: '{user_id}' }})-[:HAS_BUFFER]->(buffer:Buffer)
@ -413,8 +484,6 @@ class Neo4jGraphDB(AbstractGraphDB):
""" """
return self.query(query, params={"user_id": user_id}) return self.query(query, params={"user_id": user_id})
def retrieve_public_memory(self, user_id: str): def retrieve_public_memory(self, user_id: str):
query = """ query = """
MATCH (user:User {userId: $user_id})-[:HAS_PUBLIC_MEMORY]->(public:PublicMemory) MATCH (user:User {userId: $user_id})-[:HAS_PUBLIC_MEMORY]->(public:PublicMemory)
@ -422,23 +491,33 @@ class Neo4jGraphDB(AbstractGraphDB):
RETURN document RETURN document
""" """
return self.query(query, params={"user_id": user_id}) return self.query(query, params={"user_id": user_id})
def generate_graph_semantic_memory_document_summary(self, document_summary : str, unique_graphdb_mapping_values: dict, document_namespace: str):
""" This function takes a document and generates a document summary in Semantic Memory""" def generate_graph_semantic_memory_document_summary(
self,
document_summary: str,
unique_graphdb_mapping_values: dict,
document_namespace: str,
):
"""This function takes a document and generates a document summary in Semantic Memory"""
create_statements = [] create_statements = []
with_statement = f"WITH {', '.join(unique_graphdb_mapping_values.values())}, user, semantic, episodic, buffer" with_statement = f"WITH {', '.join(unique_graphdb_mapping_values.values())}, user, semantic, episodic, buffer"
create_statements.append(with_statement) create_statements.append(with_statement)
# Loop through each node and create relationships based on memory_type # Loop through each node and create relationships based on memory_type
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})") create_statements.append(
f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})"
)
return create_statements return create_statements
def generate_document_summary(
def generate_document_summary(self, document_summary : str, unique_graphdb_mapping_values: dict, document_namespace: str): self,
""" This function takes a document and generates a document summary in Semantic Memory""" document_summary: str,
unique_graphdb_mapping_values: dict,
document_namespace: str,
):
"""This function takes a document and generates a document summary in Semantic Memory"""
# fetch namespace from postgres db # fetch namespace from postgres db
# fetch 1st and last page from vector store # fetch 1st and last page from vector store
@ -450,12 +529,15 @@ class Neo4jGraphDB(AbstractGraphDB):
# Loop through each node and create relationships based on memory_type # Loop through each node and create relationships based on memory_type
create_statements.append(f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})") create_statements.append(
f"CREATE (semantic)-[:HAS_KNOWLEDGE]->({unique_graphdb_mapping_values})"
)
return create_statements return create_statements
async def get_memory_linked_document_summaries(self, user_id: str, memory_type: str = "PublicMemory"): async def get_memory_linked_document_summaries(
self, user_id: str, memory_type: str = "PublicMemory"
):
""" """
Retrieve a list of summaries for all documents associated with a given memory type for a user. Retrieve a list of summaries for all documents associated with a given memory type for a user.
@ -474,23 +556,30 @@ class Neo4jGraphDB(AbstractGraphDB):
elif memory_type == "SemanticMemory": elif memory_type == "SemanticMemory":
relationship = "HAS_SEMANTIC_MEMORY" relationship = "HAS_SEMANTIC_MEMORY"
try: try:
query = f''' query = f"""
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document) MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
RETURN document.d_id AS d_id, document.summary AS summary RETURN document.d_id AS d_id, document.summary AS summary
''' """
logging.info(f"Generated Cypher query: {query}") logging.info(f"Generated Cypher query: {query}")
result = self.query(query) result = self.query(query)
logging.info(f"Result: {result}") logging.info(f"Result: {result}")
return [{"d_id": record.get("d_id", None), "summary": record.get("summary", "No summary available")} for return [
record in result] {
"d_id": record.get("d_id", None),
"summary": record.get("summary", "No summary available"),
}
for record in result
]
except Exception as e: except Exception as e:
logging.error(f"An error occurred while retrieving document summary: {str(e)}") logging.error(
f"An error occurred while retrieving document summary: {str(e)}"
)
return None return None
async def get_memory_linked_document_ids(
self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"
async def get_memory_linked_document_ids(self, user_id: str, summary_id: str, memory_type: str = "PublicMemory"): ):
""" """
Retrieve a list of document IDs for a specific category associated with a given memory type for a user. Retrieve a list of document IDs for a specific category associated with a given memory type for a user.
@ -511,11 +600,11 @@ class Neo4jGraphDB(AbstractGraphDB):
elif memory_type == "SemanticMemory": elif memory_type == "SemanticMemory":
relationship = "HAS_SEMANTIC_MEMORY" relationship = "HAS_SEMANTIC_MEMORY"
try: try:
query = f''' query = f"""
MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document) MATCH (user:User {{userId: '{user_id}'}})-[:{relationship}]->(memory:{memory_type})-[:HAS_DOCUMENT]->(document:Document)
WHERE document.d_id = '{summary_id}' WHERE document.d_id = '{summary_id}'
RETURN document.d_id AS d_id RETURN document.d_id AS d_id
''' """
logging.info(f"Generated Cypher query: {query}") logging.info(f"Generated Cypher query: {query}")
result = self.query(query) result = self.query(query)
return [record["d_id"] for record in result] return [record["d_id"] for record in result]
@ -523,9 +612,13 @@ class Neo4jGraphDB(AbstractGraphDB):
logging.error(f"An error occurred while retrieving document IDs: {str(e)}") logging.error(f"An error occurred while retrieving document IDs: {str(e)}")
return None return None
def create_document_node_cypher(
def create_document_node_cypher(self, document_summary: dict, user_id: str, self,
memory_type: str = "PublicMemory",public_memory_id:str=None) -> str: document_summary: dict,
user_id: str,
memory_type: str = "PublicMemory",
public_memory_id: str = None,
) -> str:
""" """
Generate a Cypher query to create a Document node. If the memory type is 'Semantic', Generate a Cypher query to create a Document node. If the memory type is 'Semantic',
link it to a SemanticMemory node for a user. If the memory type is 'PublicMemory', link it to a SemanticMemory node for a user. If the memory type is 'PublicMemory',
@ -546,37 +639,46 @@ class Neo4jGraphDB(AbstractGraphDB):
# Validate the input parameters # Validate the input parameters
if not isinstance(document_summary, dict): if not isinstance(document_summary, dict):
raise ValueError("The document_summary must be a dictionary.") raise ValueError("The document_summary must be a dictionary.")
if not all(key in document_summary for key in ['DocumentCategory', 'Title', 'Summary', 'd_id']): if not all(
raise ValueError("The document_summary dictionary is missing required keys.") key in document_summary
for key in ["DocumentCategory", "Title", "Summary", "d_id"]
):
raise ValueError(
"The document_summary dictionary is missing required keys."
)
if not isinstance(user_id, str) or not user_id: if not isinstance(user_id, str) or not user_id:
raise ValueError("The user_id must be a non-empty string.") raise ValueError("The user_id must be a non-empty string.")
if memory_type not in ["SemanticMemory", "PublicMemory"]: if memory_type not in ["SemanticMemory", "PublicMemory"]:
raise ValueError("The memory_type must be either 'Semantic' or 'PublicMemory'.") raise ValueError(
"The memory_type must be either 'Semantic' or 'PublicMemory'."
)
# Escape single quotes in the document summary data # Escape single quotes in the document summary data
title = document_summary['Title'].replace("'", "\\'") title = document_summary["Title"].replace("'", "\\'")
summary = document_summary['Summary'].replace("'", "\\'") summary = document_summary["Summary"].replace("'", "\\'")
document_category = document_summary['DocumentCategory'].replace("'", "\\'") document_category = document_summary["DocumentCategory"].replace("'", "\\'")
d_id = document_summary['d_id'].replace("'", "\\'") d_id = document_summary["d_id"].replace("'", "\\'")
memory_node_type = "SemanticMemory" if memory_type == "SemanticMemory" else "PublicMemory" memory_node_type = (
"SemanticMemory" if memory_type == "SemanticMemory" else "PublicMemory"
)
user_memory_link = '' user_memory_link = ""
if memory_type == "SemanticMemory": if memory_type == "SemanticMemory":
user_memory_link = f''' user_memory_link = f"""
// Ensure the User node exists // Ensure the User node exists
MERGE (user:User {{ userId: '{user_id}' }}) MERGE (user:User {{ userId: '{user_id}' }})
MERGE (memory:SemanticMemory {{ userId: '{user_id}' }}) MERGE (memory:SemanticMemory {{ userId: '{user_id}' }})
MERGE (user)-[:HAS_SEMANTIC_MEMORY]->(memory) MERGE (user)-[:HAS_SEMANTIC_MEMORY]->(memory)
''' """
elif memory_type == "PublicMemory": elif memory_type == "PublicMemory":
logging.info(f"Public memory id: {public_memory_id}") logging.info(f"Public memory id: {public_memory_id}")
user_memory_link = f''' user_memory_link = f"""
// Merge with the existing PublicMemory node or create a new one if it does not exist // Merge with the existing PublicMemory node or create a new one if it does not exist
MATCH (memory:PublicMemory {{ memoryId: {public_memory_id} }}) MATCH (memory:PublicMemory {{ memoryId: {public_memory_id} }})
''' """
cypher_query = f''' cypher_query = f"""
{user_memory_link} {user_memory_link}
// Create the Document node with its properties // Create the Document node with its properties
@ -590,13 +692,15 @@ class Neo4jGraphDB(AbstractGraphDB):
// Link the Document node to the {memory_node_type} node // Link the Document node to the {memory_node_type} node
MERGE (memory)-[:HAS_DOCUMENT]->(document) MERGE (memory)-[:HAS_DOCUMENT]->(document)
''' """
logging.info(f"Generated Cypher query: {cypher_query}") logging.info(f"Generated Cypher query: {cypher_query}")
return cypher_query return cypher_query
def update_document_node_with_db_ids(self, vectordb_namespace: str, document_id: str, user_id: str = None): def update_document_node_with_db_ids(
self, vectordb_namespace: str, document_id: str, user_id: str = None
):
""" """
Update the namespace of a Document node in the database. The document can be linked Update the namespace of a Document node in the database. The document can be linked
either to a SemanticMemory node (if a user ID is provided) or to a PublicMemory node. either to a SemanticMemory node (if a user ID is provided) or to a PublicMemory node.
@ -612,23 +716,24 @@ class Neo4jGraphDB(AbstractGraphDB):
if user_id: if user_id:
# Update for a document linked to a SemanticMemory node # Update for a document linked to a SemanticMemory node
cypher_query = f''' cypher_query = f"""
MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}}) MATCH (user:User {{userId: '{user_id}' }})-[:HAS_SEMANTIC_MEMORY]->(:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}})
SET document.vectordbNamespace = '{vectordb_namespace}' SET document.vectordbNamespace = '{vectordb_namespace}'
RETURN document RETURN document
''' """
else: else:
# Update for a document linked to a PublicMemory node # Update for a document linked to a PublicMemory node
cypher_query = f''' cypher_query = f"""
MATCH (:PublicMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}}) MATCH (:PublicMemory)-[:HAS_DOCUMENT]->(document:Document {{d_id: '{document_id}'}})
SET document.vectordbNamespace = '{vectordb_namespace}' SET document.vectordbNamespace = '{vectordb_namespace}'
RETURN document RETURN document
''' """
return cypher_query return cypher_query
def run_merge_query(self, user_id: str, memory_type: str, def run_merge_query(
similarity_threshold: float) -> str: self, user_id: str, memory_type: str, similarity_threshold: float
) -> str:
""" """
Constructs a Cypher query to merge nodes in a Neo4j database based on a similarity threshold. Constructs a Cypher query to merge nodes in a Neo4j database based on a similarity threshold.
@ -645,29 +750,28 @@ class Neo4jGraphDB(AbstractGraphDB):
Returns: Returns:
str: A Cypher query string that can be executed in a Neo4j session. str: A Cypher query string that can be executed in a Neo4j session.
""" """
if memory_type == 'SemanticMemory': if memory_type == "SemanticMemory":
relationship_base = 'HAS_SEMANTIC_MEMORY' relationship_base = "HAS_SEMANTIC_MEMORY"
relationship_type = 'HAS_KNOWLEDGE' relationship_type = "HAS_KNOWLEDGE"
memory_label = 'semantic' memory_label = "semantic"
elif memory_type == 'EpisodicMemory': elif memory_type == "EpisodicMemory":
relationship_base = 'HAS_EPISODIC_MEMORY' relationship_base = "HAS_EPISODIC_MEMORY"
# relationship_type = 'EPISODIC_MEMORY' # relationship_type = 'EPISODIC_MEMORY'
relationship_type = 'HAS_EVENT' relationship_type = "HAS_EVENT"
memory_label='episodic' memory_label = "episodic"
elif memory_type == 'Buffer': elif memory_type == "Buffer":
relationship_base = 'HAS_BUFFER_MEMORY' relationship_base = "HAS_BUFFER_MEMORY"
relationship_type = 'CURRENTLY_HOLDING' relationship_type = "CURRENTLY_HOLDING"
memory_label= 'buffer' memory_label = "buffer"
query = f"""MATCH (u:User {{userId: '{user_id}'}})-[:{relationship_base}]->(sm:{memory_type})
query= f"""MATCH (u:User {{userId: '{user_id}'}})-[:{relationship_base}]->(sm:{memory_type})
MATCH (sm)-[:{relationship_type}]->(n) MATCH (sm)-[:{relationship_type}]->(n)
RETURN labels(n) AS NodeType, collect(n) AS Nodes RETURN labels(n) AS NodeType, collect(n) AS Nodes
""" """
node_results = self.query(query) node_results = self.query(query)
node_types = [record['NodeType'] for record in node_results] node_types = [record["NodeType"] for record in node_results]
for node in node_types: for node in node_types:
query = f""" query = f"""
@ -703,16 +807,18 @@ class Neo4jGraphDB(AbstractGraphDB):
- Exception: If an error occurs during the database query execution. - Exception: If an error occurs during the database query execution.
""" """
try: try:
query = f''' query = f"""
MATCH (user:User {{userId: '{user_id}'}})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document) MATCH (user:User {{userId: '{user_id}'}})-[:HAS_SEMANTIC_MEMORY]->(semantic:SemanticMemory)-[:HAS_DOCUMENT]->(document:Document)
WHERE document.documentCategory = '{category}' WHERE document.documentCategory = '{category}'
RETURN document.vectordbNamespace AS namespace RETURN document.vectordbNamespace AS namespace
''' """
result = self.query(query) result = self.query(query)
namespaces = [record["namespace"] for record in result] namespaces = [record["namespace"] for record in result]
return namespaces return namespaces
except Exception as e: except Exception as e:
logging.error(f"An error occurred while retrieving namespaces by document category: {str(e)}") logging.error(
f"An error occurred while retrieving namespaces by document category: {str(e)}"
)
return None return None
async def create_memory_node(self, labels, topic=None): async def create_memory_node(self, labels, topic=None):
@ -734,7 +840,7 @@ class Neo4jGraphDB(AbstractGraphDB):
topic = "PublicMemory" topic = "PublicMemory"
# Prepare labels as a string # Prepare labels as a string
label_list = ', '.join(f"'{label}'" for label in labels) label_list = ", ".join(f"'{label}'" for label in labels)
# Cypher query to find or create the memory node with the given description and labels # Cypher query to find or create the memory node with the given description and labels
memory_cypher = f""" memory_cypher = f"""
@ -746,17 +852,24 @@ class Neo4jGraphDB(AbstractGraphDB):
try: try:
result = self.query(memory_cypher) result = self.query(memory_cypher)
# Assuming the result is a list of records, where each record contains 'memoryId' # Assuming the result is a list of records, where each record contains 'memoryId'
memory_id = result[0]['memoryId'] if result else None memory_id = result[0]["memoryId"] if result else None
self.close() self.close()
return memory_id return memory_id
except Neo4jError as e: except Neo4jError as e:
logging.error(f"Error creating or finding memory node: {e}") logging.error(f"Error creating or finding memory node: {e}")
raise raise
def link_user_to_public(self, user_id: str, public_property_value: str, public_property_name: str = 'name', def link_user_to_public(
relationship_type: str = 'HAS_PUBLIC'): self,
user_id: str,
public_property_value: str,
public_property_name: str = "name",
relationship_type: str = "HAS_PUBLIC",
):
if not user_id or not public_property_value: if not user_id or not public_property_value:
raise ValueError("Valid User ID and Public property value are required for linking.") raise ValueError(
"Valid User ID and Public property value are required for linking."
)
try: try:
link_cypher = f""" link_cypher = f"""
@ -784,7 +897,9 @@ class Neo4jGraphDB(AbstractGraphDB):
logging.error(f"Error deleting {topic} memory node: {e}") logging.error(f"Error deleting {topic} memory node: {e}")
raise raise
def unlink_memory_from_user(self, memory_id: int, user_id: str, topic: str='PublicMemory') -> None: def unlink_memory_from_user(
self, memory_id: int, user_id: str, topic: str = "PublicMemory"
) -> None:
""" """
Unlink a memory node from a user node. Unlink a memory node from a user node.
@ -801,9 +916,13 @@ class Neo4jGraphDB(AbstractGraphDB):
raise ValueError("Valid User ID and Memory ID are required for unlinking.") raise ValueError("Valid User ID and Memory ID are required for unlinking.")
if topic not in ["SemanticMemory", "PublicMemory"]: if topic not in ["SemanticMemory", "PublicMemory"]:
raise ValueError("The memory_type must be either 'SemanticMemory' or 'PublicMemory'.") raise ValueError(
"The memory_type must be either 'SemanticMemory' or 'PublicMemory'."
)
relationship_type = "HAS_SEMANTIC_MEMORY" if topic == "SemanticMemory" else "HAS_PUBLIC_MEMORY" relationship_type = (
"HAS_SEMANTIC_MEMORY" if topic == "SemanticMemory" else "HAS_PUBLIC_MEMORY"
)
try: try:
unlink_cypher = f""" unlink_cypher = f"""
@ -815,7 +934,6 @@ class Neo4jGraphDB(AbstractGraphDB):
logging.error(f"Error unlinking {topic} from user: {e}") logging.error(f"Error unlinking {topic} from user: {e}")
raise raise
def link_public_memory_to_user(self, memory_id, user_id): def link_public_memory_to_user(self, memory_id, user_id):
# Link an existing Public Memory node to a User node # Link an existing Public Memory node to a User node
link_cypher = f""" link_cypher = f"""
@ -825,7 +943,7 @@ class Neo4jGraphDB(AbstractGraphDB):
""" """
self.query(link_cypher) self.query(link_cypher)
def retrieve_node_id_for_memory_type(self, topic: str = 'SemanticMemory'): def retrieve_node_id_for_memory_type(self, topic: str = "SemanticMemory"):
link_cypher = f""" MATCH(publicMemory: {topic}) link_cypher = f""" MATCH(publicMemory: {topic})
RETURN RETURN
id(publicMemory) id(publicMemory)
@ -835,18 +953,14 @@ class Neo4jGraphDB(AbstractGraphDB):
return node_ids return node_ids
from .networkx_graph import NetworkXGraphDB from .networkx_graph import NetworkXGraphDB
class GraphDBFactory: class GraphDBFactory:
def create_graph_db(self, db_type, **kwargs): def create_graph_db(self, db_type, **kwargs):
if db_type == 'neo4j': if db_type == "neo4j":
return Neo4jGraphDB(**kwargs) return Neo4jGraphDB(**kwargs)
elif db_type == 'networkx': elif db_type == "networkx":
return NetworkXGraphDB(**kwargs) return NetworkXGraphDB(**kwargs)
else: else:
raise ValueError(f"Unsupported database type: {db_type}") raise ValueError(f"Unsupported database type: {db_type}")

View file

@ -4,7 +4,7 @@ import networkx as nx
class NetworkXGraphDB: class NetworkXGraphDB:
def __init__(self, filename='networkx_graph.pkl'): def __init__(self, filename="networkx_graph.pkl"):
self.filename = filename self.filename = filename
try: try:
self.graph = self.load_graph() # Attempt to load an existing graph self.graph = self.load_graph() # Attempt to load an existing graph
@ -12,32 +12,36 @@ class NetworkXGraphDB:
self.graph = nx.Graph() # Create a new graph if loading failed self.graph = nx.Graph() # Create a new graph if loading failed
def save_graph(self): def save_graph(self):
""" Save the graph to a file using pickle """ """Save the graph to a file using pickle"""
with open(self.filename, 'wb') as f: with open(self.filename, "wb") as f:
pickle.dump(self.graph, f) pickle.dump(self.graph, f)
def load_graph(self): def load_graph(self):
""" Load the graph from a file using pickle """ """Load the graph from a file using pickle"""
with open(self.filename, 'rb') as f: with open(self.filename, "rb") as f:
return pickle.load(f) return pickle.load(f)
def create_base_cognitive_architecture(self, user_id: str): def create_base_cognitive_architecture(self, user_id: str):
# Add nodes for user and memory types if they don't exist # Add nodes for user and memory types if they don't exist
self.graph.add_node(user_id, type='User') self.graph.add_node(user_id, type="User")
self.graph.add_node(f"{user_id}_semantic", type='SemanticMemory') self.graph.add_node(f"{user_id}_semantic", type="SemanticMemory")
self.graph.add_node(f"{user_id}_episodic", type='EpisodicMemory') self.graph.add_node(f"{user_id}_episodic", type="EpisodicMemory")
self.graph.add_node(f"{user_id}_buffer", type='Buffer') self.graph.add_node(f"{user_id}_buffer", type="Buffer")
# Add edges to connect user to memory types # Add edges to connect user to memory types
self.graph.add_edge(user_id, f"{user_id}_semantic", relation='HAS_SEMANTIC_MEMORY') self.graph.add_edge(
self.graph.add_edge(user_id, f"{user_id}_episodic", relation='HAS_EPISODIC_MEMORY') user_id, f"{user_id}_semantic", relation="HAS_SEMANTIC_MEMORY"
self.graph.add_edge(user_id, f"{user_id}_buffer", relation='HAS_BUFFER') )
self.graph.add_edge(
user_id, f"{user_id}_episodic", relation="HAS_EPISODIC_MEMORY"
)
self.graph.add_edge(user_id, f"{user_id}_buffer", relation="HAS_BUFFER")
self.save_graph() # Save the graph after modifying it self.save_graph() # Save the graph after modifying it
def delete_all_user_memories(self, user_id: str): def delete_all_user_memories(self, user_id: str):
# Remove nodes and edges related to the user's memories # Remove nodes and edges related to the user's memories
for memory_type in ['semantic', 'episodic', 'buffer']: for memory_type in ["semantic", "episodic", "buffer"]:
memory_node = f"{user_id}_{memory_type}" memory_node = f"{user_id}_{memory_type}"
self.graph.remove_node(memory_node) self.graph.remove_node(memory_node)
@ -60,31 +64,59 @@ class NetworkXGraphDB:
def retrieve_buffer_memory(self, user_id: str): def retrieve_buffer_memory(self, user_id: str):
return [n for n in self.graph.neighbors(f"{user_id}_buffer")] return [n for n in self.graph.neighbors(f"{user_id}_buffer")]
def generate_graph_semantic_memory_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id): def generate_graph_semantic_memory_document_summary(
self,
document_summary,
unique_graphdb_mapping_values,
document_namespace,
user_id,
):
for node, attributes in unique_graphdb_mapping_values.items(): for node, attributes in unique_graphdb_mapping_values.items():
self.graph.add_node(node, **attributes) self.graph.add_node(node, **attributes)
self.graph.add_edge(f"{user_id}_semantic", node, relation='HAS_KNOWLEDGE') self.graph.add_edge(f"{user_id}_semantic", node, relation="HAS_KNOWLEDGE")
self.save_graph() self.save_graph()
def generate_document_summary(self, document_summary, unique_graphdb_mapping_values, document_namespace, user_id): def generate_document_summary(
self.generate_graph_semantic_memory_document_summary(document_summary, unique_graphdb_mapping_values, document_namespace, user_id) self,
document_summary,
unique_graphdb_mapping_values,
document_namespace,
user_id,
):
self.generate_graph_semantic_memory_document_summary(
document_summary, unique_graphdb_mapping_values, document_namespace, user_id
)
async def get_document_categories(self, user_id): async def get_document_categories(self, user_id):
return [self.graph.nodes[n]['category'] for n in self.graph.neighbors(f"{user_id}_semantic") if 'category' in self.graph.nodes[n]] return [
self.graph.nodes[n]["category"]
for n in self.graph.neighbors(f"{user_id}_semantic")
if "category" in self.graph.nodes[n]
]
async def get_document_ids(self, user_id, category): async def get_document_ids(self, user_id, category):
return [n for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category] return [
n
for n in self.graph.neighbors(f"{user_id}_semantic")
if self.graph.nodes[n].get("category") == category
]
def create_document_node(self, document_summary, user_id): def create_document_node(self, document_summary, user_id):
d_id = document_summary['d_id'] d_id = document_summary["d_id"]
self.graph.add_node(d_id, **document_summary) self.graph.add_node(d_id, **document_summary)
self.graph.add_edge(f"{user_id}_semantic", d_id, relation='HAS_DOCUMENT') self.graph.add_edge(f"{user_id}_semantic", d_id, relation="HAS_DOCUMENT")
self.save_graph() self.save_graph()
def update_document_node_with_namespace(self, user_id, vectordb_namespace, document_id): def update_document_node_with_namespace(
self, user_id, vectordb_namespace, document_id
):
if self.graph.has_node(document_id): if self.graph.has_node(document_id):
self.graph.nodes[document_id]['vectordbNamespace'] = vectordb_namespace self.graph.nodes[document_id]["vectordbNamespace"] = vectordb_namespace
self.save_graph() self.save_graph()
def get_namespaces_by_document_category(self, user_id, category): def get_namespaces_by_document_category(self, user_id, category):
return [self.graph.nodes[n].get('vectordbNamespace') for n in self.graph.neighbors(f"{user_id}_semantic") if self.graph.nodes[n].get('category') == category] return [
self.graph.nodes[n].get("vectordbNamespace")
for n in self.graph.neighbors(f"{user_id}_semantic")
if self.graph.nodes[n].get("category") == category
]

View file

@ -31,36 +31,47 @@ import os
class DatabaseConfig: class DatabaseConfig:
def __init__(self, db_type=None, db_name=None, host=None, user=None, password=None, port=None, config_file=None): def __init__(
self,
db_type=None,
db_name=None,
host=None,
user=None,
password=None,
port=None,
config_file=None,
):
if config_file: if config_file:
self.load_from_file(config_file) self.load_from_file(config_file)
else: else:
# Load default values from environment variables or use provided values # Load default values from environment variables or use provided values
self.db_type = db_type or os.getenv('DB_TYPE', 'sqlite') self.db_type = db_type or os.getenv("DB_TYPE", "sqlite")
self.db_name = db_name or os.getenv('DB_NAME', 'database.db') self.db_name = db_name or os.getenv("DB_NAME", "database.db")
self.host = host or os.getenv('DB_HOST', 'localhost') self.host = host or os.getenv("DB_HOST", "localhost")
self.user = user or os.getenv('DB_USER', 'user') self.user = user or os.getenv("DB_USER", "user")
self.password = password or os.getenv('DB_PASSWORD', 'password') self.password = password or os.getenv("DB_PASSWORD", "password")
self.port = port or os.getenv('DB_PORT', '5432') self.port = port or os.getenv("DB_PORT", "5432")
def load_from_file(self, file_path): def load_from_file(self, file_path):
with open(file_path, 'r') as file: with open(file_path, "r") as file:
config = json.load(file) config = json.load(file)
self.db_type = config.get('db_type', 'sqlite') self.db_type = config.get("db_type", "sqlite")
self.db_name = config.get('db_name', 'database.db') self.db_name = config.get("db_name", "database.db")
self.host = config.get('host', 'localhost') self.host = config.get("host", "localhost")
self.user = config.get('user', 'user') self.user = config.get("user", "user")
self.password = config.get('password', 'password') self.password = config.get("password", "password")
self.port = config.get('port', '5432') self.port = config.get("port", "5432")
def get_sqlalchemy_database_url(self): def get_sqlalchemy_database_url(self):
if self.db_type == 'sqlite': if self.db_type == "sqlite":
db_path = Path(self.db_name).absolute() # Ensure the path is absolute db_path = Path(self.db_name).absolute() # Ensure the path is absolute
return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path
elif self.db_type == 'duckdb': elif self.db_type == "duckdb":
db_path = Path(self.db_name).absolute() # Ensure the path is absolute for DuckDB as well db_path = Path(
self.db_name
).absolute() # Ensure the path is absolute for DuckDB as well
return f"duckdb+aiosqlite:///{db_path}" return f"duckdb+aiosqlite:///{db_path}"
elif self.db_type == 'postgresql': elif self.db_type == "postgresql":
# Ensure optional parameters are handled gracefully # Ensure optional parameters are handled gracefully
port_str = f":{self.port}" if self.port else "" port_str = f":{self.port}" if self.port else ""
password_str = f":{self.password}" if self.password else "" password_str = f":{self.password}" if self.password else ""
@ -68,10 +79,18 @@ class DatabaseConfig:
else: else:
raise ValueError(f"Unsupported DB_TYPE: {self.db_type}") raise ValueError(f"Unsupported DB_TYPE: {self.db_type}")
# Example usage with a configuration file: # Example usage with a configuration file:
# config = DatabaseConfig(config_file='path/to/config.json') # config = DatabaseConfig(config_file='path/to/config.json')
# Or set them programmatically: # Or set them programmatically:
config = DatabaseConfig(db_type='postgresql', db_name='mydatabase', user='myuser', password='mypassword', host='myhost', port='5432') config = DatabaseConfig(
db_type="postgresql",
db_name="mydatabase",
user="myuser",
password="mypassword",
host="myhost",
port="5432",
)
SQLALCHEMY_DATABASE_URL = config.get_sqlalchemy_database_url() SQLALCHEMY_DATABASE_URL = config.get_sqlalchemy_database_url()
@ -79,7 +98,7 @@ SQLALCHEMY_DATABASE_URL = config.get_sqlalchemy_database_url()
engine = create_async_engine( engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, SQLALCHEMY_DATABASE_URL,
pool_recycle=3600, pool_recycle=3600,
echo=True # Enable logging for tutorial purposes echo=True, # Enable logging for tutorial purposes
) )
# Use AsyncSession for the session # Use AsyncSession for the session
AsyncSessionLocal = sessionmaker( AsyncSessionLocal = sessionmaker(
@ -90,6 +109,7 @@ AsyncSessionLocal = sessionmaker(
Base = declarative_base() Base = declarative_base()
# Use asynccontextmanager to define an async context manager # Use asynccontextmanager to define an async context manager
@asynccontextmanager @asynccontextmanager
async def get_db(): async def get_db():
@ -99,6 +119,7 @@ async def get_db():
finally: finally:
await db.close() await db.close()
# #
# if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev': # if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev':
# host = os.environ.get('POSTGRES_HOST') # host = os.environ.get('POSTGRES_HOST')
@ -127,4 +148,3 @@ async def get_db():
# #
# # Use the asyncpg driver for async operation # # Use the asyncpg driver for async operation
# SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}" # SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"

View file

@ -1,4 +1,3 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
from .models.sessions import Session from .models.sessions import Session
@ -9,9 +8,9 @@ from .models.metadatas import MetaDatas
from .models.docs import DocsModel from .models.docs import DocsModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def session_scope(session): async def session_scope(session):
"""Provide a transactional scope around a series of operations.""" """Provide a transactional scope around a series of operations."""
@ -44,6 +43,8 @@ def update_entity_graph_summary(session, model, entity_id, new_value):
return "Successfully updated entity" return "Successfully updated entity"
else: else:
return "Entity not found" return "Entity not found"
async def update_entity(session, model, entity_id, new_value): async def update_entity(session, model, entity_id, new_value):
async with session_scope(session) as s: async with session_scope(session) as s:
# Retrieve the entity from the database # Retrieve the entity from the database

View file

@ -1,20 +1,20 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
from ..database import Base from ..database import Base
class DocsModel(Base): class DocsModel(Base):
__tablename__ = 'docs' __tablename__ = "docs"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True) operation_id = Column(String, ForeignKey("operations.id"), index=True)
doc_name = Column(String, nullable=True) doc_name = Column(String, nullable=True)
graph_summary = Column(Boolean, nullable=True) graph_summary = Column(Boolean, nullable=True)
memory_category = Column(String, nullable=True) memory_category = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)
operations = relationship("Operation", back_populates="docs") operations = relationship("Operation", back_populates="docs")

View file

@ -4,23 +4,27 @@ from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
from ..database import Base from ..database import Base
class MemoryModel(Base): class MemoryModel(Base):
__tablename__ = 'memories' __tablename__ = "memories"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) user_id = Column(String, ForeignKey("users.id"), index=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True) operation_id = Column(String, ForeignKey("operations.id"), index=True)
memory_name = Column(String, nullable=True) memory_name = Column(String, nullable=True)
memory_category = Column(String, nullable=True) memory_category = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)
methods_list = Column(String , nullable=True) methods_list = Column(String, nullable=True)
attributes_list = Column(String, nullable=True) attributes_list = Column(String, nullable=True)
user = relationship("User", back_populates="memories") user = relationship("User", back_populates="memories")
operation = relationship("Operation", back_populates="memories") operation = relationship("Operation", back_populates="memories")
metadatas = relationship("MetaDatas", back_populates="memory", cascade="all, delete-orphan") metadatas = relationship(
"MetaDatas", back_populates="memory", cascade="all, delete-orphan"
)
def __repr__(self): def __repr__(self):
return f"<Memory(id={self.id}, user_id={self.user_id}, created_at={self.created_at}, updated_at={self.updated_at})>" return f"<Memory(id={self.id}, user_id={self.user_id}, created_at={self.created_at}, updated_at={self.updated_at})>"

View file

@ -4,17 +4,19 @@ from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
from ..database import Base from ..database import Base
class MetaDatas(Base): class MetaDatas(Base):
__tablename__ = 'metadatas' __tablename__ = "metadatas"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) user_id = Column(String, ForeignKey("users.id"), index=True)
version = Column(String, nullable=False) version = Column(String, nullable=False)
contract_metadata = Column(String, nullable=False) contract_metadata = Column(String, nullable=False)
memory_id = Column(String, ForeignKey('memories.id'), index=True) memory_id = Column(String, ForeignKey("memories.id"), index=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)
user = relationship("User", back_populates="metadatas") user = relationship("User", back_populates="metadatas")
memory = relationship("MemoryModel", back_populates="metadatas") memory = relationship("MemoryModel", back_populates="metadatas")

View file

@ -4,13 +4,14 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
from ..database import Base from ..database import Base
class Operation(Base): class Operation(Base):
__tablename__ = 'operations' __tablename__ = "operations"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User user_id = Column(String, ForeignKey("users.id"), index=True) # Link to User
operation_type = Column(String, nullable=True) operation_type = Column(String, nullable=True)
operation_status = Column(String, nullable=True) operation_status = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)

View file

@ -9,10 +9,10 @@ from ..database import Base
class Session(Base): class Session(Base):
__tablename__ = 'sessions' __tablename__ = "sessions"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) user_id = Column(String, ForeignKey("users.id"), index=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)

View file

@ -4,17 +4,17 @@ from sqlalchemy import Column, String, DateTime
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
import os import os
import sys import sys
from .memory import MemoryModel from .memory import MemoryModel
from .operation import Operation from .operation import Operation
from .sessions import Session from .sessions import Session
from .metadatas import MetaDatas from .metadatas import MetaDatas
from .docs import DocsModel from .docs import DocsModel
from ..database import Base from ..database import Base
class User(Base): class User(Base):
__tablename__ = 'users' __tablename__ = "users"
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
session_id = Column(String, nullable=True, unique=True) session_id = Column(String, nullable=True, unique=True)
@ -22,9 +22,15 @@ class User(Base):
updated_at = Column(DateTime, onupdate=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow)
# Relationships # Relationships
memories = relationship("MemoryModel", back_populates="user", cascade="all, delete-orphan") memories = relationship(
operations = relationship("Operation", back_populates="user", cascade="all, delete-orphan") "MemoryModel", back_populates="user", cascade="all, delete-orphan"
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan") )
operations = relationship(
"Operation", back_populates="user", cascade="all, delete-orphan"
)
sessions = relationship(
"Session", back_populates="user", cascade="all, delete-orphan"
)
metadatas = relationship("MetaDatas", back_populates="user") metadatas = relationship("MetaDatas", back_populates="user")
def __repr__(self): def __repr__(self):

View file

@ -1,11 +1,12 @@
import logging import logging
from io import BytesIO from io import BytesIO
import os, sys import os, sys
# Add the parent directory to sys.path # Add the parent directory to sys.path
# sys.path.append(os.path.dirname(os.path.abspath(__file__))) # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import sqlalchemy as sa import sqlalchemy as sa
print(os.getcwd()) print(os.getcwd())
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# import marvin # import marvin
@ -22,6 +23,7 @@ from cognitive_architecture.database.relationaldb.models.operation import Operat
from cognitive_architecture.database.relationaldb.models.docs import DocsModel from cognitive_architecture.database.relationaldb.models.docs import DocsModel
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from cognitive_architecture.database.relationaldb.database import engine from cognitive_architecture.database.relationaldb.database import engine
load_dotenv() load_dotenv()
from typing import Optional from typing import Optional
import time import time
@ -31,7 +33,11 @@ tracemalloc.start()
from datetime import datetime from datetime import datetime
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
from cognitive_architecture.database.vectordb.vectordb import PineconeVectorDB, WeaviateVectorDB, LanceDB from cognitive_architecture.database.vectordb.vectordb import (
PineconeVectorDB,
WeaviateVectorDB,
LanceDB,
)
from langchain.schema import Document from langchain.schema import Document
import uuid import uuid
import weaviate import weaviate
@ -43,6 +49,7 @@ from vector_db_type import VectorDBType
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
# marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY") # marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
class VectorDBFactory: class VectorDBFactory:
def __init__(self): def __init__(self):
self.db_map = { self.db_map = {
@ -63,15 +70,12 @@ class VectorDBFactory:
): ):
if db_type in self.db_map: if db_type in self.db_map:
return self.db_map[db_type]( return self.db_map[db_type](
user_id, user_id, index_name, memory_id, namespace, embeddings
index_name,
memory_id,
namespace,
embeddings
) )
raise ValueError(f"Unsupported database type: {db_type}") raise ValueError(f"Unsupported database type: {db_type}")
class BaseMemory: class BaseMemory:
def __init__( def __init__(
self, self,
@ -95,21 +99,18 @@ class BaseMemory:
self.memory_id, self.memory_id,
db_type=self.db_type, db_type=self.db_type,
namespace=self.namespace, namespace=self.namespace,
embeddings=self.embeddings embeddings=self.embeddings,
) )
def init_client(self, embeddings, namespace: str): def init_client(self, embeddings, namespace: str):
return self.vector_db.init_client(embeddings, namespace) return self.vector_db.init_client(embeddings, namespace)
def create_field(self, field_type, **kwargs): def create_field(self, field_type, **kwargs):
field_mapping = { field_mapping = {
"Str": fields.Str, "Str": fields.Str,
"Int": fields.Int, "Int": fields.Int,
"Float": fields.Float, "Float": fields.Float,
"Bool": fields.Bool, "Bool": fields.Bool,
} }
return field_mapping[field_type](**kwargs) return field_mapping[field_type](**kwargs)
@ -121,7 +122,6 @@ class BaseMemory:
dynamic_schema_instance = Schema.from_dict(dynamic_fields)() dynamic_schema_instance = Schema.from_dict(dynamic_fields)()
return dynamic_schema_instance return dynamic_schema_instance
async def get_version_from_db(self, user_id, memory_id): async def get_version_from_db(self, user_id, memory_id):
# Logic to retrieve the version from the database. # Logic to retrieve the version from the database.
@ -137,11 +137,11 @@ class BaseMemory:
) )
if result: if result:
version_in_db, created_at = result version_in_db, created_at = result
logging.info(f"version_in_db: {version_in_db}") logging.info(f"version_in_db: {version_in_db}")
from ast import literal_eval from ast import literal_eval
version_in_db= literal_eval(version_in_db)
version_in_db = literal_eval(version_in_db)
version_in_db = version_in_db.get("version") version_in_db = version_in_db.get("version")
return [version_in_db, created_at] return [version_in_db, created_at]
else: else:
@ -157,20 +157,33 @@ class BaseMemory:
# If there is no metadata, insert it. # If there is no metadata, insert it.
if version_from_db is None: if version_from_db is None:
session.add(
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, version = str(int(time.time())) ,memory_id=self.memory_id, contract_metadata=params)) MetaDatas(
id=str(uuid.uuid4()),
user_id=self.user_id,
version=str(int(time.time())),
memory_id=self.memory_id,
contract_metadata=params,
)
)
session.commit() session.commit()
return params return params
# If params version is higher, update the metadata. # If params version is higher, update the metadata.
elif version_in_params > version_from_db[0]: elif version_in_params > version_from_db[0]:
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, memory_id=self.memory_id, contract_metadata=params)) session.add(
MetaDatas(
id=str(uuid.uuid4()),
user_id=self.user_id,
memory_id=self.memory_id,
contract_metadata=params,
)
)
session.commit() session.commit()
return params return params
else: else:
return params return params
async def add_memories( async def add_memories(
self, self,
observation: Optional[str] = None, observation: Optional[str] = None,
@ -179,11 +192,14 @@ class BaseMemory:
namespace: Optional[str] = None, namespace: Optional[str] = None,
custom_fields: Optional[str] = None, custom_fields: Optional[str] = None,
embeddings: Optional[str] = None, embeddings: Optional[str] = None,
): ):
return await self.vector_db.add_memories( return await self.vector_db.add_memories(
observation=observation, loader_settings=loader_settings, observation=observation,
params=params, namespace=namespace, metadata_schema_class = None, embeddings=embeddings loader_settings=loader_settings,
params=params,
namespace=namespace,
metadata_schema_class=None,
embeddings=embeddings,
) )
# Add other db_type conditions if necessary # Add other db_type conditions if necessary
@ -200,17 +216,15 @@ class BaseMemory:
logging.info(observation) logging.info(observation)
return await self.vector_db.fetch_memories( return await self.vector_db.fetch_memories(
observation=observation, search_type= search_type, params=params, observation=observation,
search_type=search_type,
params=params,
namespace=namespace, namespace=namespace,
n_of_observations=n_of_observations n_of_observations=n_of_observations,
) )
async def delete_memories(self, namespace:str, params: Optional[str] = None): async def delete_memories(self, namespace: str, params: Optional[str] = None):
return await self.vector_db.delete_memories(namespace,params) return await self.vector_db.delete_memories(namespace, params)
async def count_memories(self, namespace:str, params: Optional[str] = None):
return await self.vector_db.count_memories(namespace,params)
async def count_memories(self, namespace: str, params: Optional[str] = None):
return await self.vector_db.count_memories(namespace, params)

View file

@ -1,8 +1,10 @@
from enum import Enum from enum import Enum
class ChunkStrategy(Enum): class ChunkStrategy(Enum):
EXACT = 'exact' """Chunking strategies for the vector database."""
PARAGRAPH = 'paragraph' EXACT = "exact"
SENTENCE = 'sentence' PARAGRAPH = "paragraph"
VANILLA = 'vanilla' SENTENCE = "sentence"
SUMMARY = 'summary' VANILLA = "vanilla"
SUMMARY = "summary"

View file

@ -1,13 +1,17 @@
from cognitive_architecture.database.vectordb.chunkers.chunk_strategy import ChunkStrategy """Module for chunking text data based on various strategies."""
import re
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
import re
import logging
from cognitive_architecture.database.vectordb.chunkers.chunk_strategy import ChunkStrategy
from langchain.text_splitter import RecursiveCharacterTextSplitter
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
"""Chunk the given source data into smaller parts based on the specified strategy."""
if chunk_strategy == ChunkStrategy.VANILLA: if chunk_strategy == ChunkStrategy.VANILLA:
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap) chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.PARAGRAPH: elif chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap) chunked_data = chunk_data_by_paragraph(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE: elif chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap) chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT: elif chunk_strategy == ChunkStrategy.EXACT:
@ -21,68 +25,41 @@ def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_ove
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20): def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
# adapt this for different chunking strategies """Chunk the given source data into smaller parts using a vanilla strategy."""
from langchain.text_splitter import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size
text_splitter = RecursiveCharacterTextSplitter( , chunk_overlap=chunk_overlap
# Set a really small chunk size, just to show. , length_function=len)
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
# try:
# pages = text_splitter.create_documents([source_data])
# except:
# try:
pages = text_splitter.create_documents([source_data]) pages = text_splitter.create_documents([source_data])
# except:
# pages = text_splitter.create_documents(source_data.content)
# pages = source_data.load_and_split()
return pages return pages
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20): def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
""" """Chunk the given source data into smaller parts, focusing on summarizing content."""
Chunk the given source data into smaller parts, returning the first five and last five chunks. text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size
, chunk_overlap=chunk_overlap
Parameters: , length_function=len)
- source_data (str): The source data to be chunked.
- chunk_size (int): The size of each chunk.
- chunk_overlap (int): The overlap between consecutive chunks.
Returns:
- List: A list containing the first five and last five chunks of the chunked source data.
"""
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
try: try:
pages = text_splitter.create_documents([source_data]) pages = text_splitter.create_documents([source_data])
except: except Exception as e:
pages = text_splitter.create_documents(source_data.content) pages = text_splitter.create_documents(source_data.content)
logging.error(f"An error occurred: %s {str(e)}")
# Return the first 5 and last 5 chunks
if len(pages) > 10: if len(pages) > 10:
return pages[:5] + pages[-5:] return pages[:5] + pages[-5:]
else: return pages
return pages # Return all chunks if there are 10 or fewer
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap): def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
"""Chunk the data into exact sizes as specified, without considering content."""
data = "".join(data_chunks) data = "".join(data_chunks)
chunks = [] chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size - chunk_overlap)]
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
return chunks return chunks
def chunk_by_sentence(data_chunks, chunk_size, overlap): def chunk_by_sentence(data_chunks, chunk_size, overlap):
# Split by periods, question marks, exclamation marks, and ellipses """Chunk the data by sentences, ensuring each chunk does not exceed the specified size."""
data = "".join(data_chunks) data = "".join(data_chunks)
sentence_endings = r"(?<=[.!?…]) +"
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data) sentences = re.split(sentence_endings, data)
sentence_chunks = [] sentence_chunks = []
@ -96,6 +73,7 @@ def chunk_by_sentence(data_chunks, chunk_size, overlap):
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75): def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
"""Chunk the data by paragraphs, with consideration for chunk size and overlap."""
data = "".join(data_chunks) data = "".join(data_chunks)
total_length = len(data) total_length = len(data)
chunks = [] chunks = []
@ -103,20 +81,13 @@ def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
start_idx = 0 start_idx = 0
while start_idx < total_length: while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length) end_idx = min(start_idx + chunk_size, total_length)
next_paragraph_index = data.find("\n\n", start_idx + check_bound, end_idx)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1: if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2 end_idx = next_paragraph_index + 2
chunks.append(data[start_idx:end_idx + overlap]) chunks.append(data[start_idx:end_idx + overlap])
# Update start_idx to be the current end_idx
start_idx = end_idx start_idx = end_idx
return chunks return chunks

View file

@ -7,17 +7,20 @@ from .response import Response
class CogneeManager: class CogneeManager:
def __init__(self, embeddings: Embeddings = None, def __init__(
vector_db: VectorDB = None, self,
vector_db_key: str = None, embeddings: Embeddings = None,
embedding_api_key: str = None, vector_db: VectorDB = None,
webhook_url: str = None, vector_db_key: str = None,
lines_per_batch: int = 1000, embedding_api_key: str = None,
webhook_key: str = None, webhook_url: str = None,
document_id: str = None, lines_per_batch: int = 1000,
chunk_validation_url: str = None, webhook_key: str = None,
internal_api_key: str = "test123", document_id: str = None,
base_url="http://localhost:8000"): chunk_validation_url: str = None,
internal_api_key: str = "test123",
base_url="http://localhost:8000",
):
self.embeddings = embeddings if embeddings else Embeddings() self.embeddings = embeddings if embeddings else Embeddings()
self.vector_db = vector_db if vector_db else VectorDB() self.vector_db = vector_db if vector_db else VectorDB()
self.webhook_url = webhook_url self.webhook_url = webhook_url
@ -32,12 +35,12 @@ class CogneeManager:
def serialize(self): def serialize(self):
data = { data = {
'EmbeddingsMetadata': json.dumps(self.embeddings.serialize()), "EmbeddingsMetadata": json.dumps(self.embeddings.serialize()),
'VectorDBMetadata': json.dumps(self.vector_db.serialize()), "VectorDBMetadata": json.dumps(self.vector_db.serialize()),
'WebhookURL': self.webhook_url, "WebhookURL": self.webhook_url,
'LinesPerBatch': self.lines_per_batch, "LinesPerBatch": self.lines_per_batch,
'DocumentID': self.document_id, "DocumentID": self.document_id,
'ChunkValidationURL': self.chunk_validation_url, "ChunkValidationURL": self.chunk_validation_url,
} }
return {k: v for k, v in data.items() if v is not None} return {k: v for k, v in data.items() if v is not None}
@ -49,11 +52,22 @@ class CogneeManager:
data = self.serialize() data = self.serialize()
headers = self.generate_headers() headers = self.generate_headers()
multipart_form_data = [('file', (os.path.basename(filepath), open(filepath, 'rb'), 'application/octet-stream')) multipart_form_data = [
for filepath in file_paths] (
"file",
(
os.path.basename(filepath),
open(filepath, "rb"),
"application/octet-stream",
),
)
for filepath in file_paths
]
print(f"embedding {len(file_paths)} documents at {url}") print(f"embedding {len(file_paths)} documents at {url}")
response = requests.post(url, files=multipart_form_data, headers=headers, stream=True, data=data) response = requests.post(
url, files=multipart_form_data, headers=headers, stream=True, data=data
)
if response.status_code == 500: if response.status_code == 500:
print(response.text) print(response.text)
@ -75,9 +89,7 @@ class CogneeManager:
"Authorization": self.internal_api_key, "Authorization": self.internal_api_key,
} }
data = { data = {"JobIDs": job_ids}
'JobIDs': job_ids
}
print(f"retrieving job statuses for {len(job_ids)} jobs at {url}") print(f"retrieving job statuses for {len(job_ids)} jobs at {url}")
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
@ -101,9 +113,7 @@ class CogneeManager:
data = self.serialize() data = self.serialize()
headers = self.generate_headers() headers = self.generate_headers()
files = { files = {"SourceData": open(filepath, "rb")}
'SourceData': open(filepath, 'rb')
}
print(f"embedding document at file path {filepath} at {url}") print(f"embedding document at file path {filepath} at {url}")
response = requests.post(url, headers=headers, data=data, files=files) response = requests.post(url, headers=headers, data=data, files=files)
@ -146,6 +156,6 @@ class CogneeManager:
"Authorization": self.internal_api_key, "Authorization": self.internal_api_key,
"X-EmbeddingAPI-Key": self.embeddings_api_key, "X-EmbeddingAPI-Key": self.embeddings_api_key,
"X-VectorDB-Key": self.vector_db_key, "X-VectorDB-Key": self.vector_db_key,
"X-Webhook-Key": self.webhook_key "X-Webhook-Key": self.webhook_key,
} }
return {k: v for k, v in headers.items() if v is not None} return {k: v for k, v in headers.items() if v is not None}

View file

@ -3,12 +3,15 @@ from ..chunkers.chunk_strategy import ChunkStrategy
class Embeddings: class Embeddings:
def __init__(self, embeddings_type: EmbeddingsType = EmbeddingsType.OPEN_AI, def __init__(
chunk_size: int = 256, self,
chunk_overlap: int = 128, embeddings_type: EmbeddingsType = EmbeddingsType.OPEN_AI,
chunk_strategy: ChunkStrategy = ChunkStrategy.EXACT, chunk_size: int = 256,
docker_image: str = None, chunk_overlap: int = 128,
hugging_face_model_name: str = None): chunk_strategy: ChunkStrategy = ChunkStrategy.EXACT,
docker_image: str = None,
hugging_face_model_name: str = None,
):
self.embeddings_type = embeddings_type self.embeddings_type = embeddings_type
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
@ -18,12 +21,14 @@ class Embeddings:
def serialize(self): def serialize(self):
data = { data = {
'embeddings_type': self.embeddings_type.name if self.embeddings_type else None, "embeddings_type": self.embeddings_type.name
'chunk_size': self.chunk_size, if self.embeddings_type
'chunk_overlap': self.chunk_overlap, else None,
'chunk_strategy': self.chunk_strategy.name if self.chunk_strategy else None, "chunk_size": self.chunk_size,
'docker_image': self.docker_image, "chunk_overlap": self.chunk_overlap,
'hugging_face_model_name': self.hugging_face_model_name "chunk_strategy": self.chunk_strategy.name if self.chunk_strategy else None,
"docker_image": self.docker_image,
"hugging_face_model_name": self.hugging_face_model_name,
} }
return {k: v for k, v in data.items() if v is not None} return {k: v for k, v in data.items() if v is not None}

View file

@ -1,8 +1,9 @@
from enum import Enum from enum import Enum
class EmbeddingsType(Enum): class EmbeddingsType(Enum):
OPEN_AI = 'open_ai' OPEN_AI = "open_ai"
COHERE = 'cohere' COHERE = "cohere"
SELF_HOSTED = 'self_hosted' SELF_HOSTED = "self_hosted"
HUGGING_FACE = 'hugging_face' HUGGING_FACE = "hugging_face"
IMAGE = 'image' IMAGE = "image"

View file

@ -4,7 +4,10 @@ import os
import sys import sys
from cognitive_architecture.database.vectordb.chunkers.chunkers import chunk_data from cognitive_architecture.database.vectordb.chunkers.chunkers import chunk_data
from cognitive_architecture.shared.language_processing import translate_text, detect_language from cognitive_architecture.shared.language_processing import (
translate_text,
detect_language,
)
from langchain.document_loaders import UnstructuredURLLoader from langchain.document_loaders import UnstructuredURLLoader
from langchain.document_loaders import DirectoryLoader from langchain.document_loaders import DirectoryLoader
@ -15,28 +18,36 @@ import requests
async def fetch_pdf_content(file_url): async def fetch_pdf_content(file_url):
response = requests.get(file_url) response = requests.get(file_url)
pdf_stream = BytesIO(response.content) pdf_stream = BytesIO(response.content)
with fitz.open(stream=pdf_stream, filetype='pdf') as doc: with fitz.open(stream=pdf_stream, filetype="pdf") as doc:
return "".join(page.get_text() for page in doc) return "".join(page.get_text() for page in doc)
async def fetch_text_content(file_url): async def fetch_text_content(file_url):
loader = UnstructuredURLLoader(urls=file_url) loader = UnstructuredURLLoader(urls=file_url)
return loader.load() return loader.load()
async def process_content(content, metadata, loader_strategy, chunk_size, chunk_overlap):
pages = chunk_data(chunk_strategy=loader_strategy, source_data=content, chunk_size=chunk_size, async def process_content(
chunk_overlap=chunk_overlap) content, metadata, loader_strategy, chunk_size, chunk_overlap
):
pages = chunk_data(
chunk_strategy=loader_strategy,
source_data=content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if metadata is None: if metadata is None:
metadata = {"metadata": "None"} metadata = {"metadata": "None"}
chunk_count= 0 chunk_count = 0
for chunk in pages: for chunk in pages:
chunk_count+=1 chunk_count += 1
chunk.metadata = metadata chunk.metadata = metadata
chunk.metadata["chunk_count"]=chunk_count chunk.metadata["chunk_count"] = chunk_count
if detect_language(pages) != "en": if detect_language(pages) != "en":
logging.info("Translating Page") logging.info("Translating Page")
for page in pages: for page in pages:
@ -45,6 +56,7 @@ async def process_content(content, metadata, loader_strategy, chunk_size, chunk
return pages return pages
async def _document_loader(observation: str, loader_settings: dict): async def _document_loader(observation: str, loader_settings: dict):
document_format = loader_settings.get("format", "text") document_format = loader_settings.get("format", "text")
loader_strategy = loader_settings.get("strategy", "VANILLA") loader_strategy = loader_settings.get("strategy", "VANILLA")
@ -65,7 +77,13 @@ async def _document_loader(observation: str, loader_settings: dict):
else: else:
raise ValueError(f"Unsupported document format: {document_format}") raise ValueError(f"Unsupported document format: {document_format}")
pages = await process_content(content, metadata=None, loader_strategy=loader_strategy, chunk_size= chunk_size, chunk_overlap= chunk_overlap) pages = await process_content(
content,
metadata=None,
loader_strategy=loader_strategy,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunked_doc.append(pages) chunked_doc.append(pages)
elif loader_settings.get("source") == "DEVICE": elif loader_settings.get("source") == "DEVICE":
@ -76,17 +94,28 @@ async def _document_loader(observation: str, loader_settings: dict):
documents = loader.load() documents = loader.load()
for document in documents: for document in documents:
# print ("Document: ", document.page_content) # print ("Document: ", document.page_content)
pages = await process_content(content= str(document.page_content), metadata=document.metadata, loader_strategy= loader_strategy, chunk_size = chunk_size, chunk_overlap = chunk_overlap) pages = await process_content(
content=str(document.page_content),
metadata=document.metadata,
loader_strategy=loader_strategy,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunked_doc.append(pages) chunked_doc.append(pages)
else: else:
from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import PyPDFLoader
loader = PyPDFLoader(loader_settings.get("single_document_path")) loader = PyPDFLoader(loader_settings.get("single_document_path"))
documents= loader.load() documents = loader.load()
for document in documents: for document in documents:
pages = await process_content(content=str(document.page_content), metadata=document.metadata, pages = await process_content(
loader_strategy=loader_strategy, chunk_size=chunk_size, content=str(document.page_content),
chunk_overlap=chunk_overlap) metadata=document.metadata,
loader_strategy=loader_strategy,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunked_doc.append(pages) chunked_doc.append(pages)
else: else:
raise ValueError(f"Unsupported source type: {loader_settings.get('source')}") raise ValueError(f"Unsupported source type: {loader_settings.get('source')}")
@ -94,7 +123,6 @@ async def _document_loader(observation: str, loader_settings: dict):
return chunked_doc return chunked_doc
# async def _document_loader( observation: str, loader_settings: dict): # async def _document_loader( observation: str, loader_settings: dict):
# #
# document_format = loader_settings.get("format", "text") # document_format = loader_settings.get("format", "text")
@ -196,11 +224,3 @@ async def _document_loader(observation: str, loader_settings: dict):
# else: # else:
# raise ValueError(f"Error: ") # raise ValueError(f"Error: ")
# return chunked_doc # return chunked_doc

View file

@ -2,9 +2,19 @@ from .job import Job
class Response: class Response:
def __init__(self, error=None, message=None, successful_uploads=None, failed_uploads=None, def __init__(
empty_files_count=None, duplicate_files_count=None, job_id=None, self,
jobs=None, job_status=None, status_code=None): error=None,
message=None,
successful_uploads=None,
failed_uploads=None,
empty_files_count=None,
duplicate_files_count=None,
job_id=None,
jobs=None,
job_status=None,
status_code=None,
):
self.error = error self.error = error
self.message = message self.message = message
self.successful_uploads = successful_uploads self.successful_uploads = successful_uploads
@ -18,33 +28,37 @@ class Response:
@classmethod @classmethod
def from_json(cls, json_dict, status_code): def from_json(cls, json_dict, status_code):
successful_uploads = cls._convert_successful_uploads_to_jobs(json_dict.get('successful_uploads', None)) successful_uploads = cls._convert_successful_uploads_to_jobs(
jobs = cls._convert_to_jobs(json_dict.get('Jobs', None)) json_dict.get("successful_uploads", None)
)
jobs = cls._convert_to_jobs(json_dict.get("Jobs", None))
return cls( return cls(
error=json_dict.get('error'), error=json_dict.get("error"),
message=json_dict.get('message'), message=json_dict.get("message"),
successful_uploads=successful_uploads, successful_uploads=successful_uploads,
failed_uploads=json_dict.get('failed_uploads'), failed_uploads=json_dict.get("failed_uploads"),
empty_files_count=json_dict.get('empty_files_count'), empty_files_count=json_dict.get("empty_files_count"),
duplicate_files_count=json_dict.get('duplicate_files_count'), duplicate_files_count=json_dict.get("duplicate_files_count"),
job_id=json_dict.get('JobID'), job_id=json_dict.get("JobID"),
jobs=jobs, jobs=jobs,
job_status=json_dict.get('JobStatus'), job_status=json_dict.get("JobStatus"),
status_code=status_code status_code=status_code,
) )
@classmethod @classmethod
def _convert_successful_uploads_to_jobs(cls, successful_uploads): def _convert_successful_uploads_to_jobs(cls, successful_uploads):
if not successful_uploads: if not successful_uploads:
return None return None
return [Job(filename=key, job_id=val) for key, val in successful_uploads.items()] return [
Job(filename=key, job_id=val) for key, val in successful_uploads.items()
]
@classmethod @classmethod
def _convert_to_jobs(cls, jobs): def _convert_to_jobs(cls, jobs):
if not jobs: if not jobs:
return None return None
return [Job(job_id=job['JobID'], job_status=job['JobStatus']) for job in jobs] return [Job(job_id=job["JobID"], job_status=job["JobStatus"]) for job in jobs]
def __str__(self): def __str__(self):
attributes = [] attributes = []

View file

@ -1,14 +1,15 @@
from enum import Enum from enum import Enum
class VectorDBType(Enum): class VectorDBType(Enum):
PINECONE = 'pinecone' PINECONE = "pinecone"
WEAVIATE = 'weaviate' WEAVIATE = "weaviate"
MILVUS = 'milvus' MILVUS = "milvus"
QDRANT = 'qdrant' QDRANT = "qdrant"
DEEPLAKE = 'deeplake' DEEPLAKE = "deeplake"
VESPA = 'vespa' VESPA = "vespa"
PGVECTOR = 'pgvector' PGVECTOR = "pgvector"
REDIS = 'redis' REDIS = "redis"
LANCEDB = 'lancedb' LANCEDB = "lancedb"
MONGODB = 'mongodb' MONGODB = "mongodb"
FAISS = 'faiss' FAISS = "faiss"

View file

@ -1,10 +1,10 @@
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client # Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
import logging import logging
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from marshmallow import Schema, fields from marshmallow import Schema, fields
from cognitive_architecture.database.vectordb.loaders.loaders import _document_loader from cognitive_architecture.database.vectordb.loaders.loaders import _document_loader
# Add the parent directory to sys.path # Add the parent directory to sys.path
@ -12,6 +12,7 @@ logging.basicConfig(level=logging.INFO)
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
from weaviate.gql.get import HybridFusion from weaviate.gql.get import HybridFusion
import tracemalloc import tracemalloc
tracemalloc.start() tracemalloc.start()
import os import os
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
@ -28,6 +29,8 @@ config.load()
LTM_MEMORY_ID_DEFAULT = "00000" LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000" ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000" BUFFER_ID_DEFAULT = "0000"
class VectorDB: class VectorDB:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
@ -37,7 +40,7 @@ class VectorDB:
index_name: str, index_name: str,
memory_id: str, memory_id: str,
namespace: str = None, namespace: str = None,
embeddings = None, embeddings=None,
): ):
self.user_id = user_id self.user_id = user_id
self.index_name = index_name self.index_name = index_name
@ -45,6 +48,7 @@ class VectorDB:
self.memory_id = memory_id self.memory_id = memory_id
self.embeddings = embeddings self.embeddings = embeddings
class PineconeVectorDB(VectorDB): class PineconeVectorDB(VectorDB):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -54,13 +58,21 @@ class PineconeVectorDB(VectorDB):
# Pinecone initialization logic # Pinecone initialization logic
pass pass
import langchain.embeddings import langchain.embeddings
class WeaviateVectorDB(VectorDB): class WeaviateVectorDB(VectorDB):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace) self.init_weaviate(embeddings=self.embeddings, namespace=self.namespace)
def init_weaviate(self, embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")), namespace=None,retriever_type="",): def init_weaviate(
self,
embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")),
namespace=None,
retriever_type="",
):
# Weaviate initialization logic # Weaviate initialization logic
auth_config = weaviate.auth.AuthApiKey( auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY") api_key=os.environ.get("WEAVIATE_API_KEY")
@ -91,15 +103,16 @@ class WeaviateVectorDB(VectorDB):
create_schema_if_missing=True, create_schema_if_missing=True,
) )
return retriever return retriever
else : else:
return client return client
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400) # child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
# store = InMemoryStore() # store = InMemoryStore()
# retriever = ParentDocumentRetriever( # retriever = ParentDocumentRetriever(
# vectorstore=vectorstore, # vectorstore=vectorstore,
# docstore=store, # docstore=store,
# child_splitter=child_splitter, # child_splitter=child_splitter,
# ) # )
from marshmallow import Schema, fields from marshmallow import Schema, fields
def create_document_structure(observation, params, metadata_schema_class=None): def create_document_structure(observation, params, metadata_schema_class=None):
@ -111,10 +124,7 @@ class WeaviateVectorDB(VectorDB):
:param metadata_schema_class: Custom metadata schema class (optional). :param metadata_schema_class: Custom metadata schema class (optional).
:return: A list containing the validated document data. :return: A list containing the validated document data.
""" """
document_data = { document_data = {"metadata": params, "page_content": observation}
"metadata": params,
"page_content": observation
}
def get_document_schema(): def get_document_schema():
class DynamicDocumentSchema(Schema): class DynamicDocumentSchema(Schema):
@ -128,30 +138,42 @@ class WeaviateVectorDB(VectorDB):
loaded_document = CurrentDocumentSchema().load(document_data) loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document] return [loaded_document]
def _stuct(self, observation, params, metadata_schema_class =None): def _stuct(self, observation, params, metadata_schema_class=None):
"""Utility function to create the document structure with optional custom fields.""" """Utility function to create the document structure with optional custom fields."""
# Construct document data # Construct document data
document_data = { document_data = {"metadata": params, "page_content": observation}
"metadata": params,
"page_content": observation
}
def get_document_schema(): def get_document_schema():
class DynamicDocumentSchema(Schema): class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True) metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True) page_content = fields.Str(required=True)
return DynamicDocumentSchema return DynamicDocumentSchema
# Validate and deserialize # Default to "1.0" if not provided # Validate and deserialize # Default to "1.0" if not provided
CurrentDocumentSchema = get_document_schema() CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data) loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document] return [loaded_document]
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, metadata_schema_class=None, embeddings = 'hybrid'):
async def add_memories(
self,
observation,
loader_settings=None,
params=None,
namespace=None,
metadata_schema_class=None,
embeddings="hybrid",
):
# Update Weaviate memories here # Update Weaviate memories here
if namespace is None: if namespace is None:
namespace = self.namespace namespace = self.namespace
params['user_id'] = self.user_id params["user_id"] = self.user_id
logging.info("User id is %s", self.user_id) logging.info("User id is %s", self.user_id)
retriever = self.init_weaviate(embeddings=OpenAIEmbeddings(),namespace = namespace, retriever_type="single_document_context") retriever = self.init_weaviate(
embeddings=OpenAIEmbeddings(),
namespace=namespace,
retriever_type="single_document_context",
)
if loader_settings: if loader_settings:
# Assuming _document_loader returns a list of documents # Assuming _document_loader returns a list of documents
documents = await _document_loader(observation, loader_settings) documents = await _document_loader(observation, loader_settings)
@ -160,27 +182,49 @@ class WeaviateVectorDB(VectorDB):
for doc_list in documents: for doc_list in documents:
for doc in doc_list: for doc in doc_list:
chunk_count += 1 chunk_count += 1
params['chunk_count'] = doc.metadata.get("chunk_count", "None") params["chunk_count"] = doc.metadata.get("chunk_count", "None")
logging.info("Loading document with provided loader settings %s", str(doc)) logging.info(
params['source'] = doc.metadata.get("source", "None") "Loading document with provided loader settings %s", str(doc)
)
params["source"] = doc.metadata.get("source", "None")
logging.info("Params are %s", str(params)) logging.info("Params are %s", str(params))
retriever.add_documents([ retriever.add_documents(
Document(metadata=params, page_content=doc.page_content)]) [Document(metadata=params, page_content=doc.page_content)]
)
else: else:
chunk_count = 0 chunk_count = 0
from cognitive_architecture.database.vectordb.chunkers.chunkers import chunk_data from cognitive_architecture.database.vectordb.chunkers.chunkers import (
documents = [chunk_data(chunk_strategy="VANILLA", source_data=observation, chunk_size=300, chunk_data,
chunk_overlap=20)] )
documents = [
chunk_data(
chunk_strategy="VANILLA",
source_data=observation,
chunk_size=300,
chunk_overlap=20,
)
]
for doc in documents[0]: for doc in documents[0]:
chunk_count += 1 chunk_count += 1
params['chunk_order'] = chunk_count params["chunk_order"] = chunk_count
params['source'] = "User loaded" params["source"] = "User loaded"
logging.info("Loading document with default loader settings %s", str(doc)) logging.info(
"Loading document with default loader settings %s", str(doc)
)
logging.info("Params are %s", str(params)) logging.info("Params are %s", str(params))
retriever.add_documents([ retriever.add_documents(
Document(metadata=params, page_content=doc.page_content)]) [Document(metadata=params, page_content=doc.page_content)]
)
async def fetch_memories(self, observation: str, namespace: str = None, search_type: str = 'hybrid',params=None, **kwargs): async def fetch_memories(
self,
observation: str,
namespace: str = None,
search_type: str = "hybrid",
params=None,
**kwargs,
):
""" """
Fetch documents from weaviate. Fetch documents from weaviate.
@ -196,12 +240,9 @@ class WeaviateVectorDB(VectorDB):
Example: Example:
fetch_memories(query="some query", search_type='text', additional_param='value') fetch_memories(query="some query", search_type='text', additional_param='value')
""" """
client = self.init_weaviate(namespace =self.namespace) client = self.init_weaviate(namespace=self.namespace)
if search_type is None: if search_type is None:
search_type = 'hybrid' search_type = "hybrid"
if not namespace: if not namespace:
namespace = self.namespace namespace = self.namespace
@ -222,37 +263,41 @@ class WeaviateVectorDB(VectorDB):
for prop in class_obj["properties"] for prop in class_obj["properties"]
] ]
base_query = client.query.get( base_query = (
namespace, list(list_objects_of_class(namespace, client.schema.get())) client.query.get(
).with_additional( namespace, list(list_objects_of_class(namespace, client.schema.get()))
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] )
).with_where(params_user_id).with_limit(10) .with_additional(
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", "distance"]
)
.with_where(params_user_id)
.with_limit(10)
)
n_of_observations = kwargs.get('n_of_observations', 2) n_of_observations = kwargs.get("n_of_observations", 2)
# try: # try:
if search_type == 'text': if search_type == "text":
query_output = ( query_output = (
base_query base_query.with_near_text({"concepts": [observation]})
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations) .with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'hybrid': elif search_type == "hybrid":
query_output = ( query_output = (
base_query base_query.with_hybrid(
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE) query=observation, fusion_type=HybridFusion.RELATIVE_SCORE
)
.with_autocut(n_of_observations) .with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'bm25': elif search_type == "bm25":
query_output = ( query_output = (
base_query base_query.with_bm25(query=observation)
.with_bm25(query=observation)
.with_autocut(n_of_observations) .with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'summary': elif search_type == "summary":
filter_object = { filter_object = {
"operator": "And", "operator": "And",
"operands": [ "operands": [
@ -266,20 +311,32 @@ class WeaviateVectorDB(VectorDB):
"operator": "LessThan", "operator": "LessThan",
"valueNumber": 30, "valueNumber": 30,
}, },
] ],
} }
base_query = client.query.get( base_query = (
namespace, list(list_objects_of_class(namespace, client.schema.get())) client.query.get(
).with_additional( namespace,
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] list(list_objects_of_class(namespace, client.schema.get())),
).with_where(filter_object).with_limit(30) )
.with_additional(
[
"id",
"creationTimeUnix",
"lastUpdateTimeUnix",
"score",
"distance",
]
)
.with_where(filter_object)
.with_limit(30)
)
query_output = ( query_output = (
base_query base_query
# .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE) # .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
.do() .do()
) )
elif search_type == 'summary_filter_by_object_name': elif search_type == "summary_filter_by_object_name":
filter_object = { filter_object = {
"operator": "And", "operator": "And",
"operands": [ "operands": [
@ -293,17 +350,27 @@ class WeaviateVectorDB(VectorDB):
"operator": "Equal", "operator": "Equal",
"valueText": params, "valueText": params,
}, },
] ],
} }
base_query = client.query.get( base_query = (
namespace, list(list_objects_of_class(namespace, client.schema.get())) client.query.get(
).with_additional( namespace,
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] list(list_objects_of_class(namespace, client.schema.get())),
).with_where(filter_object).with_limit(30).with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE) )
query_output = ( .with_additional(
base_query [
.do() "id",
"creationTimeUnix",
"lastUpdateTimeUnix",
"score",
"distance",
]
)
.with_where(filter_object)
.with_limit(30)
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
) )
query_output = base_query.do()
# from weaviate.classes import Filter # from weaviate.classes import Filter
# client = weaviate.connect_to_wcs( # client = weaviate.connect_to_wcs(
# cluster_url=config.weaviate_url, # cluster_url=config.weaviate_url,
@ -311,20 +378,18 @@ class WeaviateVectorDB(VectorDB):
# ) # )
return query_output return query_output
elif search_type == 'generate': elif search_type == "generate":
generate_prompt = kwargs.get('generate_prompt', "") generate_prompt = kwargs.get("generate_prompt", "")
query_output = ( query_output = (
base_query base_query.with_generate(single_prompt=observation)
.with_generate(single_prompt=observation)
.with_near_text({"concepts": [observation]}) .with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations) .with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'generate_grouped': elif search_type == "generate_grouped":
generate_prompt = kwargs.get('generate_prompt', "") generate_prompt = kwargs.get("generate_prompt", "")
query_output = ( query_output = (
base_query base_query.with_generate(grouped_task=observation)
.with_generate(grouped_task=observation)
.with_near_text({"concepts": [observation]}) .with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations) .with_autocut(n_of_observations)
.do() .do()
@ -338,12 +403,10 @@ class WeaviateVectorDB(VectorDB):
return query_output return query_output
async def delete_memories(self, namespace: str, params: dict = None):
async def delete_memories(self, namespace:str, params: dict = None):
if namespace is None: if namespace is None:
namespace = self.namespace namespace = self.namespace
client = self.init_weaviate(namespace = self.namespace) client = self.init_weaviate(namespace=self.namespace)
if params: if params:
where_filter = { where_filter = {
"path": ["id"], "path": ["id"],
@ -366,7 +429,6 @@ class WeaviateVectorDB(VectorDB):
}, },
) )
async def count_memories(self, namespace: str = None, params: dict = None) -> int: async def count_memories(self, namespace: str = None, params: dict = None) -> int:
""" """
Count memories in a Weaviate database. Count memories in a Weaviate database.
@ -380,7 +442,7 @@ class WeaviateVectorDB(VectorDB):
if namespace is None: if namespace is None:
namespace = self.namespace namespace = self.namespace
client = self.init_weaviate(namespace =namespace) client = self.init_weaviate(namespace=namespace)
try: try:
object_count = client.query.aggregate(namespace).with_meta_count().do() object_count = client.query.aggregate(namespace).with_meta_count().do()
@ -391,7 +453,7 @@ class WeaviateVectorDB(VectorDB):
return 0 return 0
def update_memories(self, observation, namespace: str, params: dict = None): def update_memories(self, observation, namespace: str, params: dict = None):
client = self.init_weaviate(namespace = self.namespace) client = self.init_weaviate(namespace=self.namespace)
client.data_object.update( client.data_object.update(
data_object={ data_object={
@ -416,12 +478,15 @@ class WeaviateVectorDB(VectorDB):
) )
return return
import os import os
import lancedb import lancedb
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
class LanceDB(VectorDB): class LanceDB(VectorDB):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -434,21 +499,28 @@ class LanceDB(VectorDB):
db = lancedb.connect(uri, api_key=os.getenv("LANCEDB_API_KEY")) db = lancedb.connect(uri, api_key=os.getenv("LANCEDB_API_KEY"))
return db return db
def create_table(self, name: str, schema: Optional[pa.Schema] = None, data: Optional[pd.DataFrame] = None): def create_table(
self,
name: str,
schema: Optional[pa.Schema] = None,
data: Optional[pd.DataFrame] = None,
):
# Create a table in LanceDB. If schema is not provided, it will be inferred from the data. # Create a table in LanceDB. If schema is not provided, it will be inferred from the data.
if data is not None and schema is None: if data is not None and schema is None:
schema = pa.Schema.from_pandas(data) schema = pa.Schema.from_pandas(data)
table = self.db.create_table(name, schema=schema) table = self.db.create_table(name, schema=schema)
if data is not None: if data is not None:
table.add(data.to_dict('records')) table.add(data.to_dict("records"))
return table return table
def add_memories(self, table_name: str, data: pd.DataFrame): def add_memories(self, table_name: str, data: pd.DataFrame):
# Add data to an existing table in LanceDB # Add data to an existing table in LanceDB
table = self.db.open_table(table_name) table = self.db.open_table(table_name)
table.add(data.to_dict('records')) table.add(data.to_dict("records"))
def fetch_memories(self, table_name: str, query_vector: List[float], top_k: int = 10): def fetch_memories(
self, table_name: str, query_vector: List[float], top_k: int = 10
):
# Perform a vector search in the specified table # Perform a vector search in the specified table
table = self.db.open_table(table_name) table = self.db.open_table(table_name)
results = table.search(query_vector).limit(top_k).to_pandas() results = table.search(query_vector).limit(top_k).to_pandas()

View file

@ -16,15 +16,16 @@ sys.path.insert(0, parent_dir)
environment = os.getenv("AWS_ENV", "dev") environment = os.getenv("AWS_ENV", "dev")
def fetch_secret(secret_name, region_name, env_file_path): def fetch_secret(secret_name, region_name, env_file_path):
print("Initializing session") print("Initializing session")
session = boto3.session.Session() session = boto3.session.Session()
print("Session initialized") print("Session initialized")
client = session.client(service_name="secretsmanager", region_name = region_name) client = session.client(service_name="secretsmanager", region_name=region_name)
print("Client initialized") print("Client initialized")
try: try:
response = client.get_secret_value(SecretId = secret_name) response = client.get_secret_value(SecretId=secret_name)
except Exception as e: except Exception as e:
print(f"Error retrieving secret: {e}") print(f"Error retrieving secret: {e}")
return None return None
@ -46,6 +47,7 @@ def fetch_secret(secret_name, region_name, env_file_path):
else: else:
print(f"The .env file was not found at: {env_file_path}.") print(f"The .env file was not found at: {env_file_path}.")
ENV_FILE_PATH = os.path.abspath("../.env") ENV_FILE_PATH = os.path.abspath("../.env")
if os.path.exists(ENV_FILE_PATH): if os.path.exists(ENV_FILE_PATH):

View file

@ -6,6 +6,7 @@ from ..shared.data_models import Node, Edge, KnowledgeGraph, GraphQLQuery, Memor
from ..config import Config from ..config import Config
import instructor import instructor
from openai import OpenAI from openai import OpenAI
config = Config() config = Config()
config.load() config.load()
@ -23,7 +24,7 @@ import logging
# Function to read query prompts from files # Function to read query prompts from files
def read_query_prompt(filename): def read_query_prompt(filename):
try: try:
with open(filename, 'r') as file: with open(filename, "r") as file:
return file.read() return file.read()
except FileNotFoundError: except FileNotFoundError:
logging.info(f"Error: File not found. Attempted to read: {filename}") logging.info(f"Error: File not found. Attempted to read: {filename}")
@ -37,7 +38,9 @@ def read_query_prompt(filename):
def generate_graph(input) -> KnowledgeGraph: def generate_graph(input) -> KnowledgeGraph:
model = "gpt-4-1106-preview" model = "gpt-4-1106-preview"
user_prompt = f"Use the given format to extract information from the following input: {input}." user_prompt = f"Use the given format to extract information from the following input: {input}."
system_prompt = read_query_prompt('cognitive_architecture/llm/prompts/generate_graph_prompt.txt') system_prompt = read_query_prompt(
"cognitive_architecture/llm/prompts/generate_graph_prompt.txt"
)
out = aclient.chat.completions.create( out = aclient.chat.completions.create(
model=model, model=model,
@ -56,38 +59,40 @@ def generate_graph(input) -> KnowledgeGraph:
return out return out
async def generate_summary(input) -> MemorySummary: async def generate_summary(input) -> MemorySummary:
out = aclient.chat.completions.create( out = aclient.chat.completions.create(
model=config.model, model=config.model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Use the given format summarize and reduce the following input: {input}. """, "content": f"""Use the given format summarize and reduce the following input: {input}. """,
}, },
{ "role":"system", "content": """You are a top-tier algorithm {
"role": "system",
"content": """You are a top-tier algorithm
designed for summarizing existing knowledge graphs in structured formats based on a knowledge graph. designed for summarizing existing knowledge graphs in structured formats based on a knowledge graph.
## 1. Strict Compliance ## 1. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination. Adhere to the rules strictly. Non-compliance will result in termination.
## 2. Don't forget your main goal is to reduce the number of nodes in the knowledge graph while preserving the information contained in it."""} ## 2. Don't forget your main goal is to reduce the number of nodes in the knowledge graph while preserving the information contained in it.""",
},
], ],
response_model=MemorySummary, response_model=MemorySummary,
) )
return out return out
def user_query_to_edges_and_nodes( input: str) ->KnowledgeGraph: def user_query_to_edges_and_nodes(input: str) -> KnowledgeGraph:
system_prompt = read_query_prompt('cognitive_architecture/llm/prompts/generate_graph_prompt.txt') system_prompt = read_query_prompt(
"cognitive_architecture/llm/prompts/generate_graph_prompt.txt"
)
return aclient.chat.completions.create( return aclient.chat.completions.create(
model=config.model, model=config.model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Use the given format to extract information from the following input: {input}. """, "content": f"""Use the given format to extract information from the following input: {input}. """,
}, },
{"role": "system", "content":system_prompt} {"role": "system", "content": system_prompt},
], ],
response_model=KnowledgeGraph, response_model=KnowledgeGraph,
) )

View file

@ -4,7 +4,6 @@ import os
import time import time
HOST = os.getenv("OPENAI_API_BASE") HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
@ -41,7 +40,9 @@ def retry_with_exponential_backoff(
# Check if max retries has been reached # Check if max retries has been reached
if num_retries > max_retries: if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay # Increment the delay
delay *= exponential_base * (1 + jitter * random.random()) delay *= exponential_base * (1 + jitter * random.random())
@ -90,7 +91,9 @@ def aretry_with_exponential_backoff(
# Check if max retries has been reached # Check if max retries has been reached
if num_retries > max_retries: if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay # Increment the delay
delay *= exponential_base * (1 + jitter * random.random()) delay *= exponential_base * (1 + jitter * random.random())
@ -135,6 +138,3 @@ def get_embedding_with_backoff(text, model="text-embedding-ada-002"):
response = create_embedding_with_backoff(input=[text], model=model) response = create_embedding_with_backoff(input=[text], model=model)
embedding = response["data"][0]["embedding"] embedding = response["data"][0]["embedding"]
return embedding return embedding

View file

@ -1,9 +1,7 @@
DEFAULT_PRESET = "cognitive_architecture_chat" DEFAULT_PRESET = "cognitive_architecture_chat"
preset_options = [DEFAULT_PRESET] preset_options = [DEFAULT_PRESET]
def use_preset(): def use_preset():
"""Placeholder for different present options""" """Placeholder for different present options"""

View file

@ -1,33 +1,36 @@
from typing import Optional, List from typing import Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Node(BaseModel): class Node(BaseModel):
id: int id: int
description: str description: str
category: str category: str
color: str ="blue" color: str = "blue"
memory_type: str memory_type: str
created_at: Optional[float] = None created_at: Optional[float] = None
summarized: Optional[bool] = None summarized: Optional[bool] = None
class Edge(BaseModel): class Edge(BaseModel):
source: int source: int
target: int target: int
description: str description: str
color: str= "blue" color: str = "blue"
created_at: Optional[float] = None created_at: Optional[float] = None
summarized: Optional[bool] = None summarized: Optional[bool] = None
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
nodes: List[Node] = Field(..., default_factory=list) nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list) edges: List[Edge] = Field(..., default_factory=list)
class GraphQLQuery(BaseModel): class GraphQLQuery(BaseModel):
query: str query: str
class MemorySummary(BaseModel): class MemorySummary(BaseModel):
nodes: List[Node] = Field(..., default_factory=list) nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list) edges: List[Edge] = Field(..., default_factory=list)

View file

@ -3,13 +3,15 @@ from botocore.exceptions import BotoCoreError, ClientError
from langdetect import detect, LangDetectException from langdetect import detect, LangDetectException
import iso639 import iso639
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import logging import logging
# Basic configuration of the logging system # Basic configuration of the logging system
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def detect_language(text): def detect_language(text):
@ -34,8 +36,8 @@ def detect_language(text):
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}") logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2) # Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
if detected_lang_iso639_1 == 'hr': if detected_lang_iso639_1 == "hr":
return 'sr' return "sr"
return detected_lang_iso639_1 return detected_lang_iso639_1
except LangDetectException as e: except LangDetectException as e:
@ -46,8 +48,12 @@ def detect_language(text):
return -1 return -1
def translate_text(
def translate_text(text, source_language:str='sr', target_language:str='en', region_name='eu-west-1'): text,
source_language: str = "sr",
target_language: str = "en",
region_name="eu-west-1",
):
""" """
Translate text from source language to target language using AWS Translate. Translate text from source language to target language using AWS Translate.
@ -68,9 +74,15 @@ def translate_text(text, source_language:str='sr', target_language:str='en', reg
return "Both source and target language codes are required." return "Both source and target language codes are required."
try: try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True) translate = boto3.client(
result = translate.translate_text(Text=text, SourceLanguageCode=source_language, TargetLanguageCode=target_language) service_name="translate", region_name=region_name, use_ssl=True
return result.get('TranslatedText', 'No translation found.') )
result = translate.translate_text(
Text=text,
SourceLanguageCode=source_language,
TargetLanguageCode=target_language,
)
return result.get("TranslatedText", "No translation found.")
except BotoCoreError as e: except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}") logging.info(f"BotoCoreError occurred: {e}")
@ -81,8 +93,8 @@ def translate_text(text, source_language:str='sr', target_language:str='en', reg
return "Error with AWS client or network issue." return "Error with AWS client or network issue."
source_language = 'sr' source_language = "sr"
target_language = 'en' target_language = "en"
text_to_translate = "Ja volim da pecam i idem na reku da šetam pored nje ponekad" text_to_translate = "Ja volim da pecam i idem na reku da šetam pored nje ponekad"
translated_text = translate_text(text_to_translate, source_language, target_language) translated_text = translate_text(text_to_translate, source_language, target_language)

View file

@ -22,12 +22,15 @@ class Node:
self.description = description self.description = description
self.color = color self.color = color
class Edge: class Edge:
def __init__(self, source, target, label, color): def __init__(self, source, target, label, color):
self.source = source self.source = source
self.target = target self.target = target
self.label = label self.label = label
self.color = color self.color = color
# def visualize_knowledge_graph(kg: KnowledgeGraph): # def visualize_knowledge_graph(kg: KnowledgeGraph):
# dot = Digraph(comment="Knowledge Graph") # dot = Digraph(comment="Knowledge Graph")
# #
@ -82,6 +85,7 @@ def get_document_names(doc_input):
# doc_input is not valid # doc_input is not valid
return [] return []
def format_dict(d): def format_dict(d):
# Initialize an empty list to store formatted items # Initialize an empty list to store formatted items
formatted_items = [] formatted_items = []
@ -89,7 +93,9 @@ def format_dict(d):
# Iterate through all key-value pairs # Iterate through all key-value pairs
for key, value in d.items(): for key, value in d.items():
# Format key-value pairs with a colon and space, and adding quotes for string values # Format key-value pairs with a colon and space, and adding quotes for string values
formatted_item = f"{key}: '{value}'" if isinstance(value, str) else f"{key}: {value}" formatted_item = (
f"{key}: '{value}'" if isinstance(value, str) else f"{key}: {value}"
)
formatted_items.append(formatted_item) formatted_items.append(formatted_item)
# Join all formatted items with a comma and a space # Join all formatted items with a comma and a space
@ -114,7 +120,7 @@ def create_node_variable_mapping(nodes):
mapping = {} mapping = {}
for node in nodes: for node in nodes:
variable_name = f"{node['category']}{node['id']}".lower() variable_name = f"{node['category']}{node['id']}".lower()
mapping[node['id']] = variable_name mapping[node["id"]] = variable_name
return mapping return mapping
@ -123,18 +129,23 @@ def create_edge_variable_mapping(edges):
for edge in edges: for edge in edges:
# Construct a unique identifier for the edge # Construct a unique identifier for the edge
variable_name = f"edge{edge['source']}to{edge['target']}".lower() variable_name = f"edge{edge['source']}to{edge['target']}".lower()
mapping[(edge['source'], edge['target'])] = variable_name mapping[(edge["source"], edge["target"])] = variable_name
return mapping return mapping
def generate_letter_uuid(length=8): def generate_letter_uuid(length=8):
"""Generate a random string of uppercase letters with the specified length.""" """Generate a random string of uppercase letters with the specified length."""
letters = string.ascii_uppercase # A-Z letters = string.ascii_uppercase # A-Z
return "".join(random.choice(letters) for _ in range(length)) return "".join(random.choice(letters) for _ in range(length))
from cognitive_architecture.database.relationaldb.models.operation import Operation from cognitive_architecture.database.relationaldb.models.operation import Operation
from cognitive_architecture.database.relationaldb.database_crud import session_scope, add_entity, update_entity, fetch_job_id from cognitive_architecture.database.relationaldb.database_crud import (
session_scope,
add_entity,
update_entity,
fetch_job_id,
)
from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas
from cognitive_architecture.database.relationaldb.models.docs import DocsModel from cognitive_architecture.database.relationaldb.models.docs import DocsModel
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
@ -142,42 +153,56 @@ from cognitive_architecture.database.relationaldb.models.user import User
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
import logging import logging
async def get_vectordb_namespace(session: AsyncSession, user_id: str): async def get_vectordb_namespace(session: AsyncSession, user_id: str):
try: try:
result = await session.execute( result = await session.execute(
select(MemoryModel.memory_name).where(MemoryModel.user_id == user_id).order_by(MemoryModel.created_at.desc()) select(MemoryModel.memory_name)
.where(MemoryModel.user_id == user_id)
.order_by(MemoryModel.created_at.desc())
) )
namespace = [row[0] for row in result.fetchall()] namespace = [row[0] for row in result.fetchall()]
return namespace return namespace
except Exception as e: except Exception as e:
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}") logging.error(
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
)
return None return None
async def get_vectordb_document_name(session: AsyncSession, user_id: str): async def get_vectordb_document_name(session: AsyncSession, user_id: str):
try: try:
result = await session.execute( result = await session.execute(
select(DocsModel.doc_name).where(DocsModel.user_id == user_id).order_by(DocsModel.created_at.desc()) select(DocsModel.doc_name)
.where(DocsModel.user_id == user_id)
.order_by(DocsModel.created_at.desc())
) )
doc_names = [row[0] for row in result.fetchall()] doc_names = [row[0] for row in result.fetchall()]
return doc_names return doc_names
except Exception as e: except Exception as e:
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}") logging.error(
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
)
return None return None
async def get_model_id_name(session: AsyncSession, id: str): async def get_model_id_name(session: AsyncSession, id: str):
try: try:
result = await session.execute( result = await session.execute(
select(MemoryModel.memory_name).where(MemoryModel.id == id).order_by(MemoryModel.created_at.desc()) select(MemoryModel.memory_name)
.where(MemoryModel.id == id)
.order_by(MemoryModel.created_at.desc())
) )
doc_names = [row[0] for row in result.fetchall()] doc_names = [row[0] for row in result.fetchall()]
return doc_names return doc_names
except Exception as e: except Exception as e:
logging.error(f"An error occurred while retrieving the Vectordb_namespace: {str(e)}") logging.error(
f"An error occurred while retrieving the Vectordb_namespace: {str(e)}"
)
return None return None
async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: str): async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: str):
""" """
Asynchronously retrieves the latest memory names and document details for a given user. Asynchronously retrieves the latest memory names and document details for a given user.
@ -207,13 +232,13 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
.join(Operation.memories) # Explicit join with memories table .join(Operation.memories) # Explicit join with memories table
.options( .options(
contains_eager(Operation.docs), # Informs ORM of the join for docs contains_eager(Operation.docs), # Informs ORM of the join for docs
contains_eager(Operation.memories) # Informs ORM of the join for memories contains_eager(Operation.memories), # Informs ORM of the join for memories
) )
.where( .where(
(Operation.user_id == user_id) & # Filter by user_id (Operation.user_id == user_id)
or_( & or_( # Filter by user_id
DocsModel.graph_summary == False, # Condition 1: graph_summary is False DocsModel.graph_summary == False, # Condition 1: graph_summary is False
DocsModel.graph_summary == None # Condition 3: graph_summary is None DocsModel.graph_summary == None, # Condition 3: graph_summary is None
) # Filter by user_id ) # Filter by user_id
) )
.order_by(Operation.created_at.desc()) # Order by creation date .order_by(Operation.created_at.desc()) # Order by creation date
@ -223,7 +248,11 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
# Extract memory names and document names and IDs # Extract memory names and document names and IDs
# memory_names = [memory.memory_name for op in operations for memory in op.memories] # memory_names = [memory.memory_name for op in operations for memory in op.memories]
memory_details = [(memory.memory_name, memory.memory_category) for op in operations for memory in op.memories] memory_details = [
(memory.memory_name, memory.memory_category)
for op in operations
for memory in op.memories
]
docs = [(doc.doc_name, doc.id) for op in operations for doc in op.docs] docs = [(doc.doc_name, doc.id) for op in operations for doc in op.docs]
return memory_details, docs return memory_details, docs
@ -232,6 +261,8 @@ async def get_unsumarized_vector_db_namespace(session: AsyncSession, user_id: st
# # Handle the exception as needed # # Handle the exception as needed
# print(f"An error occurred: {e}") # print(f"An error occurred: {e}")
# return None # return None
async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str): async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
""" """
Asynchronously retrieves memory names associated with a specific document ID. Asynchronously retrieves memory names associated with a specific document ID.
@ -254,8 +285,12 @@ async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
try: try:
result = await session.execute( result = await session.execute(
select(MemoryModel.memory_name) select(MemoryModel.memory_name)
.join(Operation, Operation.id == MemoryModel.operation_id) # Join with Operation .join(
.join(DocsModel, DocsModel.operation_id == Operation.id) # Join with DocsModel Operation, Operation.id == MemoryModel.operation_id
) # Join with Operation
.join(
DocsModel, DocsModel.operation_id == Operation.id
) # Join with DocsModel
.where(DocsModel.id == docs_id) # Filtering based on the passed document ID .where(DocsModel.id == docs_id) # Filtering based on the passed document ID
.distinct() # To avoid duplicate memory names .distinct() # To avoid duplicate memory names
) )
@ -269,7 +304,6 @@ async def get_memory_name_by_doc_id(session: AsyncSession, docs_id: str):
return None return None
# #
# async def main(): # async def main():
# user_id = "user" # user_id = "user"

View file

@ -9,8 +9,6 @@ import os
print(os.getcwd()) print(os.getcwd())
from cognitive_architecture.database.relationaldb.models.user import User from cognitive_architecture.database.relationaldb.models.user import User
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
@ -27,7 +25,6 @@ import uuid
load_dotenv() load_dotenv()
from cognitive_architecture.database.vectordb.basevectordb import BaseMemory from cognitive_architecture.database.vectordb.basevectordb import BaseMemory
from cognitive_architecture.config import Config from cognitive_architecture.config import Config
@ -36,8 +33,6 @@ config = Config()
config.load() config.load()
class DynamicBaseMemory(BaseMemory): class DynamicBaseMemory(BaseMemory):
def __init__( def __init__(
self, self,
@ -145,8 +140,8 @@ class Memory:
db_type: str = None, db_type: str = None,
namespace: str = None, namespace: str = None,
memory_id: str = None, memory_id: str = None,
memory_class = None, memory_class=None,
job_id:str = None job_id: str = None,
) -> None: ) -> None:
self.load_environment_variables() self.load_environment_variables()
self.memory_id = memory_id self.memory_id = memory_id
@ -157,20 +152,25 @@ class Memory:
self.namespace = namespace self.namespace = namespace
self.memory_instances = [] self.memory_instances = []
self.memory_class = memory_class self.memory_class = memory_class
self.job_id=job_id self.job_id = job_id
# self.memory_class = DynamicBaseMemory( # self.memory_class = DynamicBaseMemory(
# "Memory", user_id, str(self.memory_id), index_name, db_type, namespace # "Memory", user_id, str(self.memory_id), index_name, db_type, namespace
# ) # )
def load_environment_variables(self) -> None: def load_environment_variables(self) -> None:
load_dotenv() load_dotenv()
self.OPENAI_TEMPERATURE = config.openai_temperature self.OPENAI_TEMPERATURE = config.openai_temperature
self.OPENAI_API_KEY = config.openai_key self.OPENAI_API_KEY = config.openai_key
@classmethod @classmethod
async def create_memory(cls, user_id: str, session, job_id:str=None, memory_label:str=None, **kwargs): async def create_memory(
cls,
user_id: str,
session,
job_id: str = None,
memory_label: str = None,
**kwargs,
):
""" """
Class method that acts as a factory method for creating Memory instances. Class method that acts as a factory method for creating Memory instances.
It performs necessary DB checks or updates before instance creation. It performs necessary DB checks or updates before instance creation.
@ -180,9 +180,14 @@ class Memory:
if existing_user: if existing_user:
# Handle existing user scenario... # Handle existing user scenario...
memory_id = await cls.check_existing_memory(user_id,memory_label, session) memory_id = await cls.check_existing_memory(user_id, memory_label, session)
if memory_id is None: if memory_id is None:
memory_id = await cls.handle_new_memory(user_id = user_id, session= session,job_id=job_id, memory_name= memory_label) memory_id = await cls.handle_new_memory(
user_id=user_id,
session=session,
job_id=job_id,
memory_name=memory_label,
)
logging.info( logging.info(
f"Existing user {user_id} found in the DB. Memory ID: {memory_id}" f"Existing user {user_id} found in the DB. Memory ID: {memory_id}"
) )
@ -190,16 +195,33 @@ class Memory:
# Handle new user scenario... # Handle new user scenario...
await cls.handle_new_user(user_id, session) await cls.handle_new_user(user_id, session)
memory_id = await cls.handle_new_memory(user_id =user_id, session=session, job_id=job_id, memory_name= memory_label) memory_id = await cls.handle_new_memory(
user_id=user_id,
session=session,
job_id=job_id,
memory_name=memory_label,
)
logging.info( logging.info(
f"New user {user_id} created in the DB. Memory ID: {memory_id}" f"New user {user_id} created in the DB. Memory ID: {memory_id}"
) )
memory_class = DynamicBaseMemory( memory_class = DynamicBaseMemory(
memory_label, user_id, str(memory_id), index_name=memory_label , db_type=config.vectordb, **kwargs memory_label,
user_id,
str(memory_id),
index_name=memory_label,
db_type=config.vectordb,
**kwargs,
) )
return cls(user_id=user_id, session=session, memory_id=memory_id, job_id =job_id, memory_class=memory_class, **kwargs) return cls(
user_id=user_id,
session=session,
memory_id=memory_id,
job_id=job_id,
memory_class=memory_class,
**kwargs,
)
async def list_memory_classes(self): async def list_memory_classes(self):
""" """
@ -215,19 +237,20 @@ class Memory:
return result.scalar_one_or_none() return result.scalar_one_or_none()
@staticmethod @staticmethod
async def check_existing_memory(user_id: str, memory_label:str, session): async def check_existing_memory(user_id: str, memory_label: str, session):
"""Check if a user memory exists in the DB and return it. Filters by user and label""" """Check if a user memory exists in the DB and return it. Filters by user and label"""
try: try:
result = await session.execute( result = await session.execute(
select(MemoryModel.id).where(MemoryModel.user_id == user_id) select(MemoryModel.id)
.where(MemoryModel.user_id == user_id)
.filter_by(memory_name=memory_label) .filter_by(memory_name=memory_label)
.order_by(MemoryModel.created_at) .order_by(MemoryModel.created_at)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {str(e)}") logging.error(f"An error occurred: {str(e)}")
return None return None
@staticmethod @staticmethod
async def handle_new_user(user_id: str, session): async def handle_new_user(user_id: str, session):
""" """
@ -251,7 +274,13 @@ class Memory:
return f"Error creating user: {str(e)}" return f"Error creating user: {str(e)}"
@staticmethod @staticmethod
async def handle_new_memory(user_id: str, session, job_id: str = None, memory_name: str = None, memory_category:str='PUBLIC'): async def handle_new_memory(
user_id: str,
session,
job_id: str = None,
memory_name: str = None,
memory_category: str = "PUBLIC",
):
""" """
Handle new memory creation associated with a user. Handle new memory creation associated with a user.
@ -296,7 +325,6 @@ class Memory:
except Exception as e: except Exception as e:
return f"Error creating memory: {str(e)}" return f"Error creating memory: {str(e)}"
async def add_memory_instance(self, memory_class_name: str): async def add_memory_instance(self, memory_class_name: str):
"""Add a new memory instance to the memory_instances list.""" """Add a new memory instance to the memory_instances list."""
instance = DynamicBaseMemory( instance = DynamicBaseMemory(
@ -446,7 +474,9 @@ async def main():
from database.relationaldb.database import AsyncSessionLocal from database.relationaldb.database import AsyncSessionLocal
async with session_scope(AsyncSessionLocal()) as session: async with session_scope(AsyncSessionLocal()) as session:
memory = await Memory.create_memory("677", session, "SEMANTICMEMORY", namespace="SEMANTICMEMORY") memory = await Memory.create_memory(
"677", session, "SEMANTICMEMORY", namespace="SEMANTICMEMORY"
)
ff = memory.memory_instances ff = memory.memory_instances
logging.info("ssss %s", ff) logging.info("ssss %s", ff)
@ -462,8 +492,13 @@ async def main():
await memory.add_dynamic_memory_class("semanticmemory", "SEMANTICMEMORY") await memory.add_dynamic_memory_class("semanticmemory", "SEMANTICMEMORY")
await memory.add_method_to_class(memory.semanticmemory_class, "add_memories") await memory.add_method_to_class(memory.semanticmemory_class, "add_memories")
await memory.add_method_to_class(memory.semanticmemory_class, "fetch_memories") await memory.add_method_to_class(memory.semanticmemory_class, "fetch_memories")
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories', sss = await memory.dynamic_method_call(
observation='some_observation', params=params, loader_settings=loader_settings) memory.semanticmemory_class,
"add_memories",
observation="some_observation",
params=params,
loader_settings=loader_settings,
)
# susu = await memory.dynamic_method_call( # susu = await memory.dynamic_method_call(
# memory.semanticmemory_class, # memory.semanticmemory_class,

435
main.py
View file

@ -7,7 +7,10 @@ from cognitive_architecture.database.relationaldb.models.memory import MemoryMod
from cognitive_architecture.classifiers.classifier import classify_documents from cognitive_architecture.classifiers.classifier import classify_documents
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from cognitive_architecture.database.relationaldb.database_crud import session_scope, update_entity_graph_summary from cognitive_architecture.database.relationaldb.database_crud import (
session_scope,
update_entity_graph_summary,
)
from cognitive_architecture.database.relationaldb.database import AsyncSessionLocal from cognitive_architecture.database.relationaldb.database import AsyncSessionLocal
from cognitive_architecture.utils import generate_letter_uuid from cognitive_architecture.utils import generate_letter_uuid
import instructor import instructor
@ -17,12 +20,18 @@ from cognitive_architecture.database.relationaldb.database_crud import fetch_job
import uuid import uuid
from cognitive_architecture.database.relationaldb.models.sessions import Session from cognitive_architecture.database.relationaldb.models.sessions import Session
from cognitive_architecture.database.relationaldb.models.operation import Operation from cognitive_architecture.database.relationaldb.models.operation import Operation
from cognitive_architecture.database.relationaldb.database_crud import session_scope, add_entity, update_entity, fetch_job_id from cognitive_architecture.database.relationaldb.database_crud import (
session_scope,
add_entity,
update_entity,
fetch_job_id,
)
from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas from cognitive_architecture.database.relationaldb.models.metadatas import MetaDatas
from cognitive_architecture.database.relationaldb.models.docs import DocsModel from cognitive_architecture.database.relationaldb.models.docs import DocsModel
from cognitive_architecture.database.relationaldb.models.memory import MemoryModel from cognitive_architecture.database.relationaldb.models.memory import MemoryModel
from cognitive_architecture.database.relationaldb.models.user import User from cognitive_architecture.database.relationaldb.models.user import User
from cognitive_architecture.classifiers.classifier import classify_call from cognitive_architecture.classifiers.classifier import classify_call
aclient = instructor.patch(OpenAI()) aclient = instructor.patch(OpenAI())
DEFAULT_PRESET = "promethai_chat" DEFAULT_PRESET = "promethai_chat"
preset_options = [DEFAULT_PRESET] preset_options = [DEFAULT_PRESET]
@ -30,6 +39,7 @@ PROMETHAI_DIR = os.path.join(os.path.expanduser("~"), ".")
load_dotenv() load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
from cognitive_architecture.config import Config from cognitive_architecture.config import Config
config = Config() config = Config()
config.load() config.load()
from cognitive_architecture.utils import get_document_names from cognitive_architecture.utils import get_document_names
@ -37,14 +47,28 @@ from sqlalchemy.orm import selectinload, joinedload, contains_eager
import logging import logging
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from cognitive_architecture.utils import get_document_names, generate_letter_uuid, get_memory_name_by_doc_id, get_unsumarized_vector_db_namespace, get_vectordb_namespace, get_vectordb_document_name from cognitive_architecture.utils import (
from cognitive_architecture.shared.language_processing import translate_text, detect_language get_document_names,
generate_letter_uuid,
get_memory_name_by_doc_id,
get_unsumarized_vector_db_namespace,
get_vectordb_namespace,
get_vectordb_document_name,
)
from cognitive_architecture.shared.language_processing import (
translate_text,
detect_language,
)
from cognitive_architecture.classifiers.classifier import classify_user_input from cognitive_architecture.classifiers.classifier import classify_user_input
async def fetch_document_vectordb_namespace(session: AsyncSession, user_id: str, namespace_id:str, doc_id:str=None):
logging.info("user id is", user_id)
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, memory_label=namespace_id)
async def fetch_document_vectordb_namespace(
session: AsyncSession, user_id: str, namespace_id: str, doc_id: str = None
):
logging.info("user id is", user_id)
memory = await Memory.create_memory(
user_id, session, namespace=namespace_id, memory_label=namespace_id
)
# Managing memory attributes # Managing memory attributes
existing_user = await Memory.check_existing_user(user_id, session) existing_user = await Memory.check_existing_user(user_id, session)
@ -66,15 +90,26 @@ async def fetch_document_vectordb_namespace(session: AsyncSession, user_id: str,
print(f"No attribute named in memory.") print(f"No attribute named in memory.")
print("Available memory classes:", await memory.list_memory_classes()) print("Available memory classes:", await memory.list_memory_classes())
result = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories', result = await memory.dynamic_method_call(
observation="placeholder", search_type="summary_filter_by_object_name", params=doc_id) dynamic_memory_class,
"fetch_memories",
observation="placeholder",
search_type="summary_filter_by_object_name",
params=doc_id,
)
logging.info("Result is %s", str(result)) logging.info("Result is %s", str(result))
return result, namespace_id return result, namespace_id
async def load_documents_to_vectorstore(
async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, content:str=None, job_id:str=None, loader_settings:dict=None, memory_type:str="PRIVATE"): session: AsyncSession,
user_id: str,
content: str = None,
job_id: str = None,
loader_settings: dict = None,
memory_type: str = "PRIVATE",
):
namespace_id = str(generate_letter_uuid()) + "_" + "SEMANTICMEMORY" namespace_id = str(generate_letter_uuid()) + "_" + "SEMANTICMEMORY"
namespace_class = namespace_id + "_class" namespace_class = namespace_id + "_class"
@ -96,12 +131,21 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
operation_type="DATA_LOAD", operation_type="DATA_LOAD",
), ),
) )
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, job_id=job_id, memory_label=namespace_id) memory = await Memory.create_memory(
user_id,
session,
namespace=namespace_id,
job_id=job_id,
memory_label=namespace_id,
)
if content is not None: if content is not None:
document_names = [content[:30]] document_names = [content[:30]]
if loader_settings is not None: if loader_settings is not None:
document_source = loader_settings.get("document_names") if isinstance(loader_settings.get("document_names"), document_source = (
list) else loader_settings.get("path", "None") loader_settings.get("document_names")
if isinstance(loader_settings.get("document_names"), list)
else loader_settings.get("path", "None")
)
logging.info("Document source is %s", document_source) logging.info("Document source is %s", document_source)
# try: # try:
document_names = get_document_names(document_source[0]) document_names = get_document_names(document_source[0])
@ -109,12 +153,21 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
# except: # except:
# document_names = document_source # document_names = document_source
for doc in document_names: for doc in document_names:
from cognitive_architecture.shared.language_processing import translate_text, detect_language from cognitive_architecture.shared.language_processing import (
#translates doc titles to english translate_text,
detect_language,
)
# translates doc titles to english
if loader_settings is not None: if loader_settings is not None:
logging.info("Detecting language of document %s", doc) logging.info("Detecting language of document %s", doc)
loader_settings["single_document_path"]= loader_settings.get("path", "None")[0] +"/"+doc loader_settings["single_document_path"] = (
logging.info("Document path is %s", loader_settings.get("single_document_path", "None")) loader_settings.get("path", "None")[0] + "/" + doc
)
logging.info(
"Document path is %s",
loader_settings.get("single_document_path", "None"),
)
memory_category = loader_settings.get("memory_category", "PUBLIC") memory_category = loader_settings.get("memory_category", "PUBLIC")
if loader_settings is None: if loader_settings is None:
memory_category = "CUSTOM" memory_category = "CUSTOM"
@ -122,7 +175,7 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
doc_ = doc.strip(".pdf").replace("-", " ") doc_ = doc.strip(".pdf").replace("-", " ")
doc_ = translate_text(doc_, "sr", "en") doc_ = translate_text(doc_, "sr", "en")
else: else:
doc_=doc doc_ = doc
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
logging.info("Document name is %s", doc_) logging.info("Document name is %s", doc_)
@ -131,17 +184,15 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
DocsModel( DocsModel(
id=doc_id, id=doc_id,
operation_id=job_id, operation_id=job_id,
graph_summary= False, graph_summary=False,
memory_category= memory_category, memory_category=memory_category,
doc_name=doc_ doc_name=doc_,
) ),
) )
# Managing memory attributes # Managing memory attributes
existing_user = await Memory.check_existing_user(user_id, session) existing_user = await Memory.check_existing_user(user_id, session)
await memory.manage_memory_attributes(existing_user) await memory.manage_memory_attributes(existing_user)
params = { params = {"doc_id": doc_id}
"doc_id":doc_id
}
print("Namespace id is %s", namespace_id) print("Namespace id is %s", namespace_id)
await memory.add_dynamic_memory_class(namespace_id.lower(), namespace_id) await memory.add_dynamic_memory_class(namespace_id.lower(), namespace_id)
@ -157,13 +208,18 @@ async def load_documents_to_vectorstore(session: AsyncSession, user_id: str, con
print(f"No attribute named in memory.") print(f"No attribute named in memory.")
print("Available memory classes:", await memory.list_memory_classes()) print("Available memory classes:", await memory.list_memory_classes())
result = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories', result = await memory.dynamic_method_call(
observation=content, params=params, loader_settings=loader_settings) dynamic_memory_class,
"add_memories",
observation=content,
params=params,
loader_settings=loader_settings,
)
await update_entity(session, Operation, job_id, "SUCCESS") await update_entity(session, Operation, job_id, "SUCCESS")
return 1 return 1
async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_input: str):
async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_input: str):
try: try:
new_user = User(id=user_id) new_user = User(id=user_id)
await add_entity(session, new_user) await add_entity(session, new_user)
@ -189,19 +245,32 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
else: else:
translated_query = query_input translated_query = query_input
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, password=config.graph_database_password) neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
cypher_query = await neo4j_graph_db.generate_cypher_query_for_user_prompt_decomposition(user_id, translated_query) cypher_query = (
await neo4j_graph_db.generate_cypher_query_for_user_prompt_decomposition(
user_id, translated_query
)
)
result = neo4j_graph_db.query(cypher_query) result = neo4j_graph_db.query(cypher_query)
neo4j_graph_db.run_merge_query(user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8) neo4j_graph_db.run_merge_query(
neo4j_graph_db.run_merge_query(user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8) user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.8
)
neo4j_graph_db.run_merge_query(
user_id=user_id, memory_type="EpisodicMemory", similarity_threshold=0.8
)
neo4j_graph_db.close() neo4j_graph_db.close()
await update_entity(session, Operation, job_id, "SUCCESS") await update_entity(session, Operation, job_id, "SUCCESS")
return result return result
# async def add_documents_to_graph_db(session: AsyncSession, user_id: Optional[str] = None, # async def add_documents_to_graph_db(session: AsyncSession, user_id: Optional[str] = None,
# document_memory_types: Optional[List[str]] = None): # document_memory_types: Optional[List[str]] = None):
# """ Add documents to a graph database, handling multiple memory types """ # """ Add documents to a graph database, handling multiple memory types """
@ -256,106 +325,159 @@ async def user_query_to_graph_db(session: AsyncSession, user_id: str, query_inpu
# return e # return e
async def add_documents_to_graph_db(session: AsyncSession, user_id: str= None, document_memory_types:list=None): async def add_documents_to_graph_db(
session: AsyncSession, user_id: str = None, document_memory_types: list = None
):
"""""" """"""
if document_memory_types is None: if document_memory_types is None:
document_memory_types = ['PUBLIC'] document_memory_types = ["PUBLIC"]
logging.info("Document memory types are", document_memory_types) logging.info("Document memory types are", document_memory_types)
try: try:
# await update_document_vectordb_namespace(postgres_session, user_id) # await update_document_vectordb_namespace(postgres_session, user_id)
memory_details, docs = await get_unsumarized_vector_db_namespace(session, user_id) memory_details, docs = await get_unsumarized_vector_db_namespace(
session, user_id
)
logging.info("Docs are", docs) logging.info("Docs are", docs)
memory_details= [detail for detail in memory_details if detail[1] in document_memory_types] memory_details = [
detail for detail in memory_details if detail[1] in document_memory_types
]
logging.info("Memory details", memory_details) logging.info("Memory details", memory_details)
for doc in docs: for doc in docs:
logging.info("Memory names are", memory_details) logging.info("Memory names are", memory_details)
doc_name, doc_id = doc doc_name, doc_id = doc
logging.info("Doc id is", doc_id) logging.info("Doc id is", doc_id)
try: try:
classification_content = await fetch_document_vectordb_namespace(session, user_id, memory_details[0][0], doc_id) classification_content = await fetch_document_vectordb_namespace(
retrieval_chunks = [item['text'] for item in session, user_id, memory_details[0][0], doc_id
classification_content[0]['data']['Get'][memory_details[0][0]]] )
retrieval_chunks = [
item["text"]
for item in classification_content[0]["data"]["Get"][
memory_details[0][0]
]
]
logging.info("Classification content is", classification_content) logging.info("Classification content is", classification_content)
except: except:
classification_content = "" classification_content = ""
retrieval_chunks = "" retrieval_chunks = ""
# retrieval_chunks = [item['text'] for item in classification_content[0]['data']['Get'][memory_details[0]]] # retrieval_chunks = [item['text'] for item in classification_content[0]['data']['Get'][memory_details[0]]]
# Concatenating the extracted text values # Concatenating the extracted text values
concatenated_retrievals = ' '.join(retrieval_chunks) concatenated_retrievals = " ".join(retrieval_chunks)
print(concatenated_retrievals) print(concatenated_retrievals)
logging.info("Retrieval chunks are", retrieval_chunks) logging.info("Retrieval chunks are", retrieval_chunks)
classification = await classify_documents(doc_name, document_id =doc_id, content=concatenated_retrievals) classification = await classify_documents(
doc_name, document_id=doc_id, content=concatenated_retrievals
)
logging.info("Classification is %s", str(classification)) logging.info("Classification is %s", str(classification))
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
if document_memory_types == ['PUBLIC']: username=config.graph_database_username,
await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory") password=config.graph_database_password,
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic="PublicMemory") )
if document_memory_types == ["PUBLIC"]:
await create_public_memory(
user_id=user_id, labels=["sr"], topic="PublicMemory"
)
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
topic="PublicMemory"
)
neo4j_graph_db.close() neo4j_graph_db.close()
print(ids) print(ids)
else: else:
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic="SemanticMemory") ids = neo4j_graph_db.retrieve_node_id_for_memory_type(
topic="SemanticMemory"
)
neo4j_graph_db.close() neo4j_graph_db.close()
print(ids) print(ids)
for id in ids: for id in ids:
print(id.get('memoryId')) print(id.get("memoryId"))
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
if document_memory_types == ['PUBLIC']: username=config.graph_database_username,
password=config.graph_database_password,
rs = neo4j_graph_db.create_document_node_cypher(classification, user_id, public_memory_id=id.get('memoryId')) )
if document_memory_types == ["PUBLIC"]:
rs = neo4j_graph_db.create_document_node_cypher(
classification, user_id, public_memory_id=id.get("memoryId")
)
neo4j_graph_db.close() neo4j_graph_db.close()
else: else:
rs = neo4j_graph_db.create_document_node_cypher(classification, user_id, memory_type='SemanticMemory') rs = neo4j_graph_db.create_document_node_cypher(
classification, user_id, memory_type="SemanticMemory"
)
neo4j_graph_db.close() neo4j_graph_db.close()
logging.info("Cypher query is %s", str(rs)) logging.info("Cypher query is %s", str(rs))
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
neo4j_graph_db.query(rs) neo4j_graph_db.query(rs)
neo4j_graph_db.close() neo4j_graph_db.close()
logging.info("WE GOT HERE") logging.info("WE GOT HERE")
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
if memory_details[0][1] == "PUBLIC": if memory_details[0][1] == "PUBLIC":
neo4j_graph_db.update_document_node_with_db_ids(
neo4j_graph_db.update_document_node_with_db_ids( vectordb_namespace=memory_details[0][0], vectordb_namespace=memory_details[0][0], document_id=doc_id
document_id=doc_id) )
neo4j_graph_db.close() neo4j_graph_db.close()
else: else:
neo4j_graph_db.update_document_node_with_db_ids( vectordb_namespace=memory_details[0][0], neo4j_graph_db.update_document_node_with_db_ids(
document_id=doc_id, user_id=user_id) vectordb_namespace=memory_details[0][0],
document_id=doc_id,
user_id=user_id,
)
neo4j_graph_db.close() neo4j_graph_db.close()
# await update_entity_graph_summary(session, DocsModel, doc_id, True) # await update_entity_graph_summary(session, DocsModel, doc_id, True)
except Exception as e: except Exception as e:
return e return e
class ResponseString(BaseModel): class ResponseString(BaseModel):
response: str = Field(default=None) # Defaulting to None or you can use a default string like "" response: str = Field(
default=None
) # Defaulting to None or you can use a default string like ""
quotation: str = Field(default=None) # Same here quotation: str = Field(default=None) # Same here
# #
def generate_graph(input) -> ResponseString: def generate_graph(input) -> ResponseString:
out = aclient.chat.completions.create( out = aclient.chat.completions.create(
model="gpt-4-1106-preview", model="gpt-4-1106-preview",
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Use the given context to answer query and use help of associated context: {input}. """, "content": f"""Use the given context to answer query and use help of associated context: {input}. """,
}, },
{ "role":"system", "content": """You are a top-tier algorithm {
"role": "system",
"content": """You are a top-tier algorithm
designed for using context summaries based on cognitive psychology to answer user queries, and provide a simple response. designed for using context summaries based on cognitive psychology to answer user queries, and provide a simple response.
Do not mention anything explicit about cognitive architecture, but use the context to answer the query. If you are using a document, reference document metadata field"""} Do not mention anything explicit about cognitive architecture, but use the context to answer the query. If you are using a document, reference document metadata field""",
},
], ],
response_model=ResponseString, response_model=ResponseString,
) )
return out return out
async def user_context_enrichment(session, user_id:str, query:str, generative_response:bool=False, memory_type:str=None)->str:
async def user_context_enrichment(
session,
user_id: str,
query: str,
generative_response: bool = False,
memory_type: str = None,
) -> str:
""" """
Asynchronously enriches the user context by integrating various memory systems and document classifications. Asynchronously enriches the user context by integrating various memory systems and document classifications.
@ -387,31 +509,38 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
enriched_context = await user_context_enrichment(session, "user123", "How does cognitive architecture work?") enriched_context = await user_context_enrichment(session, "user123", "How does cognitive architecture work?")
``` ```
""" """
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
# await user_query_to_graph_db(session, user_id, query) # await user_query_to_graph_db(session, user_id, query)
semantic_mem = neo4j_graph_db.retrieve_semantic_memory(user_id=user_id) semantic_mem = neo4j_graph_db.retrieve_semantic_memory(user_id=user_id)
neo4j_graph_db.close() neo4j_graph_db.close()
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
episodic_mem = neo4j_graph_db.retrieve_episodic_memory(user_id=user_id) episodic_mem = neo4j_graph_db.retrieve_episodic_memory(user_id=user_id)
neo4j_graph_db.close() neo4j_graph_db.close()
# public_mem = neo4j_graph_db.retrieve_public_memory(user_id=user_id) # public_mem = neo4j_graph_db.retrieve_public_memory(user_id=user_id)
if detect_language(query) != "en": if detect_language(query) != "en":
query = translate_text(query, "sr", "en") query = translate_text(query, "sr", "en")
logging.info("Translated query is %s", str(query)) logging.info("Translated query is %s", str(query))
if memory_type=='PublicMemory': if memory_type == "PublicMemory":
neo4j_graph_db = Neo4jGraphDB(
url=config.graph_database_url,
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, username=config.graph_database_username,
password=config.graph_database_password) password=config.graph_database_password,
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(user_id=user_id, memory_type=memory_type) )
summaries = await neo4j_graph_db.get_memory_linked_document_summaries(
user_id=user_id, memory_type=memory_type
)
neo4j_graph_db.close() neo4j_graph_db.close()
logging.info("Summaries are is %s", summaries) logging.info("Summaries are is %s", summaries)
# logging.info("Context from graphdb is %s", context) # logging.info("Context from graphdb is %s", context)
@ -424,7 +553,9 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
relevant_summary_id = None relevant_summary_id = None
for _ in range(max_attempts): for _ in range(max_attempts):
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries)) relevant_summary_id = await classify_call(
query=query, document_summaries=str(summaries)
)
logging.info("Relevant summary id is %s", relevant_summary_id) logging.info("Relevant summary id is %s", relevant_summary_id)
@ -432,22 +563,32 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
break break
# logging.info("Relevant categories after the classifier are %s", relevant_categories) # logging.info("Relevant categories after the classifier are %s", relevant_categories)
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(user_id, summary_id = relevant_summary_id, memory_type=memory_type) username=config.graph_database_username,
password=config.graph_database_password,
)
postgres_id = await neo4j_graph_db.get_memory_linked_document_ids(
user_id, summary_id=relevant_summary_id, memory_type=memory_type
)
neo4j_graph_db.close() neo4j_graph_db.close()
# postgres_id = neo4j_graph_db.query(get_doc_ids) # postgres_id = neo4j_graph_db.query(get_doc_ids)
logging.info("Postgres ids are %s", postgres_id) logging.info("Postgres ids are %s", postgres_id)
namespace_id = await get_memory_name_by_doc_id(session, postgres_id[0]) namespace_id = await get_memory_name_by_doc_id(session, postgres_id[0])
logging.info("Namespace ids are %s", namespace_id) logging.info("Namespace ids are %s", namespace_id)
params= {"doc_id":postgres_id[0]} params = {"doc_id": postgres_id[0]}
namespace_id = namespace_id[0] namespace_id = namespace_id[0]
namespace_class = namespace_id + "_class" namespace_class = namespace_id + "_class"
if memory_type =='PublicMemory': if memory_type == "PublicMemory":
user_id = 'system_user' user_id = "system_user"
memory = await Memory.create_memory(user_id, session, namespace=namespace_id, job_id="23232", memory = await Memory.create_memory(
memory_label=namespace_id) user_id,
session,
namespace=namespace_id,
job_id="23232",
memory_label=namespace_id,
)
existing_user = await Memory.check_existing_user(user_id, session) existing_user = await Memory.check_existing_user(user_id, session)
print("here is the existing user", existing_user) print("here is the existing user", existing_user)
@ -468,17 +609,26 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
print(f"No attribute named in memory.") print(f"No attribute named in memory.")
print("Available memory classes:", await memory.list_memory_classes()) print("Available memory classes:", await memory.list_memory_classes())
results = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories', results = await memory.dynamic_method_call(
observation=query, params=postgres_id[0], search_type="summary_filter_by_object_name") dynamic_memory_class,
"fetch_memories",
observation=query,
params=postgres_id[0],
search_type="summary_filter_by_object_name",
)
logging.info("Result is %s", str(results)) logging.info("Result is %s", str(results))
search_context = "" search_context = ""
for result in results['data']['Get'][namespace_id]: for result in results["data"]["Get"][namespace_id]:
# Assuming 'result' is a dictionary and has keys like 'source', 'text' # Assuming 'result' is a dictionary and has keys like 'source', 'text'
source = result['source'].replace('-', ' ').replace('.pdf', '').replace('.data/', '') source = (
text = result['text'] result["source"]
.replace("-", " ")
.replace(".pdf", "")
.replace(".data/", "")
)
text = result["text"]
search_context += f"Document source: {source}, Document text: {text} \n" search_context += f"Document source: {source}, Document text: {text} \n"
else: else:
@ -502,7 +652,9 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
return generative_result.model_dump_json() return generative_result.model_dump_json()
async def create_public_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]: async def create_public_memory(
user_id: str = None, labels: list = None, topic: str = None
) -> Optional[int]:
""" """
Create a public memory node associated with a user in a Neo4j graph database. Create a public memory node associated with a user in a Neo4j graph database.
If Public Memory exists, it will return the id of the memory. If Public Memory exists, it will return the id of the memory.
@ -521,16 +673,17 @@ async def create_public_memory(user_id: str=None, labels:list=None, topic:str=No
""" """
# Validate input parameters # Validate input parameters
if not labels: if not labels:
labels = ['sr'] # Labels for the memory node labels = ["sr"] # Labels for the memory node
if not topic: if not topic:
topic = "PublicMemory" topic = "PublicMemory"
try: try:
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, neo4j_graph_db = Neo4jGraphDB(
username=config.graph_database_username, url=config.graph_database_url,
password=config.graph_database_password) username=config.graph_database_username,
password=config.graph_database_password,
)
# Assuming the topic for public memory is predefined, e.g., "PublicMemory" # Assuming the topic for public memory is predefined, e.g., "PublicMemory"
# Create the memory node # Create the memory node
@ -541,7 +694,10 @@ async def create_public_memory(user_id: str=None, labels:list=None, topic:str=No
logging.error(f"Error creating public memory node: {e}") logging.error(f"Error creating public memory node: {e}")
return None return None
async def attach_user_to_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]:
async def attach_user_to_memory(
user_id: str = None, labels: list = None, topic: str = None
) -> Optional[int]:
""" """
Link user to public memory Link user to public memory
@ -560,33 +716,41 @@ async def attach_user_to_memory(user_id: str=None, labels:list=None, topic:str=N
if not user_id: if not user_id:
raise ValueError("User ID is required.") raise ValueError("User ID is required.")
if not labels: if not labels:
labels = ['sr'] # Labels for the memory node labels = ["sr"] # Labels for the memory node
if not topic: if not topic:
topic = "PublicMemory" topic = "PublicMemory"
try: try:
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, neo4j_graph_db = Neo4jGraphDB(
username=config.graph_database_username, url=config.graph_database_url,
password=config.graph_database_password) username=config.graph_database_username,
password=config.graph_database_password,
)
# Assuming the topic for public memory is predefined, e.g., "PublicMemory" # Assuming the topic for public memory is predefined, e.g., "PublicMemory"
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic) ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
neo4j_graph_db.close() neo4j_graph_db.close()
for id in ids: for id in ids:
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, neo4j_graph_db = Neo4jGraphDB(
username=config.graph_database_username, url=config.graph_database_url,
password=config.graph_database_password) username=config.graph_database_username,
linked_memory = neo4j_graph_db.link_public_memory_to_user(memory_id=id.get('memoryId'), user_id=user_id) password=config.graph_database_password,
)
linked_memory = neo4j_graph_db.link_public_memory_to_user(
memory_id=id.get("memoryId"), user_id=user_id
)
neo4j_graph_db.close() neo4j_graph_db.close()
return 1 return 1
except Neo4jError as e: except Neo4jError as e:
logging.error(f"Error creating public memory node: {e}") logging.error(f"Error creating public memory node: {e}")
return None return None
async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str=None) -> Optional[int]:
async def unlink_user_from_memory(
user_id: str = None, labels: list = None, topic: str = None
) -> Optional[int]:
""" """
Unlink user from memory Unlink user from memory
@ -604,34 +768,39 @@ async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str
if not user_id: if not user_id:
raise ValueError("User ID is required.") raise ValueError("User ID is required.")
if not labels: if not labels:
labels = ['sr'] # Labels for the memory node labels = ["sr"] # Labels for the memory node
if not topic: if not topic:
topic = "PublicMemory" topic = "PublicMemory"
try: try:
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, neo4j_graph_db = Neo4jGraphDB(
username=config.graph_database_username, url=config.graph_database_url,
password=config.graph_database_password) username=config.graph_database_username,
password=config.graph_database_password,
)
# Assuming the topic for public memory is predefined, e.g., "PublicMemory" # Assuming the topic for public memory is predefined, e.g., "PublicMemory"
ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic) ids = neo4j_graph_db.retrieve_node_id_for_memory_type(topic=topic)
neo4j_graph_db.close() neo4j_graph_db.close()
for id in ids: for id in ids:
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, neo4j_graph_db = Neo4jGraphDB(
username=config.graph_database_username, url=config.graph_database_url,
password=config.graph_database_password) username=config.graph_database_username,
linked_memory = neo4j_graph_db.unlink_memory_from_user(memory_id=id.get('memoryId'), user_id=user_id) password=config.graph_database_password,
)
linked_memory = neo4j_graph_db.unlink_memory_from_user(
memory_id=id.get("memoryId"), user_id=user_id
)
neo4j_graph_db.close() neo4j_graph_db.close()
return 1 return 1
except Neo4jError as e: except Neo4jError as e:
logging.error(f"Error creating public memory node: {e}") logging.error(f"Error creating public memory node: {e}")
return None return None
async def relevance_feedback(query: str, input_type: str):
async def relevance_feedback(query: str, input_type: str):
max_attempts = 6 max_attempts = 6
result = None result = None
for attempt in range(1, max_attempts + 1): for attempt in range(1, max_attempts + 1):
@ -641,7 +810,6 @@ async def relevance_feedback(query: str, input_type: str):
return result return result
async def main(): async def main():
user_id = "user_test_1_1" user_id = "user_test_1_1"
@ -649,8 +817,6 @@ async def main():
# await update_entity(session, DocsModel, "8cd9a022-5a7a-4af5-815a-f988415536ae", True) # await update_entity(session, DocsModel, "8cd9a022-5a7a-4af5-815a-f988415536ae", True)
# output = await get_unsumarized_vector_db_namespace(session, user_id) # output = await get_unsumarized_vector_db_namespace(session, user_id)
class GraphQLQuery(BaseModel): class GraphQLQuery(BaseModel):
query: str query: str
@ -713,7 +879,7 @@ async def main():
# print(out) # print(out)
# load_doc_to_graph = await add_documents_to_graph_db(session, user_id) # load_doc_to_graph = await add_documents_to_graph_db(session, user_id)
# print(load_doc_to_graph) # print(load_doc_to_graph)
user_id = 'test_user' user_id = "test_user"
# loader_settings = { # loader_settings = {
# "format": "PDF", # "format": "PDF",
# "source": "DEVICE", # "source": "DEVICE",
@ -723,10 +889,15 @@ async def main():
# await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory") # await create_public_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
# await add_documents_to_graph_db(session, user_id) # await add_documents_to_graph_db(session, user_id)
# #
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username, neo4j_graph_db = Neo4jGraphDB(
password=config.graph_database_password) url=config.graph_database_url,
username=config.graph_database_username,
password=config.graph_database_password,
)
out = neo4j_graph_db.run_merge_query(user_id = user_id, memory_type="SemanticMemory", similarity_threshold=0.5) out = neo4j_graph_db.run_merge_query(
user_id=user_id, memory_type="SemanticMemory", similarity_threshold=0.5
)
bb = neo4j_graph_db.query(out) bb = neo4j_graph_db.query(out)
print(bb) print(bb)
@ -798,6 +969,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

201
poetry.lock generated
View file

@ -167,6 +167,20 @@ doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-
test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (<0.22)"] trio = ["trio (<0.22)"]
[[package]]
name = "astroid"
version = "3.0.3"
description = "An abstract syntax tree for Python with inference support."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "astroid-3.0.3-py3-none-any.whl", hash = "sha256:92fcf218b89f449cdf9f7b39a269f8d5d617b27be68434912e11e79203963a17"},
{file = "astroid-3.0.3.tar.gz", hash = "sha256:4148645659b08b70d72460ed1921158027a9e53ae8b7234149b1400eddacbb93"},
]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "astunparse" name = "astunparse"
version = "1.6.3" version = "1.6.3"
@ -520,6 +534,17 @@ files = [
[package.dependencies] [package.dependencies]
pycparser = "*" pycparser = "*"
[[package]]
name = "cfgv"
version = "3.4.0"
description = "Validate configuration and produce human readable error messages."
optional = false
python-versions = ">=3.8"
files = [
{file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
{file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
]
[[package]] [[package]]
name = "chardet" name = "chardet"
version = "5.2.0" version = "5.2.0"
@ -1027,6 +1052,32 @@ files = [
[package.dependencies] [package.dependencies]
packaging = "*" packaging = "*"
[[package]]
name = "dill"
version = "0.3.8"
description = "serialize all of Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
{file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
]
[package.extras]
graph = ["objgraph (>=1.7.2)"]
profile = ["gprof2dot (>=2022.7.29)"]
[[package]]
name = "distlib"
version = "0.3.8"
description = "Distribution utilities"
optional = false
python-versions = "*"
files = [
{file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
[[package]] [[package]]
name = "distro" name = "distro"
version = "1.9.0" version = "1.9.0"
@ -1962,6 +2013,20 @@ files = [
[package.extras] [package.extras]
tests = ["freezegun", "pytest", "pytest-cov"] tests = ["freezegun", "pytest", "pytest-cov"]
[[package]]
name = "identify"
version = "2.5.34"
description = "File identification library for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "identify-2.5.34-py2.py3-none-any.whl", hash = "sha256:a4316013779e433d08b96e5eabb7f641e6c7942e4ab5d4c509ebd2e7a8994aed"},
{file = "identify-2.5.34.tar.gz", hash = "sha256:ee17bc9d499899bc9eaec1ac7bf2dc9eedd480db9d88b96d123d3b64a9d34f5d"},
]
[package.extras]
license = ["ukkonen"]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.6" version = "3.6"
@ -2037,6 +2102,20 @@ files = [
{file = "iso639-0.1.4.tar.gz", hash = "sha256:88b70cf6c64ee9c2c2972292818c8beb32db9ea6f4de1f8471a9b081a3d92e98"}, {file = "iso639-0.1.4.tar.gz", hash = "sha256:88b70cf6c64ee9c2c2972292818c8beb32db9ea6f4de1f8471a9b081a3d92e98"},
] ]
[[package]]
name = "isort"
version = "5.13.2"
description = "A Python utility / library to sort Python imports."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"},
{file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"},
]
[package.extras]
colors = ["colorama (>=0.4.6)"]
[[package]] [[package]]
name = "itsdangerous" name = "itsdangerous"
version = "2.1.2" version = "2.1.2"
@ -2760,6 +2839,17 @@ pillow = ">=8"
pyparsing = ">=2.3.1" pyparsing = ">=2.3.1"
python-dateutil = ">=2.7" python-dateutil = ">=2.7"
[[package]]
name = "mccabe"
version = "0.7.0"
description = "McCabe checker, plugin for flake8"
optional = false
python-versions = ">=3.6"
files = [
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]] [[package]]
name = "mdurl" name = "mdurl"
version = "0.1.2" version = "0.1.2"
@ -2996,6 +3086,20 @@ plot = ["matplotlib"]
tgrep = ["pyparsing"] tgrep = ["pyparsing"]
twitter = ["twython"] twitter = ["twython"]
[[package]]
name = "nodeenv"
version = "1.8.0"
description = "Node.js virtual environment builder"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
files = [
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
]
[package.dependencies]
setuptools = "*"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.26.2" version = "1.26.2"
@ -3317,8 +3421,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -3618,6 +3722,21 @@ urllib3 = ">=1.21.1"
[package.extras] [package.extras]
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.19.5,<3.20.0)"] grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.19.5,<3.20.0)"]
[[package]]
name = "platformdirs"
version = "4.2.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
optional = false
python-versions = ">=3.8"
files = [
{file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"},
{file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"},
]
[package.extras]
docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"]
[[package]] [[package]]
name = "plotly" name = "plotly"
version = "5.18.0" version = "5.18.0"
@ -3663,6 +3782,24 @@ docs = ["sphinx (>=1.7.1)"]
redis = ["redis"] redis = ["redis"]
tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"]
[[package]]
name = "pre-commit"
version = "3.6.1"
description = "A framework for managing and maintaining multi-language pre-commit hooks."
optional = false
python-versions = ">=3.9"
files = [
{file = "pre_commit-3.6.1-py2.py3-none-any.whl", hash = "sha256:9fe989afcf095d2c4796ce7c553cf28d4d4a9b9346de3cda079bcf40748454a4"},
{file = "pre_commit-3.6.1.tar.gz", hash = "sha256:c90961d8aa706f75d60935aba09469a6b0bcb8345f127c3fbee4bdc5f114cf4b"},
]
[package.dependencies]
cfgv = ">=2.0.0"
identify = ">=1.0.0"
nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
[[package]] [[package]]
name = "preshed" name = "preshed"
version = "3.0.9" version = "3.0.9"
@ -4076,6 +4213,35 @@ benchmarks = ["pytest-benchmark"]
tests = ["datasets", "duckdb", "ml_dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] tests = ["datasets", "duckdb", "ml_dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"]
torch = ["torch"] torch = ["torch"]
[[package]]
name = "pylint"
version = "3.0.3"
description = "python code static checker"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "pylint-3.0.3-py3-none-any.whl", hash = "sha256:7a1585285aefc5165db81083c3e06363a27448f6b467b3b0f30dbd0ac1f73810"},
{file = "pylint-3.0.3.tar.gz", hash = "sha256:58c2398b0301e049609a8429789ec6edf3aabe9b6c5fec916acd18639c16de8b"},
]
[package.dependencies]
astroid = ">=3.0.1,<=3.1.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
dill = [
{version = ">=0.2", markers = "python_version < \"3.11\""},
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
]
isort = ">=4.2.5,<5.13.0 || >5.13.0,<6"
mccabe = ">=0.6,<0.8"
platformdirs = ">=2.2.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
tomlkit = ">=0.10.1"
[package.extras]
spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
[[package]] [[package]]
name = "pymupdf" name = "pymupdf"
version = "1.23.8" version = "1.23.8"
@ -5790,6 +5956,17 @@ dev = ["tokenizers[testing]"]
docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.7"
files = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]] [[package]]
name = "tomlkit" name = "tomlkit"
version = "0.12.3" version = "0.12.3"
@ -6213,6 +6390,26 @@ testing = ["pytest (>=7.4.0)"]
tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"] tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"]
tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"] tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"]
[[package]]
name = "virtualenv"
version = "20.25.0"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.7"
files = [
{file = "virtualenv-20.25.0-py3-none-any.whl", hash = "sha256:4238949c5ffe6876362d9c0180fc6c3a824a7b12b80604eeb8085f2ed7460de3"},
{file = "virtualenv-20.25.0.tar.gz", hash = "sha256:bf51c0d9c7dd63ea8e44086fa1e4fb1093a31e963b86959257378aef020e1f1b"},
]
[package.dependencies]
distlib = ">=0.3.7,<1"
filelock = ">=3.12.2,<4"
platformdirs = ">=3.9.1,<5"
[package.extras]
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
[[package]] [[package]]
name = "wasabi" name = "wasabi"
version = "1.1.2" version = "1.1.2"
@ -6513,4 +6710,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "671f878d3fc3b864ac68ef553f3f48ac247bfee0ae60540f260fea7fda727e86" content-hash = "d484dd5ab17563c78699c17296b56155a967f10c432f715a96efbd07e15b34e1"

View file

@ -62,6 +62,7 @@ iso639 = "^0.1.4"
debugpy = "^1.8.0" debugpy = "^1.8.0"
lancedb = "^0.5.5" lancedb = "^0.5.5"
pyarrow = "^15.0.0" pyarrow = "^15.0.0"
pylint = "^3.0.3"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]