async embeddings fix + processing fix + decouple issues with the lancedb

This commit is contained in:
Vasilije 2024-05-17 11:13:39 +02:00
parent 79311ee510
commit db0b19bb30
6 changed files with 35 additions and 60 deletions

View file

@ -87,26 +87,22 @@ async def cognify(datasets: Union[str, List[str]] = None):
chunk_engine = infrastructure_config.get_config()["chunk_engine"] chunk_engine = infrastructure_config.get_config()["chunk_engine"]
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"] chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
async def process_batch(files_batch): async def process_batch(files_batch):
for dataset_name, file_metadata in files_batch: for dataset_name, file_metadata, document_id in files_batch:
with open(file_metadata["file_path"], "rb") as file: with open(file_metadata["file_path"], "rb") as file:
try: try:
document_id = await add_document_node(
graph_client,
parent_node_id = f"DefaultGraphModel__{USER_ID}",
document_metadata = file_metadata,
)
file_type = guess_file_type(file) file_type = guess_file_type(file)
text = extract_text_from_file(file, file_type) text = extract_text_from_file(file, file_type)
if text is None: if text is None:
text = "empty file" text = "empty file"
if text == "":
text = "empty file"
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap) subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
if dataset_name not in data_chunks: if dataset_name not in data_chunks:
data_chunks[dataset_name] = [] data_chunks[dataset_name] = []
for subchunk in subchunks: for subchunk in subchunks:
data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk)) data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk, file_metadata = file_metadata))
except FileTypeException: except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
@ -124,8 +120,14 @@ async def cognify(datasets: Union[str, List[str]] = None):
files_batch = [] files_batch = []
for (dataset_name, files) in dataset_files: for (dataset_name, files) in dataset_files:
for file_metadata in files: for file_metadata in files:
files_batch.append((dataset_name, file_metadata)) document_id = await add_document_node(
graph_client,
parent_node_id=f"DefaultGraphModel__{USER_ID}",
document_metadata=file_metadata,
)
files_batch.append((dataset_name, file_metadata, document_id))
file_count += 1 file_count += 1
if file_count >= batch_size: if file_count >= batch_size:
@ -199,11 +201,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}] classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
await asyncio.gather(
*[process_text(chunk["document_id"], chunk["chunk_id"], chunk["collection"], chunk["text"]) for chunk in added_chunks]
)
return graph_client.graph
# #
# async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str): # async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str):
# raw_document_id = document_id.split("__")[-1] # raw_document_id = document_id.split("__")[-1]

View file

@ -27,25 +27,23 @@ class DefaultEmbeddingEngine(EmbeddingEngine):
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
import asyncio
from typing import List
async def embed_text(self, text: List[str]) -> List[float]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_):
response = await aembedding(config.litellm_embedding_model, input=text_)
return response.data[0]['embedding']
print("text", text)
try:
text = str(text[0])
except:
text = str(text)
response = await aembedding(config.litellm_embedding_model, input=text)
tasks = [get_embedding(text_) for text_ in text]
result = await asyncio.gather(*tasks)
return result
# embedding = response.data[0].embedding # embedding = response.data[0].embedding
# embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text))) # # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
print("response", type(response.data[0]['embedding'])) # print("response", type(response.data[0]['embedding']))
return response.data[0]['embedding'] # print("response", response.data[0])
# return [response.data[0]['embedding']]
def get_vector_size(self) -> int: def get_vector_size(self) -> int:

View file

@ -37,10 +37,11 @@ class LanceDBAdapter(VectorDBInterface):
async def create_collection(self, collection_name: str, payload_schema: BaseModel): async def create_collection(self, collection_name: str, payload_schema: BaseModel):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
id: data_point_types["id"] = Field(...) id: data_point_types["id"] = Field(...)
vector: Vector(self.embedding_engine.get_vector_size()) vector: Vector(vector_size)
payload: payload_schema payload: payload_schema
if not await self.collection_exists(collection_name): if not await self.collection_exists(collection_name):
@ -68,10 +69,11 @@ class LanceDBAdapter(VectorDBInterface):
IdType = TypeVar("IdType") IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema") PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]): class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
id: IdType id: IdType
vector: Vector(self.embedding_engine.get_vector_size()) vector: Vector(vector_size)
payload: PayloadSchema payload: PayloadSchema
lance_data_points = [ lance_data_points = [

View file

@ -34,6 +34,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
collection = dataset_name, collection = dataset_name,
text = chunk["text"], text = chunk["text"],
document_id = chunk["document_id"], document_id = chunk["document_id"],
file_metadata = chunk["file_metadata"],
) for chunk in chunks ) for chunk in chunks
] ]

View file

@ -25,6 +25,7 @@ async def add_document_node(graph_client: GraphDBInterface, parent_node_id, docu
dict(relationship_name = "has_document"), dict(relationship_name = "has_document"),
) )
await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|")) #
# await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|"))
return document_id return document_id

View file

@ -19,7 +19,7 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
id = keyword_id, id = keyword_id,
name = keyword.lower().capitalize(), name = keyword.lower().capitalize(),
keyword = keyword.lower(), keyword = keyword.lower(),
type = "Keyword", entity_type = "Keyword",
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
), ),
@ -45,29 +45,6 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
references: References references: References
# Add data to vector # Add data to vector
keyword_data_points = [
DataPoint(
id = str(uuid4()),
payload = dict(
value = keyword_data["keyword"],
references = dict(
node_id = keyword_node_id,
cognitive_layer = parent_node_id,
),
),
embed_field = "value"
) for (keyword_node_id, keyword_data) in keyword_nodes
]
try:
await vector_client.create_collection(parent_node_id)
except Exception:
# It's ok if the collection already exists.
pass
await vector_client.create_data_points(parent_node_id, keyword_data_points)
keyword_data_points = [ keyword_data_points = [
DataPoint[PayloadSchema]( DataPoint[PayloadSchema](
id = str(uuid4()), id = str(uuid4()),
@ -84,9 +61,8 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
try: try:
await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema) await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema)
except Exception: except Exception as e:
# It's ok if the collection already exists. # It's ok if the collection already exists.
pass print(e)
await vector_client.create_data_points(parent_node_id, keyword_data_points)
await vector_client.create_data_points(parent_node_id, keyword_data_points)