Entity classification updates (#281)
* node classification updates * update * remove unused code * update
This commit is contained in:
parent
1d2417ec26
commit
6f874730f3
3 changed files with 15 additions and 8 deletions
|
|
@ -31,9 +31,13 @@ class MissedEntities(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class EntityClassification(BaseModel):
|
class EntityClassification(BaseModel):
|
||||||
entity_classification: str = Field(
|
entities: list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
|
description='List of entities',
|
||||||
|
)
|
||||||
|
entity_classifications: list[str | None] = Field(
|
||||||
|
...,
|
||||||
|
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -180,7 +184,8 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Each entity must have exactly one type
|
1. Each entity must have exactly one type
|
||||||
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
|
||||||
|
3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
Message(role='system', content=sys_prompt),
|
Message(role='system', content=sys_prompt),
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ast
|
|
||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -163,8 +162,9 @@ async def extract_nodes(
|
||||||
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
||||||
response_model=EntityClassification,
|
response_model=EntityClassification,
|
||||||
)
|
)
|
||||||
response_string = llm_response.get('entity_classification', '{}')
|
entities = llm_response.get('entities', [])
|
||||||
node_classifications.update(ast.literal_eval(response_string))
|
entity_classifications = llm_response.get('entity_classifications', [])
|
||||||
|
node_classifications.update(dict(zip(entities, entity_classifications)))
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
||||||
|
|
@ -173,7 +173,9 @@ async def extract_nodes(
|
||||||
for name in extracted_node_names:
|
for name in extracted_node_names:
|
||||||
entity_type = node_classifications.get(name)
|
entity_type = node_classifications.get(name)
|
||||||
labels = (
|
labels = (
|
||||||
['Entity'] if entity_type is None or entity_type == 'None' else ['Entity', entity_type]
|
['Entity']
|
||||||
|
if entity_type is None or entity_type == 'None' or entity_type == 'null'
|
||||||
|
else ['Entity', entity_type]
|
||||||
)
|
)
|
||||||
|
|
||||||
new_node = EntityNode(
|
new_node = EntityNode(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.7.4"
|
version = "0.7.5"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue