diff --git a/client/embedding_cache.go b/client/embedding_cache.go index 8689e2b..9fdaaec 100644 --- a/client/embedding_cache.go +++ b/client/embedding_cache.go @@ -45,6 +45,7 @@ func DefaultSemanticCacheConfig() SemanticCacheConfig { // semanticEntry holds a cached response keyed by its request embedding. type semanticEntry struct { vector []float32 + model string // embedding model that produced vector; gates cross-model reuse response *EyrieResponse createdAt time.Time @@ -213,6 +214,12 @@ func (sp *EmbeddingCachedProvider) lookup(vec []float32) (*EyrieResponse, bool) if now.Sub(e.createdAt) > sp.maxAge { continue } + // Never compare across embedding models: vectors from a different model + // live in an incompatible space, so the cosine score would be meaningless + // (and on a same-dimension model swap could serve a wrong response). + if e.model != sp.model { + continue + } sim := cosineSimilarity(vec, e.vector) if sim >= bestSim { bestSim = sim @@ -241,6 +248,7 @@ func (sp *EmbeddingCachedProvider) store(vec []float32, resp *EyrieResponse) { e := &semanticEntry{ vector: vec, + model: sp.model, response: copyResponse(resp), createdAt: time.Now(), } diff --git a/client/embedding_cache_test.go b/client/embedding_cache_test.go index ffaeca4..ed8767f 100644 --- a/client/embedding_cache_test.go +++ b/client/embedding_cache_test.go @@ -104,6 +104,44 @@ func (errEmbedder) CreateEmbedding(_ context.Context, _ EmbeddingRequest) (*Embe return nil, errEmptyEmbedding } +// TestEmbeddingCache_ModelIsolation pins fix C: an entry stored under one +// embedding model must NOT be served to a request embedded by a different model, +// even when the vectors are identical — they live in incompatible spaces. +func TestEmbeddingCache_ModelIsolation(t *testing.T) { + mock := NewMockProvider(MockModeEcho) + cfg := DefaultSemanticCacheConfig() + cfg.EmbeddingModel = "model-A" + sp := NewEmbeddingCachedProvider(mock, stubEmbedder{}, cfg) + ctx := context.Background() + + // Warm the cache under model-A. + if _, err := sp.Chat(ctx, userMsg("what is the weather today"), ChatOptions{}); err != nil { + t.Fatalf("warm: %v", err) + } + if mock.CallCount() != 1 { + t.Fatalf("setup expected 1 inner call, got %d", mock.CallCount()) + } + + // Simulate the embedding model being swapped under the same cache. The prior + // entry (tagged model-A) must be skipped, forcing a fresh inner call rather + // than a cross-model false hit. + sp.model = "model-B" + if _, err := sp.Chat(ctx, userMsg("give me the weather please"), ChatOptions{}); err != nil { + t.Fatalf("post-swap: %v", err) + } + if mock.CallCount() != 2 { + t.Errorf("model swap must invalidate cross-model reuse; expected 2 inner calls, got %d", mock.CallCount()) + } + + // Requests under model-B should now cache and hit among themselves. + if _, err := sp.Chat(ctx, userMsg("weather report now"), ChatOptions{}); err != nil { + t.Fatalf("model-B hit: %v", err) + } + if mock.CallCount() != 2 { + t.Errorf("expected a within-model-B hit (still 2 inner calls), got %d", mock.CallCount()) + } +} + func TestCosineSimilarity(t *testing.T) { if got := cosineSimilarity([]float32{1, 0}, []float32{1, 0}); got < 0.999 { t.Errorf("identical vectors should be ~1, got %f", got) diff --git a/client/extract.go b/client/extract.go new file mode 100644 index 0000000..1ec8d31 --- /dev/null +++ b/client/extract.go @@ -0,0 +1,146 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "strings" +) + +// Relationship is a subject-predicate-object triple extracted from text. It is +// the unit of a knowledge graph: subject and object are entities (nouns), and +// predicate is the typed relation between them. +// +// This mirrors the typed-extraction pattern popularized by CocoIndex's +// ExtractByLlm(output_type=list[Relationship]) — instead of free-form text, the +// model is constrained to emit a list of these triples, validated against a JSON +// schema with automatic retry. +type Relationship struct { + // Subject is the entity the relation originates from. A noun/noun phrase. + Subject string `json:"subject"` + // Predicate is the typed relation (e.g. "depends_on", "authored_by"). + Predicate string `json:"predicate"` + // Object is the entity the relation points to. A noun/noun phrase. + Object string `json:"object"` +} + +// ExtractOptions configures triple extraction. The zero value is valid and uses +// sensible defaults (noun-constrained entities, 2 validation retries). +type ExtractOptions struct { + // Chat carries provider/model/temperature for the extraction call. If Model + // is empty the provider default is used. + Chat ChatOptions + // Instruction overrides the default extraction instruction. Use it to scope + // what relations to extract (e.g. "extract only code dependency relations"). + // When empty, a general noun-constrained instruction is used. + Instruction string + // AllowedPredicates, when non-empty, constrains the predicate vocabulary — + // the model is told to use only these relation types, and triples with other + // predicates are dropped after extraction. This is the lightweight analogue + // of CocoIndex's EntityTypeConfig schema constraint. + AllowedPredicates []string + // MaxRetries is the schema-validation retry budget. Defaults to 2. + MaxRetries int +} + +// relationshipSchema is the JSON schema for a list of Relationship triples, +// passed to ChatWithStructuredOutput for validation + retry. +func relationshipSchema() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "relationships": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "subject": map[string]interface{}{"type": "string"}, + "predicate": map[string]interface{}{"type": "string"}, + "object": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"subject", "predicate", "object"}, + }, + }, + }, + "required": []interface{}{"relationships"}, + } +} + +const defaultExtractInstruction = `Extract the relationships expressed in the text as subject-predicate-object triples. + +Rules: +- subject and object MUST be nouns or noun phrases naming concrete entities (people, systems, concepts) — never verbs, sentences, or pronouns. +- predicate is a short typed relation in snake_case (e.g. depends_on, authored_by, part_of). +- Extract only relationships actually stated or clearly implied; do not invent facts. +- Deduplicate: emit each distinct triple at most once.` + +// ExtractRelationships extracts subject-predicate-object triples from text using +// schema-validated structured output with retry. It is a typed convenience layer +// over ChatWithStructuredOutput, modeled on CocoIndex's ExtractByLlm; yaad and +// other knowledge-graph consumers can call it instead of hand-rolling extraction +// prompts and JSON parsing. +func (c *EyrieClient) ExtractRelationships(ctx context.Context, text string, opts ExtractOptions) ([]Relationship, error) { + if strings.TrimSpace(text) == "" { + return nil, fmt.Errorf("eyrie: extract: text must not be empty") + } + + instruction := opts.Instruction + if instruction == "" { + instruction = defaultExtractInstruction + } + if len(opts.AllowedPredicates) > 0 { + instruction += "\n\nUse ONLY these predicates: " + strings.Join(opts.AllowedPredicates, ", ") + "." + } + + maxRetries := opts.MaxRetries + if maxRetries <= 0 { + maxRetries = 2 + } + + messages := []EyrieMessage{ + {Role: "system", Content: instruction}, + {Role: "user", Content: text}, + } + + schema := relationshipSchema() + resp, err := c.ChatWithStructuredOutput(ctx, messages, opts.Chat, SchemaValidation{ + Schema: schema, + MaxRetries: maxRetries, + }) + if err != nil { + return nil, fmt.Errorf("eyrie: extract relationships: %w", err) + } + + var parsed struct { + Relationships []Relationship `json:"relationships"` + } + if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil { + return nil, fmt.Errorf("eyrie: extract relationships: decode: %w", err) + } + + return filterRelationships(parsed.Relationships, opts.AllowedPredicates), nil +} + +// filterRelationships drops empty triples and, when a predicate allowlist is +// given, any triple whose predicate is not in it. Predicate matching is +// case-insensitive to tolerate model casing drift. +func filterRelationships(rels []Relationship, allowed []string) []Relationship { + allowSet := make(map[string]struct{}, len(allowed)) + for _, p := range allowed { + allowSet[strings.ToLower(strings.TrimSpace(p))] = struct{}{} + } + + out := make([]Relationship, 0, len(rels)) + for _, r := range rels { + if strings.TrimSpace(r.Subject) == "" || strings.TrimSpace(r.Predicate) == "" || strings.TrimSpace(r.Object) == "" { + continue + } + if len(allowSet) > 0 { + if _, ok := allowSet[strings.ToLower(strings.TrimSpace(r.Predicate))]; !ok { + continue + } + } + out = append(out, r) + } + return out +} diff --git a/client/extract_test.go b/client/extract_test.go new file mode 100644 index 0000000..26e7fef --- /dev/null +++ b/client/extract_test.go @@ -0,0 +1,59 @@ +package client + +import ( + "context" + "testing" +) + +func TestFilterRelationships_DropsEmpty(t *testing.T) { + in := []Relationship{ + {Subject: "eyrie", Predicate: "part_of", Object: "hawk-eco"}, + {Subject: "", Predicate: "x", Object: "y"}, // empty subject + {Subject: "a", Predicate: " ", Object: "b"}, // blank predicate + {Subject: "c", Predicate: "rel", Object: ""}, // empty object + } + out := filterRelationships(in, nil) + if len(out) != 1 || out[0].Subject != "eyrie" { + t.Fatalf("expected only the complete triple, got %#v", out) + } +} + +func TestFilterRelationships_PredicateAllowlist(t *testing.T) { + in := []Relationship{ + {Subject: "a", Predicate: "depends_on", Object: "b"}, + {Subject: "a", Predicate: "Depends_On", Object: "c"}, // case-insensitive match + {Subject: "a", Predicate: "authored_by", Object: "d"}, // not allowed + } + out := filterRelationships(in, []string{"depends_on"}) + if len(out) != 2 { + t.Fatalf("expected 2 depends_on triples (case-insensitive), got %d: %#v", len(out), out) + } +} + +func TestExtractRelationships_EmptyText(t *testing.T) { + c := &EyrieClient{} + _, err := c.ExtractRelationships(context.Background(), " ", ExtractOptions{}) + if err == nil { + t.Fatal("expected error for empty text") + } +} + +func TestRelationshipSchema_Shape(t *testing.T) { + s := relationshipSchema() + props, ok := s["properties"].(map[string]interface{}) + if !ok { + t.Fatal("schema missing properties") + } + rels, ok := props["relationships"].(map[string]interface{}) + if !ok || rels["type"] != "array" { + t.Fatalf("relationships should be an array, got %#v", rels) + } + items, ok := rels["items"].(map[string]interface{}) + if !ok { + t.Fatal("array missing items schema") + } + req, ok := items["required"].([]interface{}) + if !ok || len(req) != 3 { + t.Fatalf("each triple should require subject/predicate/object, got %#v", req) + } +}