fix deployment
This commit is contained in:
parent
3c261ce6a1
commit
d2e17dd4b7
6 changed files with 160 additions and 71 deletions
|
|
@ -87,9 +87,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
|
||||
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
||||
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
||||
|
||||
for (dataset_name, files) in dataset_files:
|
||||
for file_metadata in files:
|
||||
async def process_batch(files_batch):
|
||||
for dataset_name, file_metadata in files_batch:
|
||||
with open(file_metadata["file_path"], "rb") as file:
|
||||
try:
|
||||
file_type = guess_file_type(file)
|
||||
|
|
@ -102,21 +101,68 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
data_chunks[dataset_name] = []
|
||||
|
||||
for subchunk in subchunks:
|
||||
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
||||
data_chunks[dataset_name].append(
|
||||
dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata))
|
||||
except FileTypeException:
|
||||
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||
|
||||
added_chunks = await add_data_chunks(data_chunks)
|
||||
|
||||
# await asyncio.gather(
|
||||
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in
|
||||
# added_chunks]
|
||||
# )
|
||||
|
||||
batch_size = 20
|
||||
file_count = 0
|
||||
files_batch = []
|
||||
|
||||
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
||||
for (dataset_name, files) in dataset_files:
|
||||
for file_metadata in files:
|
||||
files_batch.append((dataset_name, file_metadata))
|
||||
file_count += 1
|
||||
|
||||
await asyncio.gather(
|
||||
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
||||
)
|
||||
if file_count >= batch_size:
|
||||
await process_batch(files_batch)
|
||||
files_batch = []
|
||||
file_count = 0
|
||||
|
||||
# Process any remaining files in the last batch
|
||||
if files_batch:
|
||||
await process_batch(files_batch)
|
||||
|
||||
return graph_client.graph
|
||||
|
||||
#
|
||||
# for (dataset_name, files) in dataset_files:
|
||||
# for file_metadata in files:
|
||||
# with open(file_metadata["file_path"], "rb") as file:
|
||||
# try:
|
||||
# file_type = guess_file_type(file)
|
||||
# text = extract_text_from_file(file, file_type)
|
||||
# if text is None:
|
||||
# text = "empty file"
|
||||
# subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||
#
|
||||
# if dataset_name not in data_chunks:
|
||||
# data_chunks[dataset_name] = []
|
||||
#
|
||||
# for subchunk in subchunks:
|
||||
# data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
||||
# except FileTypeException:
|
||||
# logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
||||
#
|
||||
# await asyncio.gather(
|
||||
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
||||
# )
|
||||
#
|
||||
# return graph_client.graph
|
||||
|
||||
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
||||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||
|
||||
|
|
@ -124,36 +170,46 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
|
||||
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
||||
|
||||
print("got here")
|
||||
|
||||
|
||||
document_id = await add_document_node(
|
||||
graph_client,
|
||||
parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline
|
||||
document_metadata = file_metadata,
|
||||
)
|
||||
# print("got here2")
|
||||
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
|
||||
|
||||
# classified_categories = await get_content_categories(input_text)
|
||||
#
|
||||
await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
|
||||
# print("classified_categories", classified_categories)
|
||||
# await add_classification_nodes(
|
||||
# graph_client,
|
||||
# parent_node_id = document_id,
|
||||
# categories = classified_categories,
|
||||
# )
|
||||
|
||||
classified_categories = await get_content_categories(input_text)
|
||||
await add_classification_nodes(
|
||||
graph_client,
|
||||
parent_node_id = document_id,
|
||||
categories = classified_categories,
|
||||
)
|
||||
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
|
||||
|
||||
# print(f"Chunk ({chunk_id}) classified.")
|
||||
|
||||
print("document_id", document_id)
|
||||
|
||||
content_summary = await get_content_summary(input_text)
|
||||
await add_summary_nodes(graph_client, document_id, content_summary)
|
||||
print(f"Chunk ({chunk_id}) classified.")
|
||||
|
||||
# print("document_id", document_id)
|
||||
#
|
||||
# content_summary = await get_content_summary(input_text)
|
||||
# await add_summary_nodes(graph_client, document_id, content_summary)
|
||||
|
||||
print(f"Chunk ({chunk_id}) summarized.")
|
||||
|
||||
#
|
||||
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
||||
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
||||
#
|
||||
layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
|
||||
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
|
||||
|
||||
print("got here 4444")
|
||||
#
|
||||
# if infrastructure_config.get_config()["connect_documents"] is True:
|
||||
# db_engine = infrastructure_config.get_config()["database_engine"]
|
||||
|
|
@ -200,7 +256,7 @@ if __name__ == "__main__":
|
|||
#
|
||||
# await add("data://" +data_directory_path, "example")
|
||||
|
||||
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine()})
|
||||
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() })
|
||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||
from cognee.api.v1.config import config
|
||||
|
||||
|
|
|
|||
|
|
@ -51,13 +51,13 @@ class Config:
|
|||
|
||||
# Model parameters
|
||||
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
|
||||
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
||||
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
|
||||
ollama_key: Optional[str] = "ollama"
|
||||
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
|
||||
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview" ) #"gpt-4-1106-preview"
|
||||
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
|
||||
model_endpoint: str = "openai"
|
||||
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
||||
|
|
|
|||
|
|
@ -1,25 +1,37 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import List, Type
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from tenacity import retry, stop_after_attempt
|
||||
from openai import AsyncOpenAI
|
||||
import openai
|
||||
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
class GenericAPIAdapter(LLMInterface):
|
||||
"""Adapter for Ollama's API"""
|
||||
"""Adapter for Generic API LLM provider API """
|
||||
|
||||
def __init__(self, api_endpoint, api_key: str, model: str):
|
||||
self.aclient = instructor.patch(
|
||||
AsyncOpenAI(
|
||||
base_url = api_endpoint,
|
||||
api_key = api_key, # required, but unused
|
||||
),
|
||||
mode = instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
|
||||
if infrastructure_config.get_config()["llm_provider"] == 'groq':
|
||||
from groq import groq
|
||||
self.aclient = instructor.from_openai(client = groq.Groq(
|
||||
api_key=api_key,
|
||||
), mode=instructor.Mode.MD_JSON)
|
||||
else:
|
||||
self.aclient = instructor.patch(
|
||||
AsyncOpenAI(
|
||||
base_url = api_endpoint,
|
||||
api_key = api_key, # required, but unused
|
||||
),
|
||||
mode = instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
|
||||
self.model = model
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
|
|
@ -75,20 +87,21 @@ class GenericAPIAdapter(LLMInterface):
|
|||
|
||||
return embeddings
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
async def acreate_structured_output(self, text_input: str, system_prompt: str,
|
||||
response_model: Type[BaseModel]) -> BaseModel:
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
model = self.model,
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. {system_prompt} """,
|
||||
}
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
response_model=response_model,
|
||||
response_model = response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
|||
identified_chunks = []
|
||||
|
||||
for (dataset_name, chunks) in dataset_data_chunks.items():
|
||||
# try:
|
||||
# # if not await vector_client.collection_exists(dataset_name):
|
||||
# # logging.error(f"Creating collection {str(dataset_name)}")
|
||||
# await vector_client.create_collection(dataset_name)
|
||||
# except Exception:
|
||||
# pass
|
||||
try:
|
||||
# if not await vector_client.collection_exists(dataset_name):
|
||||
# logging.error(f"Creating collection {str(dataset_name)}")
|
||||
await vector_client.create_collection(dataset_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
dataset_chunks = [
|
||||
dict(
|
||||
|
|
@ -33,29 +33,29 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
|||
|
||||
identified_chunks.extend(dataset_chunks)
|
||||
|
||||
# # if not await vector_client.collection_exists(dataset_name):
|
||||
# try:
|
||||
# logging.error("Collection still not found. Creating collection again.")
|
||||
# await vector_client.create_collection(dataset_name)
|
||||
# except:
|
||||
# pass
|
||||
#
|
||||
# async def create_collection_retry(dataset_name, dataset_chunks):
|
||||
# await vector_client.create_data_points(
|
||||
# dataset_name,
|
||||
# [
|
||||
# DataPoint(
|
||||
# id = chunk["chunk_id"],
|
||||
# payload = dict(text = chunk["text"]),
|
||||
# embed_field = "text"
|
||||
# ) for chunk in dataset_chunks
|
||||
# ],
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# await create_collection_retry(dataset_name, dataset_chunks)
|
||||
# except Exception:
|
||||
# logging.error("Collection not found in create data points.")
|
||||
# await create_collection_retry(dataset_name, dataset_chunks)
|
||||
# if not await vector_client.collection_exists(dataset_name):
|
||||
try:
|
||||
logging.error("Collection still not found. Creating collection again.")
|
||||
await vector_client.create_collection(dataset_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
async def create_collection_retry(dataset_name, dataset_chunks):
|
||||
await vector_client.create_data_points(
|
||||
dataset_name,
|
||||
[
|
||||
DataPoint(
|
||||
id = chunk["chunk_id"],
|
||||
payload = dict(text = chunk["text"]),
|
||||
embed_field = "text"
|
||||
) for chunk in dataset_chunks
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
await create_collection_retry(dataset_name, dataset_chunks)
|
||||
except Exception:
|
||||
logging.error("Collection not found in create data points.")
|
||||
await create_collection_retry(dataset_name, dataset_chunks)
|
||||
|
||||
return identified_chunks
|
||||
|
|
|
|||
21
poetry.lock
generated
21
poetry.lock
generated
|
|
@ -2107,6 +2107,25 @@ files = [
|
|||
[package.dependencies]
|
||||
colorama = ">=0.4"
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
version = "0.5.0"
|
||||
description = "The official Python library for the groq API"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "groq-0.5.0-py3-none-any.whl", hash = "sha256:a7e6be1118bcdfea3ed071ec00f505a34d4e6ec28c435adb5a5afd33545683a1"},
|
||||
{file = "groq-0.5.0.tar.gz", hash = "sha256:d476cdc3383b45d2a4dc1876142a9542e663ea1029f9e07a05de24f895cae48c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.63.0"
|
||||
|
|
@ -7837,4 +7856,4 @@ weaviate = ["weaviate-client"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9.0,<3.12"
|
||||
content-hash = "459e5945fb184ecdc2c9a817e89b2a83b9e0c97273d1a7df8f75c9c0b355ba47"
|
||||
content-hash = "4c2e75aba3260da9e4023e5fa1c4d117b06e4f042bda9441664a0e91af4069a1"
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ lancedb = "^0.6.10"
|
|||
importlib-metadata = "6.8.0"
|
||||
deepeval = "^0.21.36"
|
||||
litellm = "^1.37.3"
|
||||
groq = "^0.5.0"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue