From a05bbf105e3b7f996c3f60497b5146c7dbfb68f6 Mon Sep 17 00:00:00 2001 From: netbrah Date: Sat, 22 Nov 2025 16:43:13 -0500 Subject: [PATCH] Add Cohere reranker config, chunking, and tests --- env.example | 3 + examples/rerank_example.py | 13 +- lightrag/api/lightrag_server.py | 30 ++- lightrag/rerank.py | 208 ++++++++++++++++- tests/test_rerank_chunking.py | 386 ++++++++++++++++++++++++++++++++ 5 files changed, 620 insertions(+), 20 deletions(-) create mode 100644 tests/test_rerank_chunking.py diff --git a/env.example b/env.example index fea99953..c8419961 100644 --- a/env.example +++ b/env.example @@ -102,6 +102,9 @@ RERANK_BINDING=null # RERANK_MODEL=rerank-v3.5 # RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank # RERANK_BINDING_API_KEY=your_rerank_api_key_here +### Cohere rerank chunking configuration (useful for models with token limits like ColBERT) +# RERANK_ENABLE_CHUNKING=true +# RERANK_MAX_TOKENS_PER_DOC=480 ### Default value for Jina AI # RERANK_MODEL=jina-reranker-v2-base-multilingual diff --git a/examples/rerank_example.py b/examples/rerank_example.py index da3d0efe..889cffe8 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -15,9 +15,12 @@ Configuration Required: EMBEDDING_BINDING_HOST EMBEDDING_BINDING_API_KEY 3. Set your vLLM deployed AI rerank model setting with env vars: - RERANK_MODEL - RERANK_BINDING_HOST + RERANK_BINDING=cohere + RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5) + RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy) RERANK_BINDING_API_KEY + RERANK_ENABLE_CHUNKING=true (optional, for models with token limits) + RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096) Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True) """ @@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray: rerank_model_func = partial( cohere_rerank, - model=os.getenv("RERANK_MODEL"), + model=os.getenv("RERANK_MODEL", "rerank-v3.5"), api_key=os.getenv("RERANK_BINDING_API_KEY"), - base_url=os.getenv("RERANK_BINDING_HOST"), + base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"), + enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true", + max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")), ) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b29e39b2..0be5d9de 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -967,15 +967,27 @@ def create_app(args): query: str, documents: list, top_n: int = None, extra_body: dict = None ): """Server rerank function with configuration from environment variables""" - return await selected_rerank_func( - query=query, - documents=documents, - top_n=top_n, - api_key=args.rerank_binding_api_key, - model=args.rerank_model, - base_url=args.rerank_binding_host, - extra_body=extra_body, - ) + # Prepare kwargs for rerank function + kwargs = { + "query": query, + "documents": documents, + "top_n": top_n, + "api_key": args.rerank_binding_api_key, + "model": args.rerank_model, + "base_url": args.rerank_binding_host, + } + + # Add Cohere-specific parameters if using cohere binding + if args.rerank_binding == "cohere": + # Enable chunking if configured (useful for models with token limits like ColBERT) + kwargs["enable_chunking"] = ( + os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true" + ) + kwargs["max_tokens_per_doc"] = int( + os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096") + ) + + return await selected_rerank_func(**kwargs, extra_body=extra_body) rerank_model_func = server_rerank_func logger.info( diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 35551f5a..b3892d56 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations import os import aiohttp -from typing import Any, List, Dict, Optional +from typing import Any, List, Dict, Optional, Tuple from tenacity import ( retry, stop_after_attempt, @@ -19,6 +19,146 @@ from dotenv import load_dotenv load_dotenv(dotenv_path=".env", override=False) +def chunk_documents_for_rerank( + documents: List[str], + max_tokens: int = 480, + overlap_tokens: int = 32, + tokenizer_model: str = "gpt-4o-mini", +) -> Tuple[List[str], List[int]]: + """ + Chunk documents that exceed token limit for reranking. + + Args: + documents: List of document strings to chunk + max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit) + overlap_tokens: Number of tokens to overlap between chunks + tokenizer_model: Model name for tiktoken tokenizer + + Returns: + Tuple of (chunked_documents, original_doc_indices) + - chunked_documents: List of document chunks (may be more than input) + - original_doc_indices: Maps each chunk back to its original document index + """ + try: + from .utils import TiktokenTokenizer + + tokenizer = TiktokenTokenizer(model_name=tokenizer_model) + except Exception as e: + logger.warning( + f"Failed to initialize tokenizer: {e}. Using character-based approximation." + ) + # Fallback: approximate 1 token ≈ 4 characters + max_chars = max_tokens * 4 + overlap_chars = overlap_tokens * 4 + + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + if len(doc) <= max_chars: + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(doc): + end = min(start + max_chars, len(doc)) + chunk = doc[start:end] + chunked_docs.append(chunk) + doc_indices.append(idx) + + if end >= len(doc): + break + start = end - overlap_chars + + return chunked_docs, doc_indices + + # Use tokenizer for accurate chunking + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + tokens = tokenizer.encode(doc) + + if len(tokens) <= max_tokens: + # Document fits in one chunk + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(tokens): + end = min(start + max_tokens, len(tokens)) + chunk_tokens = tokens[start:end] + chunk_text = tokenizer.decode(chunk_tokens) + chunked_docs.append(chunk_text) + doc_indices.append(idx) + + if end >= len(tokens): + break + start = end - overlap_tokens + + return chunked_docs, doc_indices + + +def aggregate_chunk_scores( + chunk_results: List[Dict[str, Any]], + doc_indices: List[int], + num_original_docs: int, + aggregation: str = "max", +) -> List[Dict[str, Any]]: + """ + Aggregate rerank scores from document chunks back to original documents. + + Args: + chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...] + doc_indices: Maps each chunk index to original document index + num_original_docs: Total number of original documents + aggregation: Strategy for aggregating scores ("max", "mean", "first") + + Returns: + List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...] + """ + # Group scores by original document index + doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)} + + for result in chunk_results: + chunk_idx = result["index"] + score = result["relevance_score"] + + if 0 <= chunk_idx < len(doc_indices): + original_doc_idx = doc_indices[chunk_idx] + doc_scores[original_doc_idx].append(score) + + # Aggregate scores + aggregated_results = [] + for doc_idx, scores in doc_scores.items(): + if not scores: + continue + + if aggregation == "max": + final_score = max(scores) + elif aggregation == "mean": + final_score = sum(scores) / len(scores) + elif aggregation == "first": + final_score = scores[0] + else: + logger.warning(f"Unknown aggregation strategy: {aggregation}, using max") + final_score = max(scores) + + aggregated_results.append( + { + "index": doc_idx, + "relevance_score": final_score, + } + ) + + # Sort by relevance score (descending) + aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True) + + return aggregated_results + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -38,6 +178,8 @@ async def generic_rerank_api( extra_body: Optional[Dict[str, Any]] = None, response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + enable_chunking: bool = False, + max_tokens_per_doc: int = 480, ) -> List[Dict[str, Any]]: """ Generic rerank API call for Jina/Cohere/Aliyun models. @@ -52,6 +194,9 @@ async def generic_rerank_api( return_documents: Whether to return document text (Jina only) extra_body: Additional body parameters response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun) + request_format: Request format type + enable_chunking: Whether to chunk documents exceeding token limit + max_tokens_per_doc: Maximum tokens per document for chunking Returns: List of dictionary of ["index": int, "relevance_score": float] @@ -63,6 +208,17 @@ async def generic_rerank_api( if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" + # Handle document chunking if enabled + original_documents = documents + doc_indices = None + if enable_chunking: + documents, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=max_tokens_per_doc + ) + logger.debug( + f"Chunked {len(original_documents)} documents into {len(documents)} chunks" + ) + # Build request payload based on request format if request_format == "aliyun": # Aliyun format: nested input/parameters structure @@ -86,7 +242,7 @@ async def generic_rerank_api( if extra_body: payload["parameters"].update(extra_body) else: - # Standard format for Jina/Cohere + # Standard format for Jina/Cohere/OpenAI payload = { "model": model, "query": query, @@ -98,7 +254,7 @@ async def generic_rerank_api( payload["top_n"] = top_n # Only Jina API supports return_documents parameter - if return_documents is not None: + if return_documents is not None and response_format in ("standard",): payload["return_documents"] = return_documents # Add extra parameters @@ -147,7 +303,6 @@ async def generic_rerank_api( f"Expected 'output.results' to be list, got {type(results)}: {results}" ) results = [] - elif response_format == "standard": # Standard format: {"results": [...]} results = response_json.get("results", []) @@ -158,16 +313,28 @@ async def generic_rerank_api( results = [] else: raise ValueError(f"Unsupported response format: {response_format}") + if not results: logger.warning("Rerank API returned empty results") return [] # Standardize return format - return [ + standardized_results = [ {"index": result["index"], "relevance_score": result["relevance_score"]} for result in results ] + # Aggregate chunk scores back to original documents if chunking was enabled + if enable_chunking and doc_indices: + standardized_results = aggregate_chunk_scores( + standardized_results, + doc_indices, + len(original_documents), + aggregation="max", + ) + + return standardized_results + async def cohere_rerank( query: str, @@ -177,21 +344,46 @@ async def cohere_rerank( model: str = "rerank-v3.5", base_url: str = "https://api.cohere.com/v2/rerank", extra_body: Optional[Dict[str, Any]] = None, + enable_chunking: bool = False, + max_tokens_per_doc: int = 4096, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. + Supports both standard Cohere API and Cohere-compatible proxies + Args: query: The search query documents: List of strings to rerank top_n: Number of top results to return - api_key: API key - model: rerank model name + api_key: API key for authentication + model: rerank model name (default: rerank-v3.5) base_url: API endpoint extra_body: Additional body for http request(reserved for extra params) + enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc + max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5) Returns: List of dictionary of ["index": int, "relevance_score": float] + + Example: + >>> # Standard Cohere API + >>> results = await cohere_rerank( + ... query="What is the meaning of life?", + ... documents=["Doc1", "Doc2"], + ... api_key="your-cohere-key" + ... ) + + >>> # LiteLLM proxy with user authentication + >>> results = await cohere_rerank( + ... query="What is vector search?", + ... documents=["Doc1", "Doc2"], + ... model="answerai-colbert-small-v1", + ... base_url="https://llm-proxy.example.com/v2/rerank", + ... api_key="your-proxy-key", + ... enable_chunking=True, + ... max_tokens_per_doc=480 + ... ) """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") @@ -206,6 +398,8 @@ async def cohere_rerank( return_documents=None, # Cohere doesn't support this parameter extra_body=extra_body, response_format="standard", + enable_chunking=enable_chunking, + max_tokens_per_doc=max_tokens_per_doc, ) diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py new file mode 100644 index 00000000..f31331d2 --- /dev/null +++ b/tests/test_rerank_chunking.py @@ -0,0 +1,386 @@ +""" +Unit tests for rerank document chunking functionality. + +Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions +in lightrag/rerank.py to ensure proper document splitting and score aggregation. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from lightrag.rerank import ( + chunk_documents_for_rerank, + aggregate_chunk_scores, + cohere_rerank, +) + + +class TestChunkDocumentsForRerank: + """Test suite for chunk_documents_for_rerank function""" + + def test_no_chunking_needed_for_short_docs(self): + """Documents shorter than max_tokens should not be chunked""" + documents = [ + "Short doc 1", + "Short doc 2", + "Short doc 3", + ] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=100, overlap_tokens=10 + ) + + # No chunking should occur + assert len(chunked_docs) == 3 + assert chunked_docs == documents + assert doc_indices == [0, 1, 2] + + def test_chunking_with_character_fallback(self): + """Test chunking falls back to character-based when tokenizer unavailable""" + # Create a very long document that exceeds character limit + long_doc = "a" * 2000 # 2000 characters + documents = [long_doc, "short doc"] + + with patch("lightrag.rerank.TiktokenTokenizer", side_effect=ImportError): + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, + max_tokens=100, # 100 tokens = ~400 chars + overlap_tokens=10, # 10 tokens = ~40 chars + ) + + # First doc should be split into chunks, second doc stays whole + assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc + assert chunked_docs[-1] == "short doc" # Last chunk is the short doc + # Verify doc_indices maps chunks to correct original document + assert doc_indices[-1] == 1 # Last chunk maps to document 1 + + def test_chunking_with_tiktoken_tokenizer(self): + """Test chunking with actual tokenizer""" + # Create document with known token count + # Approximate: "word " = ~1 token, so 200 words ~ 200 tokens + long_doc = " ".join([f"word{i}" for i in range(200)]) + documents = [long_doc, "short"] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=50, overlap_tokens=10 + ) + + # Long doc should be split, short doc should remain + assert len(chunked_docs) > 2 + assert doc_indices[-1] == 1 # Last chunk is from second document + + # Verify overlapping chunks contain overlapping content + if len(chunked_docs) > 2: + # Check that consecutive chunks from same doc have some overlap + for i in range(len(doc_indices) - 1): + if doc_indices[i] == doc_indices[i + 1] == 0: + # Both chunks from first doc, should have overlap + chunk1_words = chunked_docs[i].split() + chunk2_words = chunked_docs[i + 1].split() + # At least one word should be common due to overlap + assert any(word in chunk2_words for word in chunk1_words[-5:]) + + def test_empty_documents(self): + """Test handling of empty document list""" + documents = [] + chunked_docs, doc_indices = chunk_documents_for_rerank(documents) + + assert chunked_docs == [] + assert doc_indices == [] + + def test_single_document_chunking(self): + """Test chunking of a single long document""" + # Create document with ~100 tokens + long_doc = " ".join([f"token{i}" for i in range(100)]) + documents = [long_doc] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=5 + ) + + # Should create multiple chunks + assert len(chunked_docs) > 1 + # All chunks should map to document 0 + assert all(idx == 0 for idx in doc_indices) + + +class TestAggregateChunkScores: + """Test suite for aggregate_chunk_scores function""" + + def test_no_chunking_simple_aggregation(self): + """Test aggregation when no chunking occurred (1:1 mapping)""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + {"index": 2, "relevance_score": 0.5}, + ] + doc_indices = [0, 1, 2] # 1:1 mapping + num_original_docs = 3 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Results should be sorted by score + assert len(aggregated) == 3 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.9 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 + assert aggregated[2]["index"] == 2 + assert aggregated[2]["relevance_score"] == 0.5 + + def test_max_aggregation_with_chunks(self): + """Test max aggregation strategy with multiple chunks per document""" + # 5 chunks: first 3 from doc 0, last 2 from doc 1 + chunk_results = [ + {"index": 0, "relevance_score": 0.5}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.6}, + {"index": 3, "relevance_score": 0.7}, + {"index": 4, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 0, 1, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Should take max score for each document + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4 + + def test_mean_aggregation_with_chunks(self): + """Test mean aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1 + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="mean" + ) + + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_first_aggregation_with_chunks(self): + """Test first aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="first" + ) + + assert len(aggregated) == 2 + # First should use first score seen for each doc + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_empty_chunk_results(self): + """Test handling of empty results""" + aggregated = aggregate_chunk_scores([], [], 3, aggregation="max") + assert aggregated == [] + + def test_documents_with_no_scores(self): + """Test when some documents have no chunks/scores""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + doc_indices = [0, 0] # Both chunks from document 0 + num_original_docs = 3 # But we have 3 documents total + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Only doc 0 should appear in results + assert len(aggregated) == 1 + assert aggregated[0]["index"] == 0 + + def test_unknown_aggregation_strategy(self): + """Test that unknown strategy falls back to max""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + ] + doc_indices = [0, 0] + num_original_docs = 1 + + # Use invalid strategy + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="invalid" + ) + + # Should fall back to max + assert aggregated[0]["relevance_score"] == 0.8 + + +@pytest.mark.offline +class TestCohereRerankChunking: + """Integration tests for cohere_rerank with chunking enabled""" + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_disabled(self): + """Test that chunking can be disabled""" + documents = ["doc1", "doc2"] + query = "test query" + + # Mock the generic_rerank_api + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=False, + max_tokens_per_doc=100, + ) + + # Verify generic_rerank_api was called with correct parameters + mock_api.assert_called_once() + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 100 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_enabled(self): + """Test that chunking parameters are passed through""" + documents = ["doc1", "doc2"] + query = "test query" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=True, + max_tokens_per_doc=480, + ) + + # Verify parameters were passed + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is True + assert call_kwargs["max_tokens_per_doc"] == 480 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_default_parameters(self): + """Test default parameter values for cohere_rerank""" + documents = ["doc1"] + query = "test" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [{"index": 0, "relevance_score": 0.9}] + + result = await cohere_rerank( + query=query, documents=documents, api_key="test-key" + ) + + # Verify default values + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 4096 + assert call_kwargs["model"] == "rerank-v3.5" + # Result should mirror mocked scores + assert len(result) == 1 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + + +@pytest.mark.offline +class TestEndToEndChunking: + """End-to-end tests for chunking workflow""" + + @pytest.mark.asyncio + async def test_end_to_end_chunking_workflow(self): + """Test complete chunking workflow from documents to aggregated results""" + # Create documents where first one needs chunking + long_doc = " ".join([f"word{i}" for i in range(100)]) + documents = [long_doc, "short doc"] + query = "test query" + + # Mock the HTTP call inside generic_rerank_api + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0 + {"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0 + {"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0 + {"index": 3, "relevance_score": 0.7}, # doc 1 (short) + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + + mock_session = Mock() + mock_session.post = AsyncMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock() + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=30, # Force chunking of long doc + ) + + # Should get 2 results (one per original document) + # The long doc's chunks should be aggregated + assert len(result) <= len(documents) + # Results should be sorted by score + assert all( + result[i]["relevance_score"] >= result[i + 1]["relevance_score"] + for i in range(len(result) - 1) + )