import os import requests import json from .embeddings import Embeddings from .vector_db import VectorDB from .response import Response class CogneeManager: def __init__( self, embeddings: Embeddings = None, vector_db: VectorDB = None, vector_db_key: str = None, embedding_api_key: str = None, webhook_url: str = None, lines_per_batch: int = 1000, webhook_key: str = None, document_id: str = None, chunk_validation_url: str = None, internal_api_key: str = "test123", base_url="http://localhost:8000", ): self.embeddings = embeddings if embeddings else Embeddings() self.vector_db = vector_db if vector_db else VectorDB() self.webhook_url = webhook_url self.lines_per_batch = lines_per_batch self.webhook_key = webhook_key self.document_id = document_id self.chunk_validation_url = chunk_validation_url self.vector_db_key = vector_db_key self.embeddings_api_key = embedding_api_key self.internal_api_key = internal_api_key self.base_url = base_url def serialize(self): data = { "EmbeddingsMetadata": json.dumps(self.embeddings.serialize()), "VectorDBMetadata": json.dumps(self.vector_db.serialize()), "WebhookURL": self.webhook_url, "LinesPerBatch": self.lines_per_batch, "DocumentID": self.document_id, "ChunkValidationURL": self.chunk_validation_url, } return {k: v for k, v in data.items() if v is not None} def upload(self, file_paths: list[str], base_url=None): if base_url: url = base_url + "/jobs" else: url = self.base_url + "/jobs" data = self.serialize() headers = self.generate_headers() multipart_form_data = [ ( "file", ( os.path.basename(filepath), open(filepath, "rb"), "application/octet-stream", ), ) for filepath in file_paths ] print(f"embedding {len(file_paths)} documents at {url}") response = requests.post( url, files=multipart_form_data, headers=headers, stream=True, data=data ) if response.status_code == 500: print(response.text) return Response(error=response.text, status_code=response.status_code) response_json = response.json() if response.status_code >= 400 and response.status_code < 500: print(f"Error: {response_json['error']}") return Response.from_json(response_json, response.status_code) def get_job_statuses(self, job_ids: list[int], base_url=None): if base_url: url = base_url + "/jobs/status" else: url = self.base_url + "/jobs/status" headers = { "Authorization": self.internal_api_key, } data = {"JobIDs": job_ids} print(f"retrieving job statuses for {len(job_ids)} jobs at {url}") response = requests.post(url, headers=headers, json=data) if response.status_code == 500: print(response.text) return Response(error=response.text, status_code=response.status_code) response_json = response.json() if response.status_code >= 400 and response.status_code < 500: print(f"Error: {response_json['error']}") return Response.from_json(response_json, response.status_code) def embed(self, filepath, base_url=None): if base_url: url = base_url + "/embed" else: url = self.base_url + "/embed" data = self.serialize() headers = self.generate_headers() files = {"SourceData": open(filepath, "rb")} print(f"embedding document at file path {filepath} at {url}") response = requests.post(url, headers=headers, data=data, files=files) if response.status_code == 500: print(response.text) return Response(error=response.text, status_code=response.status_code) response_json = response.json() if response.status_code >= 400 and response.status_code < 500: print(f"Error: {response_json['error']}") return Response.from_json(response_json, response.status_code) def get_job_status(self, job_id, base_url=None): if base_url: url = base_url + "/jobs/" + str(job_id) + "/status" else: url = self.base_url + "/jobs/" + str(job_id) + "/status" headers = { "Authorization": self.internal_api_key, } print(f"retrieving job status for job {job_id} at {url}") response = requests.get(url, headers=headers) if response.status_code == 500: print(response.text) return Response(error=response.text, status_code=response.status_code) response_json = response.json() if response.status_code >= 400 and response.status_code < 500: print(f"Error: {response_json['error']}") return Response.from_json(response_json, response.status_code) def generate_headers(self): headers = { "Authorization": self.internal_api_key, "X-EmbeddingAPI-Key": self.embeddings_api_key, "X-VectorDB-Key": self.vector_db_key, "X-Webhook-Key": self.webhook_key, } return {k: v for k, v in headers.items() if v is not None}