Fixes to database manager

This commit is contained in:
Vasilije 2024-02-20 17:10:14 +01:00
parent bf9c80653e
commit 35426c3354
5 changed files with 43 additions and 36 deletions

View file

@ -72,6 +72,7 @@ class Config:
or os.getenv("AWS_ENV") == "dev" or os.getenv("AWS_ENV") == "dev"
or os.getenv("AWS_ENV") == "prd" or os.getenv("AWS_ENV") == "prd"
): ):
db_type = 'postgresql'
db_host: str = os.getenv("POSTGRES_PROD_HOST") db_host: str = os.getenv("POSTGRES_PROD_HOST")
db_user: str = os.getenv("POSTGRES_USER") db_user: str = os.getenv("POSTGRES_USER")
db_password: str = os.getenv("POSTGRES_PASSWORD") db_password: str = os.getenv("POSTGRES_PASSWORD")

View file

@ -1,38 +1,38 @@
""" This module contains the functions that are used to query the language model. """
import os import os
from ..shared.data_models import Node, Edge, KnowledgeGraph, GraphQLQuery, MemorySummary
from ..config import Config
import instructor import instructor
from openai import OpenAI from openai import OpenAI
import logging
from ..shared.data_models import KnowledgeGraph, MemorySummary
from ..config import Config
config = Config() config = Config()
config.load() config.load()
print(config.model)
print(config.openai_key)
OPENAI_API_KEY = config.openai_key OPENAI_API_KEY = config.openai_key
aclient = instructor.patch(OpenAI()) aclient = instructor.patch(OpenAI())
import logging
# Function to read query prompts from files # Function to read query prompts from files
def read_query_prompt(filename): def read_query_prompt(filename):
"""Read a query prompt from a file."""
try: try:
with open(filename, "r") as file: with open(filename, "r") as file:
return file.read() return file.read()
except FileNotFoundError: except FileNotFoundError:
logging.info(f"Error: File not found. Attempted to read: {filename}") logging.info(f"Error: File not found. Attempted to read: %s {filename}")
logging.info(f"Current working directory: {os.getcwd()}") logging.info(f"Current working directory: %s {os.getcwd()}")
return None return None
except Exception as e: except Exception as e:
logging.info(f"An error occurred: {e}") logging.info(f"An error occurred: %s {e}")
return None return None
def generate_graph(input) -> KnowledgeGraph: def generate_graph(input) -> KnowledgeGraph:
"""Generate a knowledge graph from a user query."""
model = "gpt-4-1106-preview" model = "gpt-4-1106-preview"
user_prompt = f"Use the given format to extract information from the following input: {input}." user_prompt = f"Use the given format to extract information from the following input: {input}."
system_prompt = read_query_prompt( system_prompt = read_query_prompt(
@ -57,20 +57,26 @@ def generate_graph(input) -> KnowledgeGraph:
async def generate_summary(input) -> MemorySummary: async def generate_summary(input) -> MemorySummary:
"""Generate a summary from a user query."""
out = aclient.chat.completions.create( out = aclient.chat.completions.create(
model=config.model, model=config.model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Use the given format summarize and reduce the following input: {input}. """, "content": f"""Use the given format summarize
and reduce the following input: {input}. """,
}, },
{ {
"role": "system", "role": "system",
"content": """You are a top-tier algorithm "content": """You are a top-tier algorithm
designed for summarizing existing knowledge graphs in structured formats based on a knowledge graph. designed for summarizing existing knowledge
graphs in structured formats based on a knowledge graph.
## 1. Strict Compliance ## 1. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination. Adhere to the rules strictly.
## 2. Don't forget your main goal is to reduce the number of nodes in the knowledge graph while preserving the information contained in it.""", Non-compliance will result in termination.
## 2. Don't forget your main goal
is to reduce the number of nodes in the knowledge graph
while preserving the information contained in it.""",
}, },
], ],
response_model=MemorySummary, response_model=MemorySummary,
@ -79,6 +85,7 @@ async def generate_summary(input) -> MemorySummary:
def user_query_to_edges_and_nodes(input: str) -> KnowledgeGraph: def user_query_to_edges_and_nodes(input: str) -> KnowledgeGraph:
"""Generate a knowledge graph from a user query."""
system_prompt = read_query_prompt( system_prompt = read_query_prompt(
"cognitive_architecture/llm/prompts/generate_graph_prompt.txt" "cognitive_architecture/llm/prompts/generate_graph_prompt.txt"
) )
@ -87,7 +94,8 @@ def user_query_to_edges_and_nodes(input: str) -> KnowledgeGraph:
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"""Use the given format to extract information from the following input: {input}. """, "content": f"""Use the given format to
extract information from the following input: {input}. """,
}, },
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
], ],

View file

@ -1,3 +1,4 @@
"""Tools for interacting with OpenAI's GPT-3, GPT-4 API"""
import asyncio import asyncio
import random import random
import os import os

View file

@ -1,9 +1,10 @@
"""Data models for the cognitive architecture."""
from typing import Optional, List from typing import Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Node(BaseModel): class Node(BaseModel):
"""Node in a knowledge graph."""
id: int id: int
description: str description: str
category: str category: str
@ -14,6 +15,7 @@ class Node(BaseModel):
class Edge(BaseModel): class Edge(BaseModel):
"""Edge in a knowledge graph."""
source: int source: int
target: int target: int
description: str description: str
@ -23,14 +25,17 @@ class Edge(BaseModel):
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
"""Knowledge graph."""
nodes: List[Node] = Field(..., default_factory=list) nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list) edges: List[Edge] = Field(..., default_factory=list)
class GraphQLQuery(BaseModel): class GraphQLQuery(BaseModel):
"""GraphQL query."""
query: str query: str
class MemorySummary(BaseModel): class MemorySummary(BaseModel):
""" Memory summary. """
nodes: List[Node] = Field(..., default_factory=list) nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list) edges: List[Edge] = Field(..., default_factory=list)

View file

@ -1,9 +1,10 @@
""" This module provides language processing functions for language detection and translation. """
import logging
import boto3 import boto3
from botocore.exceptions import BotoCoreError, ClientError from botocore.exceptions import BotoCoreError, ClientError
from langdetect import detect, LangDetectException from langdetect import detect, LangDetectException
import iso639 import iso639
import logging
# Basic configuration of the logging system # Basic configuration of the logging system
logging.basicConfig( logging.basicConfig(
@ -30,7 +31,7 @@ def detect_language(text):
try: try:
# Detect the language using langdetect # Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text) detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}") logging.info(f"Detected ISO 639-1 code: %s {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2) # Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
if detected_lang_iso639_1 == "hr": if detected_lang_iso639_1 == "hr":
@ -38,9 +39,9 @@ def detect_language(text):
return detected_lang_iso639_1 return detected_lang_iso639_1
except LangDetectException as e: except LangDetectException as e:
logging.error(f"Language detection error: {e}") logging.error(f"Language detection error: %s {e}")
except Exception as e: except Exception as e:
logging.error(f"Unexpected error: {e}") logging.error(f"Unexpected error: %s {e}")
return -1 return -1
@ -57,8 +58,10 @@ def translate_text(
Parameters: Parameters:
text (str): The text to be translated. text (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php source_language (str): The source language code (e.g., 'sr' for Serbian).
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English).
ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name. region_name (str): AWS region name.
Returns: Returns:
@ -82,20 +85,9 @@ def translate_text(
return result.get("TranslatedText", "No translation found.") return result.get("TranslatedText", "No translation found.")
except BotoCoreError as e: except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}") logging.info(f"BotoCoreError occurred: %s {e}")
return "Error with AWS Translate service configuration or request." return "Error with AWS Translate service configuration or request."
except ClientError as e: except ClientError as e:
logging.info(f"ClientError occurred: {e}") logging.info(f"ClientError occurred: %s {e}")
return "Error with AWS client or network issue." return "Error with AWS client or network issue."
source_language = "sr"
target_language = "en"
text_to_translate = "Ja volim da pecam i idem na reku da šetam pored nje ponekad"
translated_text = translate_text(text_to_translate, source_language, target_language)
print(translated_text)
# print(detect_language("Koliko krava ide u setnju?"))