Ollama fixes, missing libs + config fixes

This commit is contained in:
Vasilije 2024-03-26 17:52:49 +01:00
parent 4e1b2db8ae
commit 365c7bfc0e
2 changed files with 33 additions and 31 deletions

View file

@ -24,7 +24,7 @@ class InfrastructureConfig():
self.vector_engine = WeaviateAdapter( self.vector_engine = WeaviateAdapter(
config.weaviate_url, config.weaviate_url,
config.weaviate_api_key, config.weaviate_api_key,
config.openai_key embedding_engine = DefaultEmbeddingEngine()
) )
return { return {

View file

@ -9,6 +9,7 @@ from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):
async_pool: Pool = None async_pool: Pool = None
embedding_engine: EmbeddingEngine = None embedding_engine: EmbeddingEngine = None
@ -17,12 +18,12 @@ class WeaviateAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs( self.client = weaviate.connect_to_wcs(
cluster_url = url, cluster_url=url,
auth_credentials = weaviate.auth.AuthApiKey(api_key), auth_credentials=weaviate.auth.AuthApiKey(api_key),
# headers = { # headers = {
# "X-OpenAI-Api-Key": openai_api_key # "X-OpenAI-Api-Key": openai_api_key
# }, # },
additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init = 30)) additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30))
) )
async def embed_data(self, data: List[str]) -> List[float]: async def embed_data(self, data: List[str]) -> List[float]:
@ -30,12 +31,12 @@ class WeaviateAdapter(VectorDBInterface):
async def create_collection(self, collection_name: str): async def create_collection(self, collection_name: str):
return self.client.collections.create( return self.client.collections.create(
name = collection_name, name=collection_name,
properties = [ properties=[
wvcc.Property( wvcc.Property(
name = "text", name="text",
data_type = wvcc.DataType.TEXT, data_type=wvcc.DataType.TEXT,
skip_vectorization = True skip_vectorization=True
) )
] ]
) )
@ -44,13 +45,14 @@ class WeaviateAdapter(VectorDBInterface):
return self.client.collections.get(collection_name) return self.client.collections.get(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_vectors = await self.embed_data(list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
return DataObject( return DataObject(
uuid = data_point.id, uuid=data_point.id,
properties = data_point.payload, properties=data_point.payload,
vector = data_vectors[data_points.index(data_point)] vector=data_vectors[data_points.index(data_point)]
) )
objects = list(map(convert_to_weaviate_data_points, data_points)) objects = list(map(convert_to_weaviate_data_points, data_points))
@ -58,35 +60,35 @@ class WeaviateAdapter(VectorDBInterface):
return self.get_collection(collection_name).data.insert_many(objects) return self.get_collection(collection_name).data.insert_many(objects)
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = None, limit: int = None,
with_vector: bool = False with_vector: bool = False
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise ValueError("One of query_text or query_vector must be provided!")
search_result = self.get_collection(collection_name).query.hybrid( search_result = self.get_collection(collection_name).query.hybrid(
query = None, query=None,
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0], vector=query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
limit = limit, limit=limit,
include_vector = with_vector, include_vector=with_vector,
return_metadata = wvc.query.MetadataQuery(score = True), return_metadata=wvc.query.MetadataQuery(score=True),
) )
return list(map(lambda result: ScoredResult( return list(map(lambda result: ScoredResult(
id = str(result.uuid), id=str(result.uuid),
payload = result.properties, payload=result.properties,
score = float(result.metadata.score) score=float(result.metadata.score)
), search_result.objects)) ), search_result.objects))
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False): async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
def query_search(query_vector): def query_search(query_vector):
return self.search(collection_name, query_vector = query_vector, limit = limit, with_vector = with_vectors) return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors)
return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)] return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]
async def prune(self): async def prune(self):
self.client.collections.delete_all() self.client.collections.delete_all()