fix deployment

This commit is contained in:
Vasilije 2024-05-17 09:52:14 +02:00
parent 3c261ce6a1
commit d2e17dd4b7
6 changed files with 160 additions and 71 deletions

View file

@ -87,9 +87,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
for (dataset_name, files) in dataset_files:
for file_metadata in files:
async def process_batch(files_batch):
for dataset_name, file_metadata in files_batch:
with open(file_metadata["file_path"], "rb") as file:
try:
file_type = guess_file_type(file)
@ -102,21 +101,68 @@ async def cognify(datasets: Union[str, List[str]] = None):
data_chunks[dataset_name] = []
for subchunk in subchunks:
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
data_chunks[dataset_name].append(
dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata))
except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
added_chunks = await add_data_chunks(data_chunks)
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in
# added_chunks]
# )
batch_size = 20
file_count = 0
files_batch = []
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
for (dataset_name, files) in dataset_files:
for file_metadata in files:
files_batch.append((dataset_name, file_metadata))
file_count += 1
await asyncio.gather(
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
)
if file_count >= batch_size:
await process_batch(files_batch)
files_batch = []
file_count = 0
# Process any remaining files in the last batch
if files_batch:
await process_batch(files_batch)
return graph_client.graph
#
# for (dataset_name, files) in dataset_files:
# for file_metadata in files:
# with open(file_metadata["file_path"], "rb") as file:
# try:
# file_type = guess_file_type(file)
# text = extract_text_from_file(file, file_type)
# if text is None:
# text = "empty file"
# subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
#
# if dataset_name not in data_chunks:
# data_chunks[dataset_name] = []
#
# for subchunk in subchunks:
# data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
# except FileTypeException:
# logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
#
#
#
#
# added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
#
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
# )
#
# return graph_client.graph
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
@ -124,36 +170,46 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
graph_topology = infrastructure_config.get_config()["graph_topology"]
print("got here")
document_id = await add_document_node(
graph_client,
parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline
document_metadata = file_metadata,
)
# print("got here2")
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
# classified_categories = await get_content_categories(input_text)
#
await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
# print("classified_categories", classified_categories)
# await add_classification_nodes(
# graph_client,
# parent_node_id = document_id,
# categories = classified_categories,
# )
classified_categories = await get_content_categories(input_text)
await add_classification_nodes(
graph_client,
parent_node_id = document_id,
categories = classified_categories,
)
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
# print(f"Chunk ({chunk_id}) classified.")
print("document_id", document_id)
content_summary = await get_content_summary(input_text)
await add_summary_nodes(graph_client, document_id, content_summary)
print(f"Chunk ({chunk_id}) classified.")
# print("document_id", document_id)
#
# content_summary = await get_content_summary(input_text)
# await add_summary_nodes(graph_client, document_id, content_summary)
print(f"Chunk ({chunk_id}) summarized.")
#
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
#
layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
print("got here 4444")
#
# if infrastructure_config.get_config()["connect_documents"] is True:
# db_engine = infrastructure_config.get_config()["database_engine"]
@ -200,7 +256,7 @@ if __name__ == "__main__":
#
# await add("data://" +data_directory_path, "example")
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine()})
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() })
from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.api.v1.config import config

View file

@ -51,13 +51,13 @@ class Config:
# Model parameters
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama"
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview" ) #"gpt-4-1106-preview"
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
model_endpoint: str = "openai"
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))

View file

@ -1,25 +1,37 @@
import asyncio
import os
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI
import openai
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class GenericAPIAdapter(LLMInterface):
"""Adapter for Ollama's API"""
"""Adapter for Generic API LLM provider API """
def __init__(self, api_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
AsyncOpenAI(
base_url = api_endpoint,
api_key = api_key, # required, but unused
),
mode = instructor.Mode.JSON,
)
if infrastructure_config.get_config()["llm_provider"] == 'groq':
from groq import groq
self.aclient = instructor.from_openai(client = groq.Groq(
api_key=api_key,
), mode=instructor.Mode.MD_JSON)
else:
self.aclient = instructor.patch(
AsyncOpenAI(
base_url = api_endpoint,
api_key = api_key, # required, but unused
),
mode = instructor.Mode.JSON,
)
self.model = model
@retry(stop = stop_after_attempt(5))
@ -75,20 +87,21 @@ class GenericAPIAdapter(LLMInterface):
return embeddings
@retry(stop=stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
model = self.model,
messages = [
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. {system_prompt} """,
}
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
],
response_model=response_model,
response_model = response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:

View file

@ -15,12 +15,12 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
identified_chunks = []
for (dataset_name, chunks) in dataset_data_chunks.items():
# try:
# # if not await vector_client.collection_exists(dataset_name):
# # logging.error(f"Creating collection {str(dataset_name)}")
# await vector_client.create_collection(dataset_name)
# except Exception:
# pass
try:
# if not await vector_client.collection_exists(dataset_name):
# logging.error(f"Creating collection {str(dataset_name)}")
await vector_client.create_collection(dataset_name)
except Exception:
pass
dataset_chunks = [
dict(
@ -33,29 +33,29 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
identified_chunks.extend(dataset_chunks)
# # if not await vector_client.collection_exists(dataset_name):
# try:
# logging.error("Collection still not found. Creating collection again.")
# await vector_client.create_collection(dataset_name)
# except:
# pass
#
# async def create_collection_retry(dataset_name, dataset_chunks):
# await vector_client.create_data_points(
# dataset_name,
# [
# DataPoint(
# id = chunk["chunk_id"],
# payload = dict(text = chunk["text"]),
# embed_field = "text"
# ) for chunk in dataset_chunks
# ],
# )
#
# try:
# await create_collection_retry(dataset_name, dataset_chunks)
# except Exception:
# logging.error("Collection not found in create data points.")
# await create_collection_retry(dataset_name, dataset_chunks)
# if not await vector_client.collection_exists(dataset_name):
try:
logging.error("Collection still not found. Creating collection again.")
await vector_client.create_collection(dataset_name)
except:
pass
async def create_collection_retry(dataset_name, dataset_chunks):
await vector_client.create_data_points(
dataset_name,
[
DataPoint(
id = chunk["chunk_id"],
payload = dict(text = chunk["text"]),
embed_field = "text"
) for chunk in dataset_chunks
],
)
try:
await create_collection_retry(dataset_name, dataset_chunks)
except Exception:
logging.error("Collection not found in create data points.")
await create_collection_retry(dataset_name, dataset_chunks)
return identified_chunks

21
poetry.lock generated
View file

@ -2107,6 +2107,25 @@ files = [
[package.dependencies]
colorama = ">=0.4"
[[package]]
name = "groq"
version = "0.5.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.7"
files = [
{file = "groq-0.5.0-py3-none-any.whl", hash = "sha256:a7e6be1118bcdfea3ed071ec00f505a34d4e6ec28c435adb5a5afd33545683a1"},
{file = "groq-0.5.0.tar.gz", hash = "sha256:d476cdc3383b45d2a4dc1876142a9542e663ea1029f9e07a05de24f895cae48c"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]]
name = "grpcio"
version = "1.63.0"
@ -7837,4 +7856,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.0,<3.12"
content-hash = "459e5945fb184ecdc2c9a817e89b2a83b9e0c97273d1a7df8f75c9c0b355ba47"
content-hash = "4c2e75aba3260da9e4023e5fa1c4d117b06e4f042bda9441664a0e91af4069a1"

View file

@ -66,6 +66,7 @@ lancedb = "^0.6.10"
importlib-metadata = "6.8.0"
deepeval = "^0.21.36"
litellm = "^1.37.3"
groq = "^0.5.0"