async embeddings fix + processing fix + decouple issues with the lancedb
This commit is contained in:
parent
79311ee510
commit
db0b19bb30
6 changed files with 35 additions and 60 deletions
|
|
@ -87,26 +87,22 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
||||||
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
||||||
async def process_batch(files_batch):
|
async def process_batch(files_batch):
|
||||||
for dataset_name, file_metadata in files_batch:
|
for dataset_name, file_metadata, document_id in files_batch:
|
||||||
with open(file_metadata["file_path"], "rb") as file:
|
with open(file_metadata["file_path"], "rb") as file:
|
||||||
try:
|
try:
|
||||||
document_id = await add_document_node(
|
|
||||||
graph_client,
|
|
||||||
parent_node_id = f"DefaultGraphModel__{USER_ID}",
|
|
||||||
document_metadata = file_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
text = extract_text_from_file(file, file_type)
|
text = extract_text_from_file(file, file_type)
|
||||||
if text is None:
|
if text is None:
|
||||||
text = "empty file"
|
text = "empty file"
|
||||||
|
if text == "":
|
||||||
|
text = "empty file"
|
||||||
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||||
|
|
||||||
if dataset_name not in data_chunks:
|
if dataset_name not in data_chunks:
|
||||||
data_chunks[dataset_name] = []
|
data_chunks[dataset_name] = []
|
||||||
|
|
||||||
for subchunk in subchunks:
|
for subchunk in subchunks:
|
||||||
data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk))
|
data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk, file_metadata = file_metadata))
|
||||||
|
|
||||||
except FileTypeException:
|
except FileTypeException:
|
||||||
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||||
|
|
@ -124,8 +120,14 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
files_batch = []
|
files_batch = []
|
||||||
|
|
||||||
for (dataset_name, files) in dataset_files:
|
for (dataset_name, files) in dataset_files:
|
||||||
|
|
||||||
for file_metadata in files:
|
for file_metadata in files:
|
||||||
files_batch.append((dataset_name, file_metadata))
|
document_id = await add_document_node(
|
||||||
|
graph_client,
|
||||||
|
parent_node_id=f"DefaultGraphModel__{USER_ID}",
|
||||||
|
document_metadata=file_metadata,
|
||||||
|
)
|
||||||
|
files_batch.append((dataset_name, file_metadata, document_id))
|
||||||
file_count += 1
|
file_count += 1
|
||||||
|
|
||||||
if file_count >= batch_size:
|
if file_count >= batch_size:
|
||||||
|
|
@ -199,11 +201,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
|
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
|
||||||
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
*[process_text(chunk["document_id"], chunk["chunk_id"], chunk["collection"], chunk["text"]) for chunk in added_chunks]
|
|
||||||
)
|
|
||||||
|
|
||||||
return graph_client.graph
|
|
||||||
#
|
#
|
||||||
# async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str):
|
# async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str):
|
||||||
# raw_document_id = document_id.split("__")[-1]
|
# raw_document_id = document_id.split("__")[-1]
|
||||||
|
|
|
||||||
|
|
@ -27,25 +27,23 @@ class DefaultEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[float]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
|
async def get_embedding(text_):
|
||||||
|
response = await aembedding(config.litellm_embedding_model, input=text_)
|
||||||
|
return response.data[0]['embedding']
|
||||||
print("text", text)
|
|
||||||
try:
|
|
||||||
text = str(text[0])
|
|
||||||
except:
|
|
||||||
text = str(text)
|
|
||||||
|
|
||||||
|
|
||||||
response = await aembedding(config.litellm_embedding_model, input=text)
|
|
||||||
|
|
||||||
|
tasks = [get_embedding(text_) for text_ in text]
|
||||||
|
result = await asyncio.gather(*tasks)
|
||||||
|
return result
|
||||||
|
|
||||||
# embedding = response.data[0].embedding
|
# embedding = response.data[0].embedding
|
||||||
# embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
# # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
||||||
print("response", type(response.data[0]['embedding']))
|
# print("response", type(response.data[0]['embedding']))
|
||||||
return response.data[0]['embedding']
|
# print("response", response.data[0])
|
||||||
|
# return [response.data[0]['embedding']]
|
||||||
|
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,11 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
id: data_point_types["id"] = Field(...)
|
id: data_point_types["id"] = Field(...)
|
||||||
vector: Vector(self.embedding_engine.get_vector_size())
|
vector: Vector(vector_size)
|
||||||
payload: payload_schema
|
payload: payload_schema
|
||||||
|
|
||||||
if not await self.collection_exists(collection_name):
|
if not await self.collection_exists(collection_name):
|
||||||
|
|
@ -68,10 +69,11 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
IdType = TypeVar("IdType")
|
IdType = TypeVar("IdType")
|
||||||
PayloadSchema = TypeVar("PayloadSchema")
|
PayloadSchema = TypeVar("PayloadSchema")
|
||||||
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
|
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
|
||||||
id: IdType
|
id: IdType
|
||||||
vector: Vector(self.embedding_engine.get_vector_size())
|
vector: Vector(vector_size)
|
||||||
payload: PayloadSchema
|
payload: PayloadSchema
|
||||||
|
|
||||||
lance_data_points = [
|
lance_data_points = [
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
collection = dataset_name,
|
collection = dataset_name,
|
||||||
text = chunk["text"],
|
text = chunk["text"],
|
||||||
document_id = chunk["document_id"],
|
document_id = chunk["document_id"],
|
||||||
|
file_metadata = chunk["file_metadata"],
|
||||||
) for chunk in chunks
|
) for chunk in chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ async def add_document_node(graph_client: GraphDBInterface, parent_node_id, docu
|
||||||
dict(relationship_name = "has_document"),
|
dict(relationship_name = "has_document"),
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|"))
|
#
|
||||||
|
# await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|"))
|
||||||
|
|
||||||
return document_id
|
return document_id
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
|
||||||
id = keyword_id,
|
id = keyword_id,
|
||||||
name = keyword.lower().capitalize(),
|
name = keyword.lower().capitalize(),
|
||||||
keyword = keyword.lower(),
|
keyword = keyword.lower(),
|
||||||
type = "Keyword",
|
entity_type = "Keyword",
|
||||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
),
|
),
|
||||||
|
|
@ -45,29 +45,6 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
|
||||||
references: References
|
references: References
|
||||||
|
|
||||||
# Add data to vector
|
# Add data to vector
|
||||||
|
|
||||||
keyword_data_points = [
|
|
||||||
DataPoint(
|
|
||||||
id = str(uuid4()),
|
|
||||||
payload = dict(
|
|
||||||
value = keyword_data["keyword"],
|
|
||||||
references = dict(
|
|
||||||
node_id = keyword_node_id,
|
|
||||||
cognitive_layer = parent_node_id,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
embed_field = "value"
|
|
||||||
) for (keyword_node_id, keyword_data) in keyword_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await vector_client.create_collection(parent_node_id)
|
|
||||||
except Exception:
|
|
||||||
# It's ok if the collection already exists.
|
|
||||||
pass
|
|
||||||
|
|
||||||
await vector_client.create_data_points(parent_node_id, keyword_data_points)
|
|
||||||
|
|
||||||
keyword_data_points = [
|
keyword_data_points = [
|
||||||
DataPoint[PayloadSchema](
|
DataPoint[PayloadSchema](
|
||||||
id = str(uuid4()),
|
id = str(uuid4()),
|
||||||
|
|
@ -84,9 +61,8 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema)
|
await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# It's ok if the collection already exists.
|
# It's ok if the collection already exists.
|
||||||
pass
|
print(e)
|
||||||
|
|
||||||
await vector_client.create_data_points(parent_node_id, keyword_data_points)
|
|
||||||
|
|
||||||
|
await vector_client.create_data_points(parent_node_id, keyword_data_points)
|
||||||
Loading…
Add table
Reference in a new issue