cognee/cognitive_architecture/database/vectordb/cognee_manager.py

151 lines
No EOL
5.4 KiB
Python

import os
import requests
import json
from .embeddings import Embeddings
from .vector_db import VectorDB
from .response import Response
class CogneeManager:
def __init__(self, embeddings: Embeddings = None,
vector_db: VectorDB = None,
vector_db_key: str = None,
embedding_api_key: str = None,
webhook_url: str = None,
lines_per_batch: int = 1000,
webhook_key: str = None,
document_id: str = None,
chunk_validation_url: str = None,
internal_api_key: str = "test123",
base_url="http://localhost:8000"):
self.embeddings = embeddings if embeddings else Embeddings()
self.vector_db = vector_db if vector_db else VectorDB()
self.webhook_url = webhook_url
self.lines_per_batch = lines_per_batch
self.webhook_key = webhook_key
self.document_id = document_id
self.chunk_validation_url = chunk_validation_url
self.vector_db_key = vector_db_key
self.embeddings_api_key = embedding_api_key
self.internal_api_key = internal_api_key
self.base_url = base_url
def serialize(self):
data = {
'EmbeddingsMetadata': json.dumps(self.embeddings.serialize()),
'VectorDBMetadata': json.dumps(self.vector_db.serialize()),
'WebhookURL': self.webhook_url,
'LinesPerBatch': self.lines_per_batch,
'DocumentID': self.document_id,
'ChunkValidationURL': self.chunk_validation_url,
}
return {k: v for k, v in data.items() if v is not None}
def upload(self, file_paths: list[str], base_url=None):
if base_url:
url = base_url + "/jobs"
else:
url = self.base_url + "/jobs"
data = self.serialize()
headers = self.generate_headers()
multipart_form_data = [('file', (os.path.basename(filepath), open(filepath, 'rb'), 'application/octet-stream'))
for filepath in file_paths]
print(f"embedding {len(file_paths)} documents at {url}")
response = requests.post(url, files=multipart_form_data, headers=headers, stream=True, data=data)
if response.status_code == 500:
print(response.text)
return Response(error=response.text, status_code=response.status_code)
response_json = response.json()
if response.status_code >= 400 and response.status_code < 500:
print(f"Error: {response_json['error']}")
return Response.from_json(response_json, response.status_code)
def get_job_statuses(self, job_ids: list[int], base_url=None):
if base_url:
url = base_url + "/jobs/status"
else:
url = self.base_url + "/jobs/status"
headers = {
"Authorization": self.internal_api_key,
}
data = {
'JobIDs': job_ids
}
print(f"retrieving job statuses for {len(job_ids)} jobs at {url}")
response = requests.post(url, headers=headers, json=data)
if response.status_code == 500:
print(response.text)
return Response(error=response.text, status_code=response.status_code)
response_json = response.json()
if response.status_code >= 400 and response.status_code < 500:
print(f"Error: {response_json['error']}")
return Response.from_json(response_json, response.status_code)
def embed(self, filepath, base_url=None):
if base_url:
url = base_url + "/embed"
else:
url = self.base_url + "/embed"
data = self.serialize()
headers = self.generate_headers()
files = {
'SourceData': open(filepath, 'rb')
}
print(f"embedding document at file path {filepath} at {url}")
response = requests.post(url, headers=headers, data=data, files=files)
if response.status_code == 500:
print(response.text)
return Response(error=response.text, status_code=response.status_code)
response_json = response.json()
if response.status_code >= 400 and response.status_code < 500:
print(f"Error: {response_json['error']}")
return Response.from_json(response_json, response.status_code)
def get_job_status(self, job_id, base_url=None):
if base_url:
url = base_url + "/jobs/" + str(job_id) + "/status"
else:
url = self.base_url + "/jobs/" + str(job_id) + "/status"
headers = {
"Authorization": self.internal_api_key,
}
print(f"retrieving job status for job {job_id} at {url}")
response = requests.get(url, headers=headers)
if response.status_code == 500:
print(response.text)
return Response(error=response.text, status_code=response.status_code)
response_json = response.json()
if response.status_code >= 400 and response.status_code < 500:
print(f"Error: {response_json['error']}")
return Response.from_json(response_json, response.status_code)
def generate_headers(self):
headers = {
"Authorization": self.internal_api_key,
"X-EmbeddingAPI-Key": self.embeddings_api_key,
"X-VectorDB-Key": self.vector_db_key,
"X-Webhook-Key": self.webhook_key
}
return {k: v for k, v in headers.items() if v is not None}