Merge pull request #158 from topoteretes/COG-170-PGvector-adapter

Cog 170 pgvector adapter
This commit is contained in:
Igor Ilic 2024-10-22 15:46:49 +02:00 committed by GitHub
commit c170bbb6e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 503 additions and 21 deletions

View file

@ -7,24 +7,26 @@ GRAPHISTRY_PASSWORD=
SENTRY_REPORTING_URL=
GRAPH_DATABASE_PROVIDER="neo4j" # or "networkx"
# "neo4j" or "networkx"
GRAPH_DATABASE_PROVIDER="neo4j"
# Not needed if using networkx
GRAPH_DATABASE_URL=
GRAPH_DATABASE_USERNAME=
GRAPH_DATABASE_PASSWORD=
VECTOR_DB_PROVIDER="qdrant" # or "weaviate" or "lancedb"
# Not needed if using "lancedb"
# "qdrant", "pgvector", "weaviate" or "lancedb"
VECTOR_DB_PROVIDER="qdrant"
# Not needed if using "lancedb" or "pgvector"
VECTOR_DB_URL=
VECTOR_DB_KEY=
# Database provider
DB_PROVIDER="sqlite" # or "postgres"
# Relational Database provider "sqlite" or "postgres"
DB_PROVIDER="sqlite"
# Database name
DB_NAME=cognee_db
# Postgres specific parameters (Only if Postgres is run)
# Postgres specific parameters (Only if Postgres or PGVector is used)
DB_HOST=127.0.0.1
DB_PORT=5432
DB_USERNAME=cognee

67
.github/workflows/test_pgvector.yml vendored Normal file
View file

@ -0,0 +1,67 @@
name: test | pgvector
on:
pull_request:
branches:
- main
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_pgvector_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: poetry install -E postgres --no-interaction
- name: Run default PGVector
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: poetry run python ./cognee/tests/test_pgvector.py

View file

@ -190,11 +190,11 @@ Cognee supports a variety of tools and services for different operations:
- **Local Setup**: By default, LanceDB runs locally with NetworkX and OpenAI.
- **Vector Stores**: Cognee supports Qdrant and Weaviate for vector storage.
- **Vector Stores**: Cognee supports LanceDB, Qdrant, PGVector and Weaviate for vector storage.
- **Language Models (LLMs)**: You can use either Anyscale or Ollama as your LLM provider.
- **Graph Stores**: In addition to LanceDB, Neo4j is also supported for graph storage.
- **Graph Stores**: In addition to NetworkX, Neo4j is also supported for graph storage.
- **User management**: Create individual user graphs and manage permissions

View file

@ -374,7 +374,7 @@ class LLMConfigDTO(InDTO):
api_key: str
class VectorDBConfigDTO(InDTO):
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
url: str
api_key: str

View file

@ -8,15 +8,18 @@ from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.ingestion import get_matched_datasets, save_data_to_file
from cognee.shared.utils import send_telemetry
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.relational import get_relational_engine, create_db_and_tables
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.methods import get_default_user
from cognee.tasks.ingestion import get_dlt_destination
from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.modules.users.models import User
from cognee.modules.data.methods import create_dataset
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
await create_db_and_tables()
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
if isinstance(data, str):
if "data://" in data:

View file

@ -3,10 +3,12 @@ from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task
from cognee.tasks.ingestion import save_data_to_storage, ingest_data
from cognee.infrastructure.databases.relational import create_db_and_tables
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None):
await create_db_and_tables()
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
if user is None:
user = await get_default_user()

View file

@ -95,6 +95,30 @@ class config():
vector_db_config = get_vectordb_config()
vector_db_config.vector_db_provider = vector_db_provider
@staticmethod
def set_relational_db_config(config_dict: dict):
"""
Updates the relational db config with values from config_dict.
"""
relational_db_config = get_relational_config()
for key, value in config_dict.items():
if hasattr(relational_db_config, key):
object.__setattr__(relational_db_config, key, value)
else:
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
@staticmethod
def set_vector_db_config(config_dict: dict):
"""
Updates the vector db config with values from config_dict.
"""
vector_db_config = get_vectordb_config()
for key, value in config_dict.items():
if hasattr(vector_db_config, key):
object.__setattr__(vector_db_config, key, value)
else:
raise AttributeError(f"'{key}' is not a valid attribute of the config.")
@staticmethod
def set_vector_db_key(db_key: str):
vector_db_config = get_vectordb_config()

View file

@ -119,6 +119,8 @@ class SQLAlchemyAdapter():
self.db_path = None
else:
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
for table in Base.metadata.sorted_tables:
drop_table_query = text(f"DROP TABLE IF EXISTS {table.name} CASCADE")
await connection.execute(drop_table_query)

View file

@ -1,5 +1,7 @@
from typing import Dict
from ..relational.config import get_relational_config
class VectorConfig(Dict):
vector_db_url: str
vector_db_key: str
@ -26,6 +28,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "pgvector":
from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database
relational_config = get_relational_config()
db_username = relational_config.db_username
db_password = relational_config.db_password
db_host = relational_config.db_host
db_port = relational_config.db_port
db_name = relational_config.db_name
connection_string: str = (
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
)
return PGVectorAdapter(connection_string,
config["vector_db_key"],
embedding_engine
)
else:
from .lancedb.LanceDBAdapter import LanceDBAdapter

View file

@ -152,7 +152,7 @@ class LanceDBAdapter(VectorDBInterface):
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather(
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,

View file

@ -0,0 +1,222 @@
import asyncio
from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from .serialize_datetime import serialize_datetime
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if collection_name in Base.metadata.tables:
return True
else:
return False
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0].payload),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
pgvector_data_points = [
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_datetime(data_point.payload.dict()),
)
for (data_index, data_point) in enumerate(data_points)
]
session.add_all(pgvector_data_points)
await session.commit()
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name]
else:
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=result.id, payload=result.payload, score=0)
for result in results
]
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Use async session to connect to the database
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id=str(row.id), payload=row.payload, score=row.similarity
)
for row in vector_list
]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self):
# Clean up the database if it was set up as temporary
await self.delete_database()

View file

@ -0,0 +1,2 @@
from .PGVectorAdapter import PGVectorAdapter
from .create_db_and_tables import create_db_and_tables

View file

@ -0,0 +1,14 @@
from ...relational.ModelBase import Base
from ..get_vector_engine import get_vector_engine, get_vectordb_config
from sqlalchemy import text
async def create_db_and_tables():
vector_config = get_vectordb_config()
vector_engine = get_vector_engine()
if vector_config.vector_db_provider == "pgvector":
vector_engine.create_database()
async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -0,0 +1,12 @@
from datetime import datetime
def serialize_datetime(data):
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
if isinstance(data, dict):
return {key: serialize_datetime(value) for key, value in data.items()}
elif isinstance(data, list):
return [serialize_datetime(item) for item in data]
elif isinstance(data, datetime):
return data.isoformat() # Convert datetime to ISO 8601 string
else:
return data

View file

@ -41,6 +41,9 @@ def get_settings() -> SettingsDict:
}, {
"value": "lancedb",
"label": "LanceDB",
}, {
"value": "pgvector",
"label": "PGVector",
}]
vector_config = get_vectordb_config()

View file

@ -5,7 +5,7 @@ from cognee.infrastructure.databases.vector import get_vectordb_config
class VectorDBConfig(BaseModel):
url: str
api_key: str
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
async def save_vector_db_config(vector_db_config: VectorDBConfig):
vector_config = get_vectordb_config()

View file

@ -0,0 +1,93 @@
import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
logging.basicConfig(level=logging.DEBUG)
async def main():
cognee.config.set_vector_db_config(
{
"vector_db_url": "",
"vector_db_key": "",
"vector_db_provider": "pgvector"
}
)
cognee.config.set_relational_db_config(
{
"db_path": "",
"db_name": "cognee_db",
"db_host": "127.0.0.1",
"db_port": "5432",
"db_username": "cognee",
"db_password": "cognee",
"db_provider": "postgres",
}
)
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_pgvector")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_pgvector")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.CHUNKS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.SUMMARIES, query=random_node_name)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nExtracted summaries are:\n")
for result in search_results:
print(f"{result}\n")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -62,7 +62,7 @@ services:
- cognee-network
postgres:
image: postgres:latest
image: pgvector/pgvector:pg17
container_name: postgres
environment:
POSTGRES_USER: cognee

19
poetry.lock generated
View file

@ -4656,6 +4656,20 @@ files = [
[package.dependencies]
ptyprocess = ">=0.5"
[[package]]
name = "pgvector"
version = "0.3.5"
description = "pgvector support for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pgvector-0.3.5-py3-none-any.whl", hash = "sha256:56cca90392e596ea18873c593ec858a1984a77d16d1f82b8d0c180e79ef1018f"},
{file = "pgvector-0.3.5.tar.gz", hash = "sha256:e876c9ee382c4c2f7ee57691a4c4015d688c7222e47448ce310ded03ecfafe2f"},
]
[package.dependencies]
numpy = "*"
[[package]]
name = "pillow"
version = "10.4.0"
@ -4918,7 +4932,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
name = "psycopg2"
version = "2.9.10"
description = "psycopg2 - Python-PostgreSQL Database Adapter"
optional = false
optional = true
python-versions = ">=3.8"
files = [
{file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"},
@ -7745,10 +7759,11 @@ cli = []
filesystem = []
neo4j = ["neo4j"]
notebook = ["overrides"]
postgres = ["psycopg2"]
qdrant = ["qdrant-client"]
weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.0,<3.12"
content-hash = "4cba654100a455c8691dd3d4e1b588f00bbb2acca89168954037017b3a6aced9"
content-hash = "70a0072dce8de95d64b862f9a9df48aaec84c8d8515ae018fce4426a0dcacf88"

View file

@ -70,8 +70,8 @@ sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = { version = "*", extras = ["sqlalchemy"] }
asyncpg = "^0.29.0"
alembic = "^1.13.3"
psycopg2 = "^2.9.10"
pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true}
[tool.poetry.extras]
filesystem = ["s3fs", "botocore"]
@ -79,6 +79,7 @@ cli = ["pipdeptree", "cron-descriptor"]
weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"]
neo4j = ["neo4j"]
postgres = ["psycopg2"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
[tool.poetry.group.dev.dependencies]
@ -104,7 +105,6 @@ diskcache = "^5.6.3"
pandas = "2.0.3"
tabulate = "^0.9.0"
[tool.ruff] # https://beta.ruff.rs/docs/
line-length = 100
ignore = ["F401"]