Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client/embedding_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func DefaultSemanticCacheConfig() SemanticCacheConfig {
// semanticEntry holds a cached response keyed by its request embedding.
type semanticEntry struct {
vector []float32
model string // embedding model that produced vector; gates cross-model reuse
response *EyrieResponse
createdAt time.Time

Expand Down Expand Up @@ -213,6 +214,12 @@ func (sp *EmbeddingCachedProvider) lookup(vec []float32) (*EyrieResponse, bool)
if now.Sub(e.createdAt) > sp.maxAge {
continue
}
// Never compare across embedding models: vectors from a different model
// live in an incompatible space, so the cosine score would be meaningless
// (and on a same-dimension model swap could serve a wrong response).
if e.model != sp.model {
continue
}
sim := cosineSimilarity(vec, e.vector)
if sim >= bestSim {
bestSim = sim
Expand Down Expand Up @@ -241,6 +248,7 @@ func (sp *EmbeddingCachedProvider) store(vec []float32, resp *EyrieResponse) {

e := &semanticEntry{
vector: vec,
model: sp.model,
response: copyResponse(resp),
createdAt: time.Now(),
}
Expand Down
38 changes: 38 additions & 0 deletions client/embedding_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ func (errEmbedder) CreateEmbedding(_ context.Context, _ EmbeddingRequest) (*Embe
return nil, errEmptyEmbedding
}

// TestEmbeddingCache_ModelIsolation pins fix C: an entry stored under one
// embedding model must NOT be served to a request embedded by a different model,
// even when the vectors are identical — they live in incompatible spaces.
func TestEmbeddingCache_ModelIsolation(t *testing.T) {
mock := NewMockProvider(MockModeEcho)
cfg := DefaultSemanticCacheConfig()
cfg.EmbeddingModel = "model-A"
sp := NewEmbeddingCachedProvider(mock, stubEmbedder{}, cfg)
ctx := context.Background()

// Warm the cache under model-A.
if _, err := sp.Chat(ctx, userMsg("what is the weather today"), ChatOptions{}); err != nil {
t.Fatalf("warm: %v", err)
}
if mock.CallCount() != 1 {
t.Fatalf("setup expected 1 inner call, got %d", mock.CallCount())
}

// Simulate the embedding model being swapped under the same cache. The prior
// entry (tagged model-A) must be skipped, forcing a fresh inner call rather
// than a cross-model false hit.
sp.model = "model-B"
if _, err := sp.Chat(ctx, userMsg("give me the weather please"), ChatOptions{}); err != nil {
t.Fatalf("post-swap: %v", err)
}
if mock.CallCount() != 2 {
t.Errorf("model swap must invalidate cross-model reuse; expected 2 inner calls, got %d", mock.CallCount())
}

// Requests under model-B should now cache and hit among themselves.
if _, err := sp.Chat(ctx, userMsg("weather report now"), ChatOptions{}); err != nil {
t.Fatalf("model-B hit: %v", err)
}
if mock.CallCount() != 2 {
t.Errorf("expected a within-model-B hit (still 2 inner calls), got %d", mock.CallCount())
}
}

func TestCosineSimilarity(t *testing.T) {
if got := cosineSimilarity([]float32{1, 0}, []float32{1, 0}); got < 0.999 {
t.Errorf("identical vectors should be ~1, got %f", got)
Expand Down
146 changes: 146 additions & 0 deletions client/extract.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package client

import (
"context"
"encoding/json"
"fmt"
"strings"
)

// Relationship is a subject-predicate-object triple extracted from text. It is
// the unit of a knowledge graph: subject and object are entities (nouns), and
// predicate is the typed relation between them.
//
// This mirrors the typed-extraction pattern popularized by CocoIndex's
// ExtractByLlm(output_type=list[Relationship]) — instead of free-form text, the
// model is constrained to emit a list of these triples, validated against a JSON
// schema with automatic retry.
type Relationship struct {
// Subject is the entity the relation originates from. A noun/noun phrase.
Subject string `json:"subject"`
// Predicate is the typed relation (e.g. "depends_on", "authored_by").
Predicate string `json:"predicate"`
// Object is the entity the relation points to. A noun/noun phrase.
Object string `json:"object"`
}

// ExtractOptions configures triple extraction. The zero value is valid and uses
// sensible defaults (noun-constrained entities, 2 validation retries).
type ExtractOptions struct {
// Chat carries provider/model/temperature for the extraction call. If Model
// is empty the provider default is used.
Chat ChatOptions
// Instruction overrides the default extraction instruction. Use it to scope
// what relations to extract (e.g. "extract only code dependency relations").
// When empty, a general noun-constrained instruction is used.
Instruction string
// AllowedPredicates, when non-empty, constrains the predicate vocabulary —
// the model is told to use only these relation types, and triples with other
// predicates are dropped after extraction. This is the lightweight analogue
// of CocoIndex's EntityTypeConfig schema constraint.
AllowedPredicates []string
// MaxRetries is the schema-validation retry budget. Defaults to 2.
MaxRetries int
}

// relationshipSchema is the JSON schema for a list of Relationship triples,
// passed to ChatWithStructuredOutput for validation + retry.
func relationshipSchema() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"relationships": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"subject": map[string]interface{}{"type": "string"},
"predicate": map[string]interface{}{"type": "string"},
"object": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"subject", "predicate", "object"},
},
},
},
"required": []interface{}{"relationships"},
}
}

const defaultExtractInstruction = `Extract the relationships expressed in the text as subject-predicate-object triples.

Rules:
- subject and object MUST be nouns or noun phrases naming concrete entities (people, systems, concepts) — never verbs, sentences, or pronouns.
- predicate is a short typed relation in snake_case (e.g. depends_on, authored_by, part_of).
- Extract only relationships actually stated or clearly implied; do not invent facts.
- Deduplicate: emit each distinct triple at most once.`

// ExtractRelationships extracts subject-predicate-object triples from text using
// schema-validated structured output with retry. It is a typed convenience layer
// over ChatWithStructuredOutput, modeled on CocoIndex's ExtractByLlm; yaad and
// other knowledge-graph consumers can call it instead of hand-rolling extraction
// prompts and JSON parsing.
func (c *EyrieClient) ExtractRelationships(ctx context.Context, text string, opts ExtractOptions) ([]Relationship, error) {
if strings.TrimSpace(text) == "" {
return nil, fmt.Errorf("eyrie: extract: text must not be empty")
}

instruction := opts.Instruction
if instruction == "" {
instruction = defaultExtractInstruction
}
if len(opts.AllowedPredicates) > 0 {
instruction += "\n\nUse ONLY these predicates: " + strings.Join(opts.AllowedPredicates, ", ") + "."
}

maxRetries := opts.MaxRetries
if maxRetries <= 0 {
maxRetries = 2
}

messages := []EyrieMessage{
{Role: "system", Content: instruction},
{Role: "user", Content: text},
}

schema := relationshipSchema()
resp, err := c.ChatWithStructuredOutput(ctx, messages, opts.Chat, SchemaValidation{
Schema: schema,
MaxRetries: maxRetries,
})
if err != nil {
return nil, fmt.Errorf("eyrie: extract relationships: %w", err)
}

var parsed struct {
Relationships []Relationship `json:"relationships"`
}
if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil {
return nil, fmt.Errorf("eyrie: extract relationships: decode: %w", err)
}

return filterRelationships(parsed.Relationships, opts.AllowedPredicates), nil
}

// filterRelationships drops empty triples and, when a predicate allowlist is
// given, any triple whose predicate is not in it. Predicate matching is
// case-insensitive to tolerate model casing drift.
func filterRelationships(rels []Relationship, allowed []string) []Relationship {
allowSet := make(map[string]struct{}, len(allowed))
for _, p := range allowed {
allowSet[strings.ToLower(strings.TrimSpace(p))] = struct{}{}
}

out := make([]Relationship, 0, len(rels))
for _, r := range rels {
if strings.TrimSpace(r.Subject) == "" || strings.TrimSpace(r.Predicate) == "" || strings.TrimSpace(r.Object) == "" {
continue
}
if len(allowSet) > 0 {
if _, ok := allowSet[strings.ToLower(strings.TrimSpace(r.Predicate))]; !ok {
continue
}
}
out = append(out, r)
}
return out
}
59 changes: 59 additions & 0 deletions client/extract_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package client

import (
"context"
"testing"
)

func TestFilterRelationships_DropsEmpty(t *testing.T) {
in := []Relationship{
{Subject: "eyrie", Predicate: "part_of", Object: "hawk-eco"},
{Subject: "", Predicate: "x", Object: "y"}, // empty subject
{Subject: "a", Predicate: " ", Object: "b"}, // blank predicate
{Subject: "c", Predicate: "rel", Object: ""}, // empty object
}
out := filterRelationships(in, nil)
if len(out) != 1 || out[0].Subject != "eyrie" {
t.Fatalf("expected only the complete triple, got %#v", out)
}
}

func TestFilterRelationships_PredicateAllowlist(t *testing.T) {
in := []Relationship{
{Subject: "a", Predicate: "depends_on", Object: "b"},
{Subject: "a", Predicate: "Depends_On", Object: "c"}, // case-insensitive match
{Subject: "a", Predicate: "authored_by", Object: "d"}, // not allowed
}
out := filterRelationships(in, []string{"depends_on"})
if len(out) != 2 {
t.Fatalf("expected 2 depends_on triples (case-insensitive), got %d: %#v", len(out), out)
}
}

func TestExtractRelationships_EmptyText(t *testing.T) {
c := &EyrieClient{}
_, err := c.ExtractRelationships(context.Background(), " ", ExtractOptions{})
if err == nil {
t.Fatal("expected error for empty text")
}
}

func TestRelationshipSchema_Shape(t *testing.T) {
s := relationshipSchema()
props, ok := s["properties"].(map[string]interface{})
if !ok {
t.Fatal("schema missing properties")
}
rels, ok := props["relationships"].(map[string]interface{})
if !ok || rels["type"] != "array" {
t.Fatalf("relationships should be an array, got %#v", rels)
}
items, ok := rels["items"].(map[string]interface{})
if !ok {
t.Fatal("array missing items schema")
}
req, ok := items["required"].([]interface{})
if !ok || len(req) != 3 {
t.Fatalf("each triple should require subject/predicate/object, got %#v", req)
}
}
Loading