search update (#426)
This commit is contained in:
parent
c7f1db9974
commit
8b19771d86
3 changed files with 21 additions and 16 deletions
|
|
@ -18,6 +18,7 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -68,6 +69,7 @@ async def main():
|
|||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
messages = parse_podcast_messages()
|
||||
group_id = str(uuid4())
|
||||
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
episodes = await client.retrieve_episodes(
|
||||
|
|
@ -80,7 +82,7 @@ async def main():
|
|||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='podcast',
|
||||
group_id=group_id,
|
||||
entity_types={'Person': Person},
|
||||
previous_episode_uuids=episode_uuids,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -341,10 +341,10 @@ async def node_fulltext_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -676,7 +676,7 @@ async def get_relevant_nodes(
|
|||
WHERE score > $min_score
|
||||
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
||||
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", 'group_id:"' + $group_id + '" AND ' + node.name, {limit: $limit})
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
|
||||
YIELD node AS m
|
||||
WHERE m.group_id = $group_id
|
||||
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
||||
|
|
@ -705,18 +705,21 @@ async def get_relevant_nodes(
|
|||
"""
|
||||
)
|
||||
|
||||
query_nodes = [
|
||||
{
|
||||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
'name_embedding': node.name_embedding,
|
||||
'fulltext_query': fulltext_query(node.name, [node.group_id]),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
nodes=[
|
||||
{
|
||||
'uuid': node.uuid,
|
||||
'name': lucene_sanitize(node.name),
|
||||
'name_embedding': node.name_embedding,
|
||||
}
|
||||
for node in nodes
|
||||
],
|
||||
group_id=lucene_sanitize(group_id),
|
||||
nodes=query_nodes,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.11.1"
|
||||
version = "0.11.2"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue