151 lines
No EOL
5.4 KiB
Python
151 lines
No EOL
5.4 KiB
Python
import os
|
|
import requests
|
|
import json
|
|
from .embeddings import Embeddings
|
|
from .vector_db import VectorDB
|
|
from .response import Response
|
|
|
|
|
|
class CogneeManager:
|
|
def __init__(self, embeddings: Embeddings = None,
|
|
vector_db: VectorDB = None,
|
|
vector_db_key: str = None,
|
|
embedding_api_key: str = None,
|
|
webhook_url: str = None,
|
|
lines_per_batch: int = 1000,
|
|
webhook_key: str = None,
|
|
document_id: str = None,
|
|
chunk_validation_url: str = None,
|
|
internal_api_key: str = "test123",
|
|
base_url="http://localhost:8000"):
|
|
self.embeddings = embeddings if embeddings else Embeddings()
|
|
self.vector_db = vector_db if vector_db else VectorDB()
|
|
self.webhook_url = webhook_url
|
|
self.lines_per_batch = lines_per_batch
|
|
self.webhook_key = webhook_key
|
|
self.document_id = document_id
|
|
self.chunk_validation_url = chunk_validation_url
|
|
self.vector_db_key = vector_db_key
|
|
self.embeddings_api_key = embedding_api_key
|
|
self.internal_api_key = internal_api_key
|
|
self.base_url = base_url
|
|
|
|
def serialize(self):
|
|
data = {
|
|
'EmbeddingsMetadata': json.dumps(self.embeddings.serialize()),
|
|
'VectorDBMetadata': json.dumps(self.vector_db.serialize()),
|
|
'WebhookURL': self.webhook_url,
|
|
'LinesPerBatch': self.lines_per_batch,
|
|
'DocumentID': self.document_id,
|
|
'ChunkValidationURL': self.chunk_validation_url,
|
|
}
|
|
return {k: v for k, v in data.items() if v is not None}
|
|
|
|
def upload(self, file_paths: list[str], base_url=None):
|
|
if base_url:
|
|
url = base_url + "/jobs"
|
|
else:
|
|
url = self.base_url + "/jobs"
|
|
|
|
data = self.serialize()
|
|
headers = self.generate_headers()
|
|
multipart_form_data = [('file', (os.path.basename(filepath), open(filepath, 'rb'), 'application/octet-stream'))
|
|
for filepath in file_paths]
|
|
|
|
print(f"embedding {len(file_paths)} documents at {url}")
|
|
response = requests.post(url, files=multipart_form_data, headers=headers, stream=True, data=data)
|
|
|
|
if response.status_code == 500:
|
|
print(response.text)
|
|
return Response(error=response.text, status_code=response.status_code)
|
|
|
|
response_json = response.json()
|
|
if response.status_code >= 400 and response.status_code < 500:
|
|
print(f"Error: {response_json['error']}")
|
|
|
|
return Response.from_json(response_json, response.status_code)
|
|
|
|
def get_job_statuses(self, job_ids: list[int], base_url=None):
|
|
if base_url:
|
|
url = base_url + "/jobs/status"
|
|
else:
|
|
url = self.base_url + "/jobs/status"
|
|
|
|
headers = {
|
|
"Authorization": self.internal_api_key,
|
|
}
|
|
|
|
data = {
|
|
'JobIDs': job_ids
|
|
}
|
|
|
|
print(f"retrieving job statuses for {len(job_ids)} jobs at {url}")
|
|
response = requests.post(url, headers=headers, json=data)
|
|
|
|
if response.status_code == 500:
|
|
print(response.text)
|
|
return Response(error=response.text, status_code=response.status_code)
|
|
|
|
response_json = response.json()
|
|
if response.status_code >= 400 and response.status_code < 500:
|
|
print(f"Error: {response_json['error']}")
|
|
|
|
return Response.from_json(response_json, response.status_code)
|
|
|
|
def embed(self, filepath, base_url=None):
|
|
if base_url:
|
|
url = base_url + "/embed"
|
|
else:
|
|
url = self.base_url + "/embed"
|
|
|
|
data = self.serialize()
|
|
headers = self.generate_headers()
|
|
|
|
files = {
|
|
'SourceData': open(filepath, 'rb')
|
|
}
|
|
|
|
print(f"embedding document at file path {filepath} at {url}")
|
|
response = requests.post(url, headers=headers, data=data, files=files)
|
|
|
|
if response.status_code == 500:
|
|
print(response.text)
|
|
return Response(error=response.text, status_code=response.status_code)
|
|
|
|
response_json = response.json()
|
|
if response.status_code >= 400 and response.status_code < 500:
|
|
print(f"Error: {response_json['error']}")
|
|
|
|
return Response.from_json(response_json, response.status_code)
|
|
|
|
def get_job_status(self, job_id, base_url=None):
|
|
if base_url:
|
|
url = base_url + "/jobs/" + str(job_id) + "/status"
|
|
else:
|
|
url = self.base_url + "/jobs/" + str(job_id) + "/status"
|
|
|
|
headers = {
|
|
"Authorization": self.internal_api_key,
|
|
}
|
|
|
|
print(f"retrieving job status for job {job_id} at {url}")
|
|
response = requests.get(url, headers=headers)
|
|
|
|
if response.status_code == 500:
|
|
print(response.text)
|
|
return Response(error=response.text, status_code=response.status_code)
|
|
|
|
response_json = response.json()
|
|
if response.status_code >= 400 and response.status_code < 500:
|
|
print(f"Error: {response_json['error']}")
|
|
|
|
return Response.from_json(response_json, response.status_code)
|
|
|
|
def generate_headers(self):
|
|
headers = {
|
|
"Authorization": self.internal_api_key,
|
|
"X-EmbeddingAPI-Key": self.embeddings_api_key,
|
|
"X-VectorDB-Key": self.vector_db_key,
|
|
"X-Webhook-Key": self.webhook_key
|
|
}
|
|
return {k: v for k, v in headers.items() if v is not None} |