fix deployment

This commit is contained in:
Vasilije 2024-05-17 09:52:14 +02:00
parent 3c261ce6a1
commit d2e17dd4b7
6 changed files with 160 additions and 71 deletions

View file

@ -87,9 +87,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
chunk_engine = infrastructure_config.get_config()["chunk_engine"] chunk_engine = infrastructure_config.get_config()["chunk_engine"]
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"] chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
async def process_batch(files_batch):
for (dataset_name, files) in dataset_files: for dataset_name, file_metadata in files_batch:
for file_metadata in files:
with open(file_metadata["file_path"], "rb") as file: with open(file_metadata["file_path"], "rb") as file:
try: try:
file_type = guess_file_type(file) file_type = guess_file_type(file)
@ -102,21 +101,68 @@ async def cognify(datasets: Union[str, List[str]] = None):
data_chunks[dataset_name] = [] data_chunks[dataset_name] = []
for subchunk in subchunks: for subchunk in subchunks:
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata)) data_chunks[dataset_name].append(
dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata))
except FileTypeException: except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
added_chunks = await add_data_chunks(data_chunks)
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in
# added_chunks]
# )
batch_size = 20
file_count = 0
files_batch = []
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks) for (dataset_name, files) in dataset_files:
for file_metadata in files:
files_batch.append((dataset_name, file_metadata))
file_count += 1
await asyncio.gather( if file_count >= batch_size:
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks] await process_batch(files_batch)
) files_batch = []
file_count = 0
# Process any remaining files in the last batch
if files_batch:
await process_batch(files_batch)
return graph_client.graph return graph_client.graph
#
# for (dataset_name, files) in dataset_files:
# for file_metadata in files:
# with open(file_metadata["file_path"], "rb") as file:
# try:
# file_type = guess_file_type(file)
# text = extract_text_from_file(file, file_type)
# if text is None:
# text = "empty file"
# subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
#
# if dataset_name not in data_chunks:
# data_chunks[dataset_name] = []
#
# for subchunk in subchunks:
# data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
# except FileTypeException:
# logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
#
#
#
#
# added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
#
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
# )
#
# return graph_client.graph
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict): async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).") print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
@ -124,36 +170,46 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
graph_topology = infrastructure_config.get_config()["graph_topology"] graph_topology = infrastructure_config.get_config()["graph_topology"]
print("got here")
document_id = await add_document_node( document_id = await add_document_node(
graph_client, graph_client,
parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline
document_metadata = file_metadata, document_metadata = file_metadata,
) )
# print("got here2")
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
# classified_categories = await get_content_categories(input_text)
# #
await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|")) # print("classified_categories", classified_categories)
# await add_classification_nodes(
# graph_client,
# parent_node_id = document_id,
# categories = classified_categories,
# )
classified_categories = await get_content_categories(input_text) classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
await add_classification_nodes(
graph_client,
parent_node_id = document_id,
categories = classified_categories,
)
# print(f"Chunk ({chunk_id}) classified.")
print("document_id", document_id)
content_summary = await get_content_summary(input_text) print(f"Chunk ({chunk_id}) classified.")
await add_summary_nodes(graph_client, document_id, content_summary)
# print("document_id", document_id)
#
# content_summary = await get_content_summary(input_text)
# await add_summary_nodes(graph_client, document_id, content_summary)
print(f"Chunk ({chunk_id}) summarized.") print(f"Chunk ({chunk_id}) summarized.")
#
cognitive_layers = await get_cognitive_layers(input_text, classified_categories) cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2] cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
# #
layer_graphs = await get_layer_graphs(input_text, cognitive_layers) layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs) await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
print("got here 4444")
# #
# if infrastructure_config.get_config()["connect_documents"] is True: # if infrastructure_config.get_config()["connect_documents"] is True:
# db_engine = infrastructure_config.get_config()["database_engine"] # db_engine = infrastructure_config.get_config()["database_engine"]
@ -200,7 +256,7 @@ if __name__ == "__main__":
# #
# await add("data://" +data_directory_path, "example") # await add("data://" +data_directory_path, "example")
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine()}) infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() })
from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.api.v1.config import config from cognee.api.v1.config import config

View file

@ -51,13 +51,13 @@ class Config:
# Model parameters # Model parameters
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1") #"mistralai/Mixtral-8x7B-Instruct-v0.1" custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY") custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1" ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama" ollama_key: Optional[str] = "ollama"
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct" ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview" ) #"gpt-4-1106-preview" openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
model_endpoint: str = "openai" model_endpoint: str = "openai"
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY") openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0)) openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))

View file

@ -1,25 +1,37 @@
import asyncio import asyncio
import os
from typing import List, Type from typing import List, Type
from pydantic import BaseModel from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI from openai import AsyncOpenAI
import openai import openai
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
"""Adapter for Ollama's API""" """Adapter for Generic API LLM provider API """
def __init__(self, api_endpoint, api_key: str, model: str): def __init__(self, api_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
AsyncOpenAI(
base_url = api_endpoint, if infrastructure_config.get_config()["llm_provider"] == 'groq':
api_key = api_key, # required, but unused from groq import groq
), self.aclient = instructor.from_openai(client = groq.Groq(
mode = instructor.Mode.JSON, api_key=api_key,
) ), mode=instructor.Mode.MD_JSON)
else:
self.aclient = instructor.patch(
AsyncOpenAI(
base_url = api_endpoint,
api_key = api_key, # required, but unused
),
mode = instructor.Mode.JSON,
)
self.model = model self.model = model
@retry(stop = stop_after_attempt(5)) @retry(stop = stop_after_attempt(5))
@ -75,20 +87,21 @@ class GenericAPIAdapter(LLMInterface):
return embeddings return embeddings
@retry(stop=stop_after_attempt(5)) @retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return await self.aclient.chat.completions.create( return await self.aclient.chat.completions.create(
model=self.model, model = self.model,
messages=[ messages = [
{ {
"role": "user", "role": "user",
"content": f"""Use the given format to "content": f"""Use the given format to
extract information from the following input: {text_input}. {system_prompt} """, extract information from the following input: {text_input}. """,
} },
{"role": "system", "content": system_prompt},
], ],
response_model=response_model, response_model = response_model,
) )
def show_prompt(self, text_input: str, system_prompt: str) -> str: def show_prompt(self, text_input: str, system_prompt: str) -> str:

View file

@ -15,12 +15,12 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
identified_chunks = [] identified_chunks = []
for (dataset_name, chunks) in dataset_data_chunks.items(): for (dataset_name, chunks) in dataset_data_chunks.items():
# try: try:
# # if not await vector_client.collection_exists(dataset_name): # if not await vector_client.collection_exists(dataset_name):
# # logging.error(f"Creating collection {str(dataset_name)}") # logging.error(f"Creating collection {str(dataset_name)}")
# await vector_client.create_collection(dataset_name) await vector_client.create_collection(dataset_name)
# except Exception: except Exception:
# pass pass
dataset_chunks = [ dataset_chunks = [
dict( dict(
@ -33,29 +33,29 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
identified_chunks.extend(dataset_chunks) identified_chunks.extend(dataset_chunks)
# # if not await vector_client.collection_exists(dataset_name): # if not await vector_client.collection_exists(dataset_name):
# try: try:
# logging.error("Collection still not found. Creating collection again.") logging.error("Collection still not found. Creating collection again.")
# await vector_client.create_collection(dataset_name) await vector_client.create_collection(dataset_name)
# except: except:
# pass pass
#
# async def create_collection_retry(dataset_name, dataset_chunks): async def create_collection_retry(dataset_name, dataset_chunks):
# await vector_client.create_data_points( await vector_client.create_data_points(
# dataset_name, dataset_name,
# [ [
# DataPoint( DataPoint(
# id = chunk["chunk_id"], id = chunk["chunk_id"],
# payload = dict(text = chunk["text"]), payload = dict(text = chunk["text"]),
# embed_field = "text" embed_field = "text"
# ) for chunk in dataset_chunks ) for chunk in dataset_chunks
# ], ],
# ) )
#
# try: try:
# await create_collection_retry(dataset_name, dataset_chunks) await create_collection_retry(dataset_name, dataset_chunks)
# except Exception: except Exception:
# logging.error("Collection not found in create data points.") logging.error("Collection not found in create data points.")
# await create_collection_retry(dataset_name, dataset_chunks) await create_collection_retry(dataset_name, dataset_chunks)
return identified_chunks return identified_chunks

21
poetry.lock generated
View file

@ -2107,6 +2107,25 @@ files = [
[package.dependencies] [package.dependencies]
colorama = ">=0.4" colorama = ">=0.4"
[[package]]
name = "groq"
version = "0.5.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.7"
files = [
{file = "groq-0.5.0-py3-none-any.whl", hash = "sha256:a7e6be1118bcdfea3ed071ec00f505a34d4e6ec28c435adb5a5afd33545683a1"},
{file = "groq-0.5.0.tar.gz", hash = "sha256:d476cdc3383b45d2a4dc1876142a9542e663ea1029f9e07a05de24f895cae48c"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]] [[package]]
name = "grpcio" name = "grpcio"
version = "1.63.0" version = "1.63.0"
@ -7837,4 +7856,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9.0,<3.12" python-versions = ">=3.9.0,<3.12"
content-hash = "459e5945fb184ecdc2c9a817e89b2a83b9e0c97273d1a7df8f75c9c0b355ba47" content-hash = "4c2e75aba3260da9e4023e5fa1c4d117b06e4f042bda9441664a0e91af4069a1"

View file

@ -66,6 +66,7 @@ lancedb = "^0.6.10"
importlib-metadata = "6.8.0" importlib-metadata = "6.8.0"
deepeval = "^0.21.36" deepeval = "^0.21.36"
litellm = "^1.37.3" litellm = "^1.37.3"
groq = "^0.5.0"