small model fix (#432)
* updated dedupe nodes operations * updates * Update examples/podcast/podcast_transcript.txt Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
f13a497239
commit
2ffc58b3da
15 changed files with 102 additions and 48 deletions
|
|
@ -20,7 +20,7 @@ Fordham is a well-regarded private university in New York City, founded in 1841
|
||||||
There's a very daunting hall of portraits outside of my office. You know, all of these priests going back to 1841,
|
There's a very daunting hall of portraits outside of my office. You know, all of these priests going back to 1841,
|
||||||
|
|
||||||
0 (1m 41s):
|
0 (1m 41s):
|
||||||
Tet, LO's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tania was born in New York not long before the family moved to New Orleans, so Fordham is in her genes.
|
Tetlow's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tania was born in New York not long before the family moved to New Orleans, so Fordham is in her genes.
|
||||||
|
|
||||||
1 (2m 0s):
|
1 (2m 0s):
|
||||||
A good way to recruit me is they can tell me you exist because of us.
|
A good way to recruit me is they can tell me you exist because of us.
|
||||||
|
|
|
||||||
|
|
@ -396,7 +396,7 @@ class Graphiti:
|
||||||
episode.content = ''
|
episode.content = ''
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
self.driver, [episode], episodic_edges, nodes, entity_edges
|
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update any communities
|
# Update any communities
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError, RefusalError
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -173,6 +173,7 @@ class AnthropicClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the Anthropic LLM using tool-based approach for all requests.
|
Generate a response from the Anthropic LLM using tool-based approach for all requests.
|
||||||
|
|
@ -263,6 +264,7 @@ class AnthropicClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the LLM.
|
Generate a response from the LLM.
|
||||||
|
|
@ -289,7 +291,9 @@ class AnthropicClient(LLMClient):
|
||||||
|
|
||||||
while retry_count <= max_retries:
|
while retry_count <= max_retries:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_response(messages, response_model, max_tokens)
|
response = await self._generate_response(
|
||||||
|
messages, response_model, max_tokens, model_size
|
||||||
|
)
|
||||||
|
|
||||||
# If we have a response_model, attempt to validate the response
|
# If we have a response_model, attempt to validate the response
|
||||||
if response_model is not None:
|
if response_model is not None:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
|
|
@ -55,6 +55,7 @@ class LLMClient(ABC):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = config.model
|
self.model = config.model
|
||||||
|
self.small_model = config.small_model
|
||||||
self.temperature = config.temperature
|
self.temperature = config.temperature
|
||||||
self.max_tokens = config.max_tokens
|
self.max_tokens = config.max_tokens
|
||||||
self.cache_enabled = cache
|
self.cache_enabled = cache
|
||||||
|
|
@ -102,9 +103,10 @@ class LLMClient(ABC):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
try:
|
try:
|
||||||
return await self._generate_response(messages, response_model, max_tokens)
|
return await self._generate_response(messages, response_model, max_tokens, model_size)
|
||||||
except (httpx.HTTPStatusError, RateLimitError) as e:
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
@ -114,6 +116,7 @@ class LLMClient(ABC):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -128,6 +131,7 @@ class LLMClient(ABC):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
@ -154,7 +158,9 @@ class LLMClient(ABC):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
message.content = self._clean_input(message.content)
|
message.content = self._clean_input(message.content)
|
||||||
|
|
||||||
response = await self._generate_response_with_retry(messages, response_model, max_tokens)
|
response = await self._generate_response_with_retry(
|
||||||
|
messages, response_model, max_tokens, model_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.cache_enabled and self.cache_dir is not None:
|
if self.cache_enabled and self.cache_dir is not None:
|
||||||
cache_key = self._get_cache_key(messages)
|
cache_key = self._get_cache_key(messages)
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,17 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS = 8192
|
DEFAULT_MAX_TOKENS = 8192
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSize(Enum):
|
||||||
|
small = 'small'
|
||||||
|
medium = 'medium'
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
"""
|
"""
|
||||||
Configuration class for the Language Learning Model (LLM).
|
Configuration class for the Language Learning Model (LLM).
|
||||||
|
|
@ -34,6 +41,7 @@ class LLMConfig:
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
temperature: float = DEFAULT_TEMPERATURE,
|
temperature: float = DEFAULT_TEMPERATURE,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
small_model: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the LLMConfig with the provided parameters.
|
Initialize the LLMConfig with the provided parameters.
|
||||||
|
|
@ -43,15 +51,18 @@ class LLMConfig:
|
||||||
This is required for making authorized requests.
|
This is required for making authorized requests.
|
||||||
|
|
||||||
model (str, optional): The specific LLM model to use for generating responses.
|
model (str, optional): The specific LLM model to use for generating responses.
|
||||||
Defaults to "gpt-4.1-mini", which appears to be a custom model name.
|
Defaults to "gpt-4.1-mini".
|
||||||
Common values might include "gpt-3.5-turbo" or "gpt-4".
|
|
||||||
|
|
||||||
base_url (str, optional): The base URL of the LLM API service.
|
base_url (str, optional): The base URL of the LLM API service.
|
||||||
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
|
||||||
This can be changed if using a different provider or a custom endpoint.
|
This can be changed if using a different provider or a custom endpoint.
|
||||||
|
|
||||||
|
small_model (str, optional): The specific LLM model to use for generating responses of simpler prompts.
|
||||||
|
Defaults to "gpt-4.1-nano".
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.small_model = small_model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -82,6 +82,7 @@ class GeminiClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the Gemini language model.
|
Generate a response from the Gemini language model.
|
||||||
|
|
@ -167,6 +168,7 @@ class GeminiClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
Generate a response from the Gemini language model.
|
Generate a response from the Gemini language model.
|
||||||
|
|
@ -185,5 +187,8 @@ class GeminiClient(LLMClient):
|
||||||
|
|
||||||
# Call the internal _generate_response method
|
# Call the internal _generate_response method
|
||||||
return await self._generate_response(
|
return await self._generate_response(
|
||||||
messages=messages, response_model=response_model, max_tokens=max_tokens
|
messages=messages,
|
||||||
|
response_model=response_model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
model_size=model_size,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -49,6 +49,7 @@ class GroqClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
msgs: list[ChatCompletionMessageParam] = []
|
msgs: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,13 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError, RefusalError
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL = 'gpt-4.1-mini'
|
DEFAULT_MODEL = 'gpt-4.1-mini'
|
||||||
|
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(LLMClient):
|
class OpenAIClient(LLMClient):
|
||||||
|
|
@ -94,6 +95,7 @@ class OpenAIClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -103,8 +105,13 @@ class OpenAIClient(LLMClient):
|
||||||
elif m.role == 'system':
|
elif m.role == 'system':
|
||||||
openai_messages.append({'role': 'system', 'content': m.content})
|
openai_messages.append({'role': 'system', 'content': m.content})
|
||||||
try:
|
try:
|
||||||
|
if model_size == ModelSize.small:
|
||||||
|
model = self.small_model or DEFAULT_SMALL_MODEL
|
||||||
|
else:
|
||||||
|
model = self.model or DEFAULT_MODEL
|
||||||
|
|
||||||
response = await self.client.beta.chat.completions.parse(
|
response = await self.client.beta.chat.completions.parse(
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=model,
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=max_tokens or self.max_tokens,
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
|
|
@ -132,6 +139,7 @@ class OpenAIClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
@ -144,7 +152,9 @@ class OpenAIClient(LLMClient):
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_response(messages, response_model, max_tokens)
|
response = await self._generate_response(
|
||||||
|
messages, response_model, max_tokens, model_size
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except (RateLimitError, RefusalError):
|
except (RateLimitError, RefusalError):
|
||||||
# These errors should not trigger retries
|
# These errors should not trigger retries
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError, RefusalError
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -89,6 +89,7 @@ class OpenAIGenericClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -118,6 +119,7 @@ class OpenAIGenericClient(LLMClient):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
@ -139,7 +141,7 @@ class OpenAIGenericClient(LLMClient):
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_response(
|
response = await self._generate_response(
|
||||||
messages, response_model, max_tokens=max_tokens
|
messages, response_model, max_tokens=max_tokens, model_size=model_size
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except (RateLimitError, RefusalError):
|
except (RateLimitError, RefusalError):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ class NodeDuplicate(BaseModel):
|
||||||
...,
|
...,
|
||||||
description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
|
description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
|
||||||
)
|
)
|
||||||
|
name: str = Field(..., description='Name of the entity.')
|
||||||
|
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
|
|
@ -43,7 +44,7 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
content='You are a helpful assistant that de-duplicates nodes from node lists.',
|
content='You are a helpful assistant that de-duplicates entities from entity lists.',
|
||||||
),
|
),
|
||||||
Message(
|
Message(
|
||||||
role='user',
|
role='user',
|
||||||
|
|
@ -54,25 +55,33 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
</CURRENT MESSAGE>
|
</CURRENT MESSAGE>
|
||||||
|
<NEW ENTITY>
|
||||||
<EXISTING NODES>
|
|
||||||
{json.dumps(context['existing_nodes'], indent=2)}
|
|
||||||
</EXISTING NODES>
|
|
||||||
|
|
||||||
Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW NODE extracted from the conversation
|
|
||||||
is a duplicate entity of one of the EXISTING NODES.
|
|
||||||
|
|
||||||
<NEW NODE>
|
|
||||||
{json.dumps(context['extracted_node'], indent=2)}
|
{json.dumps(context['extracted_node'], indent=2)}
|
||||||
</NEW NODE>
|
</NEW ENTITY>
|
||||||
|
<ENTITY TYPE DESCRIPTION>
|
||||||
|
{json.dumps(context['entity_type_description'], indent=2)}
|
||||||
|
</ENTITY TYPE DESCRIPTION>
|
||||||
|
|
||||||
|
<EXISTING ENTITIES>
|
||||||
|
{json.dumps(context['existing_nodes'], indent=2)}
|
||||||
|
</EXISTING ENTITIES>
|
||||||
|
|
||||||
|
Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
|
||||||
|
is a duplicate entity of one of the EXISTING ENTITIES.
|
||||||
|
|
||||||
|
The ENTITY TYPE DESCRIPTION gives more insight into what the entity type means for the NEW ENTITY.
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If the NEW NODE is a duplicate of any node in EXISTING NODES, set duplicate_node_id to the
|
If the NEW ENTITY represents a duplicate entity of any entity in EXISTING ENTITIES, set duplicate_entity_id to the
|
||||||
id of the EXISTING NODE that is the duplicate. If the NEW NODE is not a duplicate of any of the EXISTING NODES,
|
id of the EXISTING ENTITY that is the duplicate. If the NEW ENTITY is not a duplicate of any of the EXISTING ENTITIES,
|
||||||
duplicate_node_id should be set to -1.
|
duplicate_entity_id should be set to -1.
|
||||||
|
|
||||||
|
Also return the most complete name for the entity.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Use the name, summary, and attributes of nodes to determine if the entities are duplicates,
|
1. Entities with the same name should be considered duplicates
|
||||||
duplicate nodes may have different names
|
2. Duplicate entities may refer to the same real-world entity even if names differ. Use context clues from the MESSAGES
|
||||||
|
to determine if the NEW ENTITY represents a duplicate entity of one of the EXISTING ENTITIES.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -256,7 +256,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
||||||
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
||||||
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
||||||
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
|
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
|
||||||
Summaries must be no longer than 200 words.
|
Summaries must be no longer than 500 words.
|
||||||
|
|
||||||
<ENTITY>
|
<ENTITY>
|
||||||
{context['node']}
|
{context['node']}
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from graphiti_core.edges import (
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||||
|
|
@ -377,7 +378,9 @@ async def dedupe_extracted_edge(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.dedupe_edges.edge(context), response_model=EdgeDuplicate
|
prompt_library.dedupe_edges.edge(context),
|
||||||
|
response_model=EdgeDuplicate,
|
||||||
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from pydantic import BaseModel, Field
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
|
|
@ -281,7 +282,6 @@ async def resolve_extracted_node(
|
||||||
'id': i,
|
'id': i,
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
'entity_types': node.labels,
|
'entity_types': node.labels,
|
||||||
'summary': node.summary,
|
|
||||||
},
|
},
|
||||||
**node.attributes,
|
**node.attributes,
|
||||||
}
|
}
|
||||||
|
|
@ -291,14 +291,14 @@ async def resolve_extracted_node(
|
||||||
extracted_node_context = {
|
extracted_node_context = {
|
||||||
'name': extracted_node.name,
|
'name': extracted_node.name,
|
||||||
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
|
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
|
||||||
'entity_type_description': entity_type.__doc__
|
|
||||||
if entity_type is not None
|
|
||||||
else 'Default Entity Type',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'existing_nodes': existing_nodes_context,
|
'existing_nodes': existing_nodes_context,
|
||||||
'extracted_node': extracted_node_context,
|
'extracted_node': extracted_node_context,
|
||||||
|
'entity_type_description': entity_type.__doc__
|
||||||
|
if entity_type is not None
|
||||||
|
else 'Default Entity Type',
|
||||||
'episode_content': episode.content if episode is not None else '',
|
'episode_content': episode.content if episode is not None else '',
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
if previous_episodes is not None
|
if previous_episodes is not None
|
||||||
|
|
@ -306,7 +306,9 @@ async def resolve_extracted_node(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
|
prompt_library.dedupe_nodes.node(context),
|
||||||
|
response_model=NodeDuplicate,
|
||||||
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
|
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
|
||||||
|
|
@ -315,6 +317,8 @@ async def resolve_extracted_node(
|
||||||
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
|
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
|
||||||
)
|
)
|
||||||
|
|
||||||
|
node.name = llm_response.get('name', '')
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
||||||
|
|
@ -371,13 +375,9 @@ async def extract_attributes_from_node(
|
||||||
'summary': (
|
'summary': (
|
||||||
str,
|
str,
|
||||||
Field(
|
Field(
|
||||||
description='Summary containing the important information about the entity. Under 200 words',
|
description='Summary containing the important information about the entity. Under 500 words',
|
||||||
),
|
),
|
||||||
),
|
)
|
||||||
'name': (
|
|
||||||
str,
|
|
||||||
Field(description='Name of the ENTITY'),
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if entity_type is not None:
|
if entity_type is not None:
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from time import time
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import EpisodicNode
|
from graphiti_core.nodes import EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.extract_edge_dates import EdgeDates
|
from graphiti_core.prompts.extract_edge_dates import EdgeDates
|
||||||
|
|
@ -81,7 +82,9 @@ async def get_edge_contradictions(
|
||||||
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
|
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.invalidate_edges.v2(context), response_model=InvalidatedEdges
|
prompt_library.invalidate_edges.v2(context),
|
||||||
|
response_model=InvalidatedEdges,
|
||||||
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.11.3"
|
version = "0.11.4"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue