Ollama fixes, missing libs + config fixes
This commit is contained in:
parent
4e1b2db8ae
commit
365c7bfc0e
2 changed files with 33 additions and 31 deletions
|
|
@ -24,7 +24,7 @@ class InfrastructureConfig():
|
||||||
self.vector_engine = WeaviateAdapter(
|
self.vector_engine = WeaviateAdapter(
|
||||||
config.weaviate_url,
|
config.weaviate_url,
|
||||||
config.weaviate_api_key,
|
config.weaviate_api_key,
|
||||||
config.openai_key
|
embedding_engine = DefaultEmbeddingEngine()
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from ..models.DataPoint import DataPoint
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
|
||||||
class WeaviateAdapter(VectorDBInterface):
|
class WeaviateAdapter(VectorDBInterface):
|
||||||
async_pool: Pool = None
|
async_pool: Pool = None
|
||||||
embedding_engine: EmbeddingEngine = None
|
embedding_engine: EmbeddingEngine = None
|
||||||
|
|
@ -17,12 +18,12 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
self.client = weaviate.connect_to_wcs(
|
self.client = weaviate.connect_to_wcs(
|
||||||
cluster_url = url,
|
cluster_url=url,
|
||||||
auth_credentials = weaviate.auth.AuthApiKey(api_key),
|
auth_credentials=weaviate.auth.AuthApiKey(api_key),
|
||||||
# headers = {
|
# headers = {
|
||||||
# "X-OpenAI-Api-Key": openai_api_key
|
# "X-OpenAI-Api-Key": openai_api_key
|
||||||
# },
|
# },
|
||||||
additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init = 30))
|
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30))
|
||||||
)
|
)
|
||||||
|
|
||||||
async def embed_data(self, data: List[str]) -> List[float]:
|
async def embed_data(self, data: List[str]) -> List[float]:
|
||||||
|
|
@ -30,12 +31,12 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str):
|
async def create_collection(self, collection_name: str):
|
||||||
return self.client.collections.create(
|
return self.client.collections.create(
|
||||||
name = collection_name,
|
name=collection_name,
|
||||||
properties = [
|
properties=[
|
||||||
wvcc.Property(
|
wvcc.Property(
|
||||||
name = "text",
|
name="text",
|
||||||
data_type = wvcc.DataType.TEXT,
|
data_type=wvcc.DataType.TEXT,
|
||||||
skip_vectorization = True
|
skip_vectorization=True
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -44,13 +45,14 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
return self.client.collections.get(collection_name)
|
return self.client.collections.get(collection_name)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
data_vectors = await self.embed_data(list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
data_vectors = await self.embed_data(
|
||||||
|
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
||||||
|
|
||||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||||
return DataObject(
|
return DataObject(
|
||||||
uuid = data_point.id,
|
uuid=data_point.id,
|
||||||
properties = data_point.payload,
|
properties=data_point.payload,
|
||||||
vector = data_vectors[data_points.index(data_point)]
|
vector=data_vectors[data_points.index(data_point)]
|
||||||
)
|
)
|
||||||
|
|
||||||
objects = list(map(convert_to_weaviate_data_points, data_points))
|
objects = list(map(convert_to_weaviate_data_points, data_points))
|
||||||
|
|
@ -58,35 +60,35 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
return self.get_collection(collection_name).data.insert_many(objects)
|
return self.get_collection(collection_name).data.insert_many(objects)
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vector: bool = False
|
with_vector: bool = False
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
search_result = self.get_collection(collection_name).query.hybrid(
|
search_result = self.get_collection(collection_name).query.hybrid(
|
||||||
query = None,
|
query=None,
|
||||||
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
|
vector=query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
|
||||||
limit = limit,
|
limit=limit,
|
||||||
include_vector = with_vector,
|
include_vector=with_vector,
|
||||||
return_metadata = wvc.query.MetadataQuery(score = True),
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return list(map(lambda result: ScoredResult(
|
return list(map(lambda result: ScoredResult(
|
||||||
id = str(result.uuid),
|
id=str(result.uuid),
|
||||||
payload = result.properties,
|
payload=result.properties,
|
||||||
score = float(result.metadata.score)
|
score=float(result.metadata.score)
|
||||||
), search_result.objects))
|
), search_result.objects))
|
||||||
|
|
||||||
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
|
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
|
||||||
def query_search(query_vector):
|
def query_search(query_vector):
|
||||||
return self.search(collection_name, query_vector = query_vector, limit = limit, with_vector = with_vectors)
|
return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors)
|
||||||
|
|
||||||
return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]
|
return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
self.client.collections.delete_all()
|
self.client.collections.delete_all()
|
||||||
Loading…
Add table
Reference in a new issue