From ae31527e1e83ceadb3e2f0be2abd2f9404e1e053 Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Tue, 13 Jan 2026 06:15:53 +0800 Subject: [PATCH] Fix ContextRelevancy score Fixes #80 --- js/ragas.test.ts | 152 ++++++++++++++++++++++++++++++++++++- js/ragas.ts | 9 ++- py/autoevals/ragas.py | 6 +- py/autoevals/test_ragas.py | 75 ++++++++++++++++++ 4 files changed, 236 insertions(+), 6 deletions(-) diff --git a/js/ragas.test.ts b/js/ragas.test.ts index 9c321ef..a74342e 100644 --- a/js/ragas.test.ts +++ b/js/ragas.test.ts @@ -1,4 +1,7 @@ -import { expect, test } from "vitest"; +import { http, HttpResponse } from "msw"; +import { setupServer } from "msw/node"; +import { OpenAI } from "openai"; +import { afterAll, afterEach, beforeAll, describe, expect, test } from "vitest"; import { AnswerCorrectness, AnswerRelevancy, @@ -9,6 +12,7 @@ import { ContextRelevancy, Faithfulness, } from "./ragas"; +import { init } from "./oai"; const data = { input: "Can starred docs from different workspaces be accessed in one place?", @@ -84,3 +88,149 @@ test("Ragas end-to-end test", async () => { } } }, 600000); + +// Tests for ContextRelevancy score clamping (#80) +describe("ContextRelevancy score clamping", () => { + const server = setupServer(); + + beforeAll(() => { + server.listen({ + onUnhandledRequest: (req) => { + throw new Error(`Unhandled request ${req.method}, ${req.url}`); + }, + }); + }); + + afterEach(() => { + server.resetHandlers(); + init(); + }); + + afterAll(() => { + server.close(); + }); + + test("clamps score to 1.0 when LLM returns sentences longer than context", async () => { + // Mock response where extracted sentences are LONGER than the context + // This would produce a raw score > 1.0 without clamping + server.use( + http.post("https://api.openai.com/v1/chat/completions", () => { + return HttpResponse.json({ + id: "chatcmpl-test", + object: "chat.completion", + created: Date.now(), + model: "gpt-4o", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_test", + type: "function", + function: { + name: "extract_sentences", + arguments: JSON.stringify({ + sentences: [ + { + sentence: + "Hello world, this is a much longer sentence than the original context that was provided", + reasons: ["This is a test reason"], + }, + ], + }), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }, + }); + }), + ); + + init({ + client: new OpenAI({ + apiKey: "test-api-key", + baseURL: "https://api.openai.com/v1", + }), + }); + + // Short context that would cause score > 1.0 without clamping + const result = await ContextRelevancy({ + input: "What is hello?", + output: "Hello world", + context: "Hello world", // 11 chars, but mock returns 88 chars + }); + + // Score should be clamped to 1.0, not exceed it + expect(result.score).toBe(1); + expect(result.score).toBeLessThanOrEqual(1); + expect(result.score).toBeGreaterThanOrEqual(0); + }); + + test("returns expected score for normal case", async () => { + const context = + "Hello world, this is a test context with some content that is reasonably long."; + + // Mock response where extracted sentences are shorter than the context + server.use( + http.post("https://api.openai.com/v1/chat/completions", () => { + return HttpResponse.json({ + id: "chatcmpl-test", + object: "chat.completion", + created: Date.now(), + model: "gpt-4o", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_test", + type: "function", + function: { + name: "extract_sentences", + arguments: JSON.stringify({ + sentences: [ + { sentence: "Hello world", reasons: ["Test reason"] }, + ], + }), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }, + }); + }), + ); + + init({ + client: new OpenAI({ + apiKey: "test-api-key", + baseURL: "https://api.openai.com/v1", + }), + }); + + const result = await ContextRelevancy({ + input: "What is hello?", + output: "Hello world", + context, + }); + + // Score should be len("Hello world") / len(context) = 11 / 79 ≈ 0.139 + const expectedScore = "Hello world".length / context.length; + expect(result.score).toBeCloseTo(expectedScore, 2); + expect(result.score).toBeLessThanOrEqual(1); + expect(result.score).toBeGreaterThanOrEqual(0); + }); +}); diff --git a/js/ragas.ts b/js/ragas.ts index d5a5285..41071ca 100644 --- a/js/ragas.ts +++ b/js/ragas.ts @@ -183,11 +183,14 @@ export const ContextRelevancy: ScorerWithPartial = }); const sentences = relevantSentencesSchema.parse(mustParseArgs(response)); + // Clamp score to [0, 1] - the LLM may return sentences longer than the + // original context due to paraphrasing or hallucination (#80) + const rawScore = + sentences.sentences.map((s) => s.sentence).join("").length / + context.length; return { name: "ContextRelevancy", - score: - sentences.sentences.map((s) => s.sentence).join("").length / - context.length, + score: Math.min(Math.max(rawScore, 0), 1), metadata: { relevantSentences: sentences.sentences, }, diff --git a/py/autoevals/ragas.py b/py/autoevals/ragas.py index 2e432fe..14ce16c 100644 --- a/py/autoevals/ragas.py +++ b/py/autoevals/ragas.py @@ -320,10 +320,12 @@ def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Clie def _postprocess(self, context, response): sentences = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) + # Clamp score to [0, 1] - the LLM may return sentences longer than the + # original context due to paraphrasing or hallucination (#80) + raw_score = len("".join([s["sentence"] for s in sentences["sentences"]])) / len(context) return Score( name=self._name(), - # Simplify this by just using the string length, rather than the number of sentences. - score=len("".join([s["sentence"] for s in sentences["sentences"]])) / len(context), + score=min(max(raw_score, 0), 1), metadata={ "relevant_sentences": sentences["sentences"], }, diff --git a/py/autoevals/test_ragas.py b/py/autoevals/test_ragas.py index a224a4d..556b3ae 100644 --- a/py/autoevals/test_ragas.py +++ b/py/autoevals/test_ragas.py @@ -1,4 +1,5 @@ import asyncio +import json import pytest @@ -44,3 +45,77 @@ def test_ragas_retrieval(metric: OpenAILLMScorer, expected_score: float, is_asyn pytest.xfail(f"Expected score {expected_score} but got {score}") else: raise e + + +def test_context_relevancy_score_clamping(): + """Test that ContextRelevancy clamps scores to [0, 1] range (#80). + + When the LLM returns sentences longer than the original context + (due to paraphrasing or hallucination), the raw score would exceed 1.0. + This test verifies the score is properly clamped. + """ + scorer = ContextRelevancy() + + # Short context + context = "Hello world" + + # Mock response where extracted sentences are LONGER than the context + # This would produce a raw score > 1.0 without clamping + mock_response = { + "choices": [ + { + "message": { + "tool_calls": [ + { + "function": { + "arguments": json.dumps( + { + "sentences": [ + { + "sentence": "Hello world, this is a much longer sentence than the original context" + } + ] + } + ) + } + } + ] + } + } + ] + } + + result = scorer._postprocess(context, mock_response) + + # Score should be clamped to 1.0, not exceed it + assert result.score == 1.0 + assert result.score <= 1.0 + assert result.score >= 0.0 + + +def test_context_relevancy_score_normal_case(): + """Test that ContextRelevancy returns expected score for normal case.""" + scorer = ContextRelevancy() + + context = "Hello world, this is a test context with some content." + + # Mock response where extracted sentences are shorter than the context + mock_response = { + "choices": [ + { + "message": { + "tool_calls": [ + {"function": {"arguments": json.dumps({"sentences": [{"sentence": "Hello world"}]})}} + ] + } + } + ] + } + + result = scorer._postprocess(context, mock_response) + + # Score should be len("Hello world") / len(context) = 11 / 54 ≈ 0.204 + expected_score = len("Hello world") / len(context) + assert result.score == pytest.approx(expected_score, rel=1e-3) + assert result.score <= 1.0 + assert result.score >= 0.0