diff --git a/cognee/infrastructure/InfrastructureConfig.py b/cognee/infrastructure/InfrastructureConfig.py index aeff8c78c..fff2834ce 100644 --- a/cognee/infrastructure/InfrastructureConfig.py +++ b/cognee/infrastructure/InfrastructureConfig.py @@ -24,7 +24,7 @@ class InfrastructureConfig(): self.vector_engine = WeaviateAdapter( config.weaviate_url, config.weaviate_api_key, - config.openai_key + embedding_engine = DefaultEmbeddingEngine() ) return { diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index a299aac8a..4119c4361 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -9,6 +9,7 @@ from ..models.DataPoint import DataPoint from ..models.ScoredResult import ScoredResult from ..embeddings.EmbeddingEngine import EmbeddingEngine + class WeaviateAdapter(VectorDBInterface): async_pool: Pool = None embedding_engine: EmbeddingEngine = None @@ -17,12 +18,12 @@ class WeaviateAdapter(VectorDBInterface): self.embedding_engine = embedding_engine self.client = weaviate.connect_to_wcs( - cluster_url = url, - auth_credentials = weaviate.auth.AuthApiKey(api_key), + cluster_url=url, + auth_credentials=weaviate.auth.AuthApiKey(api_key), # headers = { # "X-OpenAI-Api-Key": openai_api_key # }, - additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init = 30)) + additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)) ) async def embed_data(self, data: List[str]) -> List[float]: @@ -30,12 +31,12 @@ class WeaviateAdapter(VectorDBInterface): async def create_collection(self, collection_name: str): return self.client.collections.create( - name = collection_name, - properties = [ + name=collection_name, + properties=[ wvcc.Property( - name = "text", - data_type = wvcc.DataType.TEXT, - skip_vectorization = True + name="text", + data_type=wvcc.DataType.TEXT, + skip_vectorization=True ) ] ) @@ -44,13 +45,14 @@ class WeaviateAdapter(VectorDBInterface): return self.client.collections.get(collection_name) async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): - data_vectors = await self.embed_data(list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) + data_vectors = await self.embed_data( + list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) def convert_to_weaviate_data_points(data_point: DataPoint): return DataObject( - uuid = data_point.id, - properties = data_point.payload, - vector = data_vectors[data_points.index(data_point)] + uuid=data_point.id, + properties=data_point.payload, + vector=data_vectors[data_points.index(data_point)] ) objects = list(map(convert_to_weaviate_data_points, data_points)) @@ -58,35 +60,35 @@ class WeaviateAdapter(VectorDBInterface): return self.get_collection(collection_name).data.insert_many(objects) async def search( - self, - collection_name: str, - query_text: Optional[str] = None, - query_vector: Optional[List[float]] = None, - limit: int = None, - with_vector: bool = False + self, + collection_name: str, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, + limit: int = None, + with_vector: bool = False ): if query_text is None and query_vector is None: raise ValueError("One of query_text or query_vector must be provided!") - + search_result = self.get_collection(collection_name).query.hybrid( - query = None, - vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0], - limit = limit, - include_vector = with_vector, - return_metadata = wvc.query.MetadataQuery(score = True), + query=None, + vector=query_vector if query_vector is not None else (await self.embed_data([query_text]))[0], + limit=limit, + include_vector=with_vector, + return_metadata=wvc.query.MetadataQuery(score=True), ) return list(map(lambda result: ScoredResult( - id = str(result.uuid), - payload = result.properties, - score = float(result.metadata.score) + id=str(result.uuid), + payload=result.properties, + score=float(result.metadata.score) ), search_result.objects)) - async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False): + async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False): def query_search(query_vector): - return self.search(collection_name, query_vector = query_vector, limit = limit, with_vector = with_vectors) + return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors) return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)] async def prune(self): - self.client.collections.delete_all() + self.client.collections.delete_all() \ No newline at end of file