LightRAG/examples/unofficial-sample/lightrag_nvidia_demo.py
clssck 69358d830d test(lightrag,examples,api): comprehensive ruff formatting and type hints
Format entire codebase with ruff and add type hints across all modules:
- Apply ruff formatting to all Python files (121 files, 17K insertions)
- Add type hints to function signatures throughout lightrag core and API
- Update test suite with improved type annotations and docstrings
- Add pyrightconfig.json for static type checking configuration
- Create prompt_optimized.py and test_extraction_prompt_ab.py test files
- Update ruff.toml and .gitignore for improved linting configuration
- Standardize code style across examples, reproduce scripts, and utilities
2025-12-05 15:17:06 +01:00

160 lines
4.9 KiB
Python

import asyncio
import os
import nest_asyncio
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.llm import (
nvidia_openai_embed,
openai_complete_if_cache,
)
# for custom llm_model_func
from lightrag.utils import EmbeddingFunc, locate_json_string_body_from_string
nest_asyncio.apply()
WORKING_DIR = './dickens'
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# some method to use your API key (choose one)
# NVIDIA_OPENAI_API_KEY = os.getenv("NVIDIA_OPENAI_API_KEY")
NVIDIA_OPENAI_API_KEY = 'nvapi-xxxx' # your api key
# using pre-defined function for nvidia LLM API. OpenAI compatible
# llm_model_func = nvidia_openai_complete
# If you trying to make custom llm_model_func to use llm model on NVIDIA API like other example:
async def llm_model_func(prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs) -> str:
if history_messages is None:
history_messages = []
result = await openai_complete_if_cache(
'nvidia/llama-3.1-nemotron-70b-instruct',
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=NVIDIA_OPENAI_API_KEY,
base_url='https://integrate.api.nvidia.com/v1',
**kwargs,
)
if keyword_extraction:
return locate_json_string_body_from_string(result)
return result
# custom embedding
nvidia_embed_model = 'nvidia/nv-embedqa-e5-v5'
async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",
api_key=NVIDIA_OPENAI_API_KEY,
base_url='https://integrate.api.nvidia.com/v1',
input_type='passage',
trunc='END', # handling on server side if input token is longer than maximum token
encode='float',
)
async def query_embedding_func(texts: list[str]) -> np.ndarray:
return await nvidia_openai_embed(
texts,
model=nvidia_embed_model, # maximum 512 token
# model="nvidia/llama-3.2-nv-embedqa-1b-v1",
api_key=NVIDIA_OPENAI_API_KEY,
base_url='https://integrate.api.nvidia.com/v1',
input_type='query',
trunc='END', # handling on server side if input token is longer than maximum token
encode='float',
)
# dimension are same
async def get_embedding_dim():
test_text = ['This is a test sentence.']
embedding = await indexing_embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
# function test
async def test_funcs():
result = await llm_model_func('How are you?')
print('llm_model_func: ', result)
result = await indexing_embedding_func(['How are you?'])
print('embedding_func: ', result)
# asyncio.run(test_funcs())
async def initialize_rag():
embedding_dimension = await get_embedding_dim()
print(f'Detected embedding dimension: {embedding_dimension}')
# lightRAG class during indexing
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
# llm_model_name="meta/llama3-70b-instruct", #un comment if
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512, # maximum token size, somehow it's still exceed maximum number of token
# so truncate (trunc) parameter on embedding_func will handle it and try to examine the tokenizer used in LightRAG
# so you can adjust to be able to fit the NVIDIA model (future work)
func=indexing_embedding_func,
),
)
await rag.initialize_storages() # Auto-initializes pipeline_status
return rag
async def main():
try:
# Initialize RAG instance
rag = await initialize_rag()
# reading file
with open('./book.txt', encoding='utf-8') as f:
await rag.ainsert(f.read())
# Perform naive search
print('==============Naive===============')
print(await rag.aquery('What are the top themes in this story?', param=QueryParam(mode='naive')))
# Perform local search
print('==============local===============')
print(await rag.aquery('What are the top themes in this story?', param=QueryParam(mode='local')))
# Perform global search
print('==============global===============')
print(
await rag.aquery(
'What are the top themes in this story?',
param=QueryParam(mode='global'),
)
)
# Perform hybrid search
print('==============hybrid===============')
print(
await rag.aquery(
'What are the top themes in this story?',
param=QueryParam(mode='hybrid'),
)
)
except Exception as e:
print(f'An error occurred: {e}')
if __name__ == '__main__':
asyncio.run(main())