diff --git a/go.mod b/go.mod index e9497ee..97a1105 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module go.kenn.io/kit go 1.26.3 require ( + github.com/asg017/sqlite-vec-go-bindings v0.1.6 github.com/gofrs/flock v0.13.0 + github.com/mattn/go-sqlite3 v1.14.44 github.com/posthog/posthog-go v1.12.6 github.com/stretchr/testify v1.11.1 go.opentelemetry.io/otel v1.43.0 @@ -14,6 +16,7 @@ require ( golang.org/x/sys v0.44.0 golang.org/x/term v0.43.0 golang.org/x/tools v0.45.0 + modernc.org/sqlite v1.53.0 ) require ( @@ -21,6 +24,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dnephin/pflag v1.0.7 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect @@ -32,7 +36,9 @@ require ( github.com/klauspost/compress v1.18.6 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect @@ -40,6 +46,9 @@ require ( golang.org/x/text v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/gotestsum v1.13.0 // indirect + modernc.org/libc v1.73.4 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) tool ( diff --git a/go.sum b/go.sum index f82d09d..ef6eaa8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= +github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -6,6 +8,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -21,6 +25,8 @@ github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -38,10 +44,16 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.12.6 h1:N+FrKWY6DOuDhV2OMgvtKAKDYGTdtS9/nuvr0BTyBp0= github.com/posthog/posthog-go v1.12.6/go.mod h1:xsVOW9YImilUcazwPNEq4PJDqEZf2KeCS758zXjwkPg= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= @@ -87,3 +99,31 @@ gotest.tools/gotestsum v1.13.0 h1:+Lh454O9mu9AMG1APV4o0y7oDYKyik/3kBOiCqiEpRo= gotest.tools/gotestsum v1.13.0/go.mod h1:7f0NS5hFb0dWr4NtcsAsF0y1kzjEFfAil0HiBQJE03Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c= +modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= +modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws= +modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc= +modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA= +modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg= +modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M= +modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/vector/AGENTS.md b/vector/AGENTS.md new file mode 100644 index 0000000..6c86733 --- /dev/null +++ b/vector/AGENTS.md @@ -0,0 +1,41 @@ +# vector package invariants + +`go.kenn.io/kit/vector` owns the backend-neutral parts of an embedding +pipeline. Preserve these invariants when changing it. + +## The storage boundary is the point of this package + +- The core `vector` package must not import `database/sql`, a driver, or + any backend client, and must not construct backend SQL. The `Fill` and + `Search` flows reach storage only through the `Store[K, G]` interface. +- Persistence is a function of the caller's source system. Backends live + in their own subpackages (e.g. `vector/sqlitevec`) so a caller wiring + one backend never pulls another backend's driver. New backends + (pgvector, duckdb) go in sibling subpackages, not into the core. +- Backends own query construction. The differences between sqlite-vec + `vec0 MATCH`, pgvector `<=>`, and duckdb `array_distance` belong behind + `QueryGeneration`, never in the core flows. + +## Keys and generations are opaque + +- Document identity is the caller's type `K` and generation identity its + type `G`. msgvault uses `int64`; kata uses UUIDs. Compare them for + equality only; never assume a type, a single id namespace, or an + ordering. Backends additionally require `K`/`G` to be types + `database/sql` can bind and scan. + +## Merge semantics + +- `Merge` takes per-generation lists in descending preference and keeps + the earliest list's hit on overlap (prefer the newer generation during + a migration). Coverage is a union — never drop a document that only one + generation covers, and never emit duplicates. +- Cross-generation scores are not comparable. Default to + `MergeNormalizedScore`; raw-score merging is opt-in. + +## Generations during migration + +- The mid-migration union exists because new documents land only in the + building generation while the active generation still serves the bulk. + `Search` must keep querying every generation `LiveGenerations` returns, + in the order it returns them. diff --git a/vector/chunk.go b/vector/chunk.go new file mode 100644 index 0000000..78239dd --- /dev/null +++ b/vector/chunk.go @@ -0,0 +1,48 @@ +package vector + +// Chunk is a window of text encoded as a single vector. Index is the +// chunk's position within the source content, starting at zero. +type Chunk struct { + Index int + Text string +} + +// SplitOptions controls how Split windows content into chunks. +type SplitOptions struct { + // MaxRunes bounds the number of runes in each chunk. Values <= 0 + // disable splitting and return the content as a single chunk. + MaxRunes int + // Overlap is the number of runes shared between consecutive chunks. + // It is clamped to the range [0, MaxRunes-1]. + Overlap int +} + +// Split windows content into overlapping chunks of at most MaxRunes runes. +// It splits on runes rather than bytes so multi-byte characters are never +// torn apart. Empty content yields no chunks. +// +// Split measures size in runes, not model tokens. Callers that budget by +// tokens should convert their token budget to an approximate rune count. +func Split(content string, o SplitOptions) []Chunk { + if content == "" { + return nil + } + runes := []rune(content) + if o.MaxRunes <= 0 || len(runes) <= o.MaxRunes { + return []Chunk{{Index: 0, Text: content}} + } + + overlap := min(max(o.Overlap, 0), o.MaxRunes-1) + stride := o.MaxRunes - overlap + + var chunks []Chunk + for start, idx := 0, 0; start < len(runes); start, idx = start+stride, idx+1 { + end := start + o.MaxRunes + if end >= len(runes) { + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:])}) + break + } + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:end])}) + } + return chunks +} diff --git a/vector/chunk_test.go b/vector/chunk_test.go new file mode 100644 index 0000000..e2af4f9 --- /dev/null +++ b/vector/chunk_test.go @@ -0,0 +1,84 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestSplit(t *testing.T) { + tests := []struct { + name string + content string + opts vector.SplitOptions + want []vector.Chunk + }{ + { + name: "empty yields no chunks", + content: "", + opts: vector.SplitOptions{MaxRunes: 4}, + want: nil, + }, + { + name: "non-positive max returns single chunk", + content: "hello world", + opts: vector.SplitOptions{MaxRunes: 0}, + want: []vector.Chunk{{Index: 0, Text: "hello world"}}, + }, + { + name: "content shorter than max is one chunk", + content: "abcd", + opts: vector.SplitOptions{MaxRunes: 8}, + want: []vector.Chunk{{Index: 0, Text: "abcd"}}, + }, + { + name: "windows without overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 5}, + want: []vector.Chunk{ + {Index: 0, Text: "abcde"}, + {Index: 1, Text: "fghij"}, + }, + }, + { + name: "windows with overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 4, Overlap: 1}, + want: []vector.Chunk{ + {Index: 0, Text: "abcd"}, + {Index: 1, Text: "defg"}, + {Index: 2, Text: "ghij"}, + }, + }, + { + name: "overlap at or above max clamps to max-1", + content: "abcdef", + opts: vector.SplitOptions{MaxRunes: 3, Overlap: 9}, + want: []vector.Chunk{ + {Index: 0, Text: "abc"}, + {Index: 1, Text: "bcd"}, + {Index: 2, Text: "cde"}, + {Index: 3, Text: "def"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, vector.Split(tt.content, tt.opts)) + }) + } +} + +func TestSplitDoesNotTearMultiByteRunes(t *testing.T) { + assert := assert.New(t) + // Each emoji is multiple bytes but one rune. + chunks := vector.Split("😀😁😂🤣", vector.SplitOptions{MaxRunes: 2}) + + assert.Equal([]vector.Chunk{ + {Index: 0, Text: "😀😁"}, + {Index: 1, Text: "😂🤣"}, + }, chunks) +} diff --git a/vector/doc.go b/vector/doc.go new file mode 100644 index 0000000..6d96ae3 --- /dev/null +++ b/vector/doc.go @@ -0,0 +1,24 @@ +// Package vector provides backend-neutral building blocks for embedding +// content and searching the resulting vectors. +// +// It is organized in three layers: +// +// - Transforms and value types: Split windows content into chunks, +// Generation identifies an embedding model, EncodeBatched batches +// encode calls, and RollupByDocument and Merge reduce and combine +// search results across generations. These are pure functions. +// +// - The Store contract: Store[K, G] is the persistence interface the +// flows depend on. Implementations are a function of the caller's +// source system and own all backend SQL and query construction; see +// the sqlitevec subpackage for a worked example. +// +// - Flows: Fill runs the scan-and-fill embedding loop and Search runs +// the cross-generation query-and-merge, both over a Store. +// +// Nothing in this package opens a database, holds an index, or constructs +// backend SQL — the flows delegate every storage operation to the Store. +// Document identity is the caller's own key type K, and generation +// identity its type G; the package compares both for equality but never +// interprets them. +package vector diff --git a/vector/encode.go b/vector/encode.go new file mode 100644 index 0000000..59458bf --- /dev/null +++ b/vector/encode.go @@ -0,0 +1,111 @@ +package vector + +import ( + "context" + "fmt" + "sync" +) + +// Vector is a single embedding. +type Vector []float32 + +// EncodeFunc turns a batch of texts into one vector each, in the same +// order. Implementations own the model or API client and any retry or +// backoff policy, since retryability is provider-specific. +type EncodeFunc func(ctx context.Context, texts []string) ([][]float32, error) + +// BatchOptions controls how EncodeBatched groups and parallelizes calls. +type BatchOptions struct { + // BatchSize is the maximum number of chunks passed to EncodeFunc in a + // single call. Values <= 0 send every chunk in one call. + BatchSize int + // Concurrency bounds how many EncodeFunc calls run at once. Values + // <= 0 mean one call at a time. + Concurrency int +} + +// EncodeBatched splits chunks into batches, invokes enc with bounded +// concurrency, and returns one Vector per input chunk in input order. It +// stops launching work at the first error or when ctx is cancelled, and +// reports the first error encountered. +func EncodeBatched(ctx context.Context, enc EncodeFunc, chunks []Chunk, o BatchOptions) ([]Vector, error) { + if enc == nil { + return nil, fmt.Errorf("encode func is nil") + } + if len(chunks) == 0 { + return nil, nil + } + + batchSize := o.BatchSize + if batchSize <= 0 { + batchSize = len(chunks) + } + concurrency := o.Concurrency + if concurrency <= 0 { + concurrency = 1 + } + + out := make([]Vector, len(chunks)) + sem := make(chan struct{}, concurrency) + var ( + wg sync.WaitGroup + mu sync.Mutex + firstErr error + ) + failed := func() bool { + mu.Lock() + defer mu.Unlock() + return firstErr != nil + } + setErr := func(err error) { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + + for start := 0; start < len(chunks); start += batchSize { + if ctx.Err() != nil { + setErr(ctx.Err()) + break + } + if failed() { + break + } + + end := min(start+batchSize, len(chunks)) + texts := make([]string, end-start) + for i, c := range chunks[start:end] { + texts[i] = c.Text + } + + sem <- struct{}{} + wg.Add(1) + go func(start int, texts []string) { + defer wg.Done() + defer func() { <-sem }() + + vecs, err := enc(ctx, texts) + if err != nil { + setErr(fmt.Errorf("encode batch at %d: %w", start, err)) + return + } + if len(vecs) != len(texts) { + setErr(fmt.Errorf("encode batch at %d: got %d vectors for %d texts", start, len(vecs), len(texts))) + return + } + // Each batch owns a disjoint index range, so writes to out + // never overlap across goroutines. + for i, v := range vecs { + out[start+i] = Vector(v) + } + }(start, texts) + } + + wg.Wait() + if firstErr != nil { + return nil, firstErr + } + return out, nil +} diff --git a/vector/encode_test.go b/vector/encode_test.go new file mode 100644 index 0000000..f783493 --- /dev/null +++ b/vector/encode_test.go @@ -0,0 +1,124 @@ +package vector_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +func chunks(texts ...string) []vector.Chunk { + out := make([]vector.Chunk, len(texts)) + for i, txt := range texts { + out[i] = vector.Chunk{Index: i, Text: txt} + } + return out +} + +// echoEncoder returns one vector per text whose single component encodes +// the text length, so results can be matched back to their input order. +func echoEncoder(record func(batch []string)) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + if record != nil { + record(texts) + } + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len(txt))} + } + return out, nil + } +} + +func TestEncodeBatchedPreservesOrderAcrossBatches(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var mu sync.Mutex + var sizes []int + enc := echoEncoder(func(batch []string) { + mu.Lock() + sizes = append(sizes, len(batch)) + mu.Unlock() + }) + + in := chunks("a", "bb", "ccc", "dddd", "eeeee") + out, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 2, Concurrency: 3}) + require.NoError(err) + require.Len(out, len(in)) + for i, c := range in { + assert.Equal(float32(len(c.Text)), out[i][0], "vector %d matches its input", i) + } + + mu.Lock() + defer mu.Unlock() + assert.ElementsMatch([]int{2, 2, 1}, sizes, "batches are sized by BatchSize") +} + +func TestEncodeBatchedRespectsConcurrencyBound(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var inFlight, maxInFlight atomic.Int64 + enc := func(_ context.Context, texts []string) ([][]float32, error) { + cur := inFlight.Add(1) + for { + prev := maxInFlight.Load() + if cur <= prev || maxInFlight.CompareAndSwap(prev, cur) { + break + } + } + defer inFlight.Add(-1) + out := make([][]float32, len(texts)) + return out, nil + } + + in := chunks("a", "b", "c", "d", "e", "f", "g", "h") + _, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 1, Concurrency: 2}) + require.NoError(err) + assert.LessOrEqual(maxInFlight.Load(), int64(2), "never exceeds the concurrency bound") +} + +func TestEncodeBatchedSurfacesEncodeError(t *testing.T) { + assert := assert.New(t) + sentinel := errors.New("boom") + enc := func(_ context.Context, _ []string) ([][]float32, error) { return nil, sentinel } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(err, sentinel) +} + +func TestEncodeBatchedRejectsCountMismatch(t *testing.T) { + assert := assert.New(t) + enc := func(_ context.Context, _ []string) ([][]float32, error) { + return [][]float32{{1}}, nil // one vector for two texts + } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{}) + assert.ErrorContains(err, "vectors for") +} + +func TestEncodeBatchedNilEncoder(t *testing.T) { + _, err := vector.EncodeBatched(context.Background(), nil, chunks("a"), vector.BatchOptions{}) + assert.Error(t, err) +} + +func TestEncodeBatchedEmptyInput(t *testing.T) { + assert := assert.New(t) + out, err := vector.EncodeBatched(context.Background(), echoEncoder(nil), nil, vector.BatchOptions{}) + assert.NoError(err) + assert.Empty(out) +} + +func TestEncodeBatchedStopsOnCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := vector.EncodeBatched(ctx, echoEncoder(nil), chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/vector/flow.go b/vector/flow.go new file mode 100644 index 0000000..65c73c9 --- /dev/null +++ b/vector/flow.go @@ -0,0 +1,115 @@ +package vector + +import ( + "context" + "fmt" +) + +// FillOptions configures Fill. +type FillOptions struct { + // ScanBatch is the number of pending documents fetched per scan. + // Values <= 0 use 128. + ScanBatch int + // Split controls how each document's content is windowed into chunks. + Split SplitOptions + // Batch controls how chunks are batched into encode calls. + Batch BatchOptions +} + +// FillStats reports what a Fill run embedded. +type FillStats struct { + Documents int + Chunks int +} + +// Fill embeds every document that still needs the target generation: it +// scans the store for pending documents, splits and encodes each, and +// saves the resulting vectors, repeating until no documents remain. It is +// the generic scan-and-fill loop; the store decides what counts as +// pending and persists the results. +func Fill[K, G comparable](ctx context.Context, store Store[K, G], gen G, enc EncodeFunc, o FillOptions) (FillStats, error) { + scanBatch := o.ScanBatch + if scanBatch <= 0 { + scanBatch = 128 + } + + var stats FillStats + for { + if err := ctx.Err(); err != nil { + return stats, err + } + pending, err := store.PendingForGeneration(ctx, gen, scanBatch) + if err != nil { + return stats, fmt.Errorf("scan pending: %w", err) + } + if len(pending) == 0 { + return stats, nil + } + + for _, p := range pending { + chunks := Split(p.Content, o.Split) + vectors, err := EncodeBatched(ctx, enc, chunks, o.Batch) + if err != nil { + return stats, fmt.Errorf("encode document %v: %w", p.Doc, err) + } + cvs := make([]ChunkVector, len(chunks)) + for i, c := range chunks { + cvs[i] = ChunkVector{ChunkIndex: c.Index, Vector: vectors[i]} + } + if err := store.SaveVectors(ctx, gen, p.Doc, cvs); err != nil { + return stats, fmt.Errorf("save document %v: %w", p.Doc, err) + } + stats.Documents++ + stats.Chunks += len(cvs) + } + } +} + +// SearchOptions configures Search. +type SearchOptions struct { + // PerGeneration caps how many hits are fetched from each generation + // before merging. Values <= 0 use 50. + PerGeneration int + // Merge configures how per-generation results are combined. + Merge MergeOptions +} + +// Search embeds queryText once per live generation (each may use a +// different model), queries each generation, rolls the chunk hits up to +// documents, and merges the per-generation results into one ranking. +// encFor maps a generation to the encoder for that generation's model. +func Search[K, G comparable]( + ctx context.Context, + store Store[K, G], + queryText string, + encFor func(gen G) EncodeFunc, + o SearchOptions, +) ([]Hit[K], error) { + perGen := o.PerGeneration + if perGen <= 0 { + perGen = 50 + } + + gens, err := store.LiveGenerations(ctx) + if err != nil { + return nil, fmt.Errorf("live generations: %w", err) + } + + lists := make([][]Hit[K], 0, len(gens)) + for _, gen := range gens { + enc := encFor(gen) + if enc == nil { + return nil, fmt.Errorf("no encoder for generation %v", gen) + } + vectors, err := EncodeBatched(ctx, enc, []Chunk{{Index: 0, Text: queryText}}, BatchOptions{}) + if err != nil { + return nil, fmt.Errorf("embed query for generation %v: %w", gen, err) + } + hits, err := store.QueryGeneration(ctx, gen, vectors[0], perGen) + if err != nil { + return nil, fmt.Errorf("query generation %v: %w", gen, err) + } + lists = append(lists, RollupByDocument(hits)) + } + return Merge(lists, o.Merge), nil +} diff --git a/vector/flow_test.go b/vector/flow_test.go new file mode 100644 index 0000000..a7135d5 --- /dev/null +++ b/vector/flow_test.go @@ -0,0 +1,180 @@ +package vector_test + +import ( + "context" + "math" + "slices" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +// memStore is an in-memory Store[int64, int] used to exercise the flows +// without any real backend. Documents are keyed by int64; generations by +// int. QueryGeneration ranks by cosine similarity over stored vectors. +type memStore struct { + content map[int64]string + embedded map[int64]map[int]bool // doc -> gen -> done + vectors map[int]map[int64][]vector.ChunkVector // gen -> doc -> chunks + live []int // descending preference +} + +func newMemStore() *memStore { + return &memStore{ + content: map[int64]string{}, + embedded: map[int64]map[int]bool{}, + vectors: map[int]map[int64][]vector.ChunkVector{}, + } +} + +func (m *memStore) PendingForGeneration(_ context.Context, gen int, limit int) ([]vector.Pending[int64], error) { + keys := make([]int64, 0, len(m.content)) + for doc := range m.content { + if !m.embedded[doc][gen] { + keys = append(keys, doc) + } + } + slices.Sort(keys) + if limit > 0 && len(keys) > limit { + keys = keys[:limit] + } + out := make([]vector.Pending[int64], len(keys)) + for i, doc := range keys { + out[i] = vector.Pending[int64]{Doc: doc, Content: m.content[doc]} + } + return out, nil +} + +func (m *memStore) SaveVectors(_ context.Context, gen int, doc int64, vecs []vector.ChunkVector) error { + if m.vectors[gen] == nil { + m.vectors[gen] = map[int64][]vector.ChunkVector{} + } + m.vectors[gen][doc] = vecs + if m.embedded[doc] == nil { + m.embedded[doc] = map[int]bool{} + } + m.embedded[doc][gen] = true + return nil +} + +func (m *memStore) LiveGenerations(_ context.Context) ([]int, error) { + return m.live, nil +} + +func (m *memStore) QueryGeneration(_ context.Context, gen int, query vector.Vector, limit int) ([]vector.Hit[int64], error) { + var hits []vector.Hit[int64] + for doc, chunks := range m.vectors[gen] { + for _, cv := range chunks { + hits = append(hits, vector.Hit[int64]{Doc: doc, ChunkIndex: cv.ChunkIndex, Score: cosine(query, cv.Vector)}) + } + } + sort.SliceStable(hits, func(i, j int) bool { return hits[i].Score > hits[j].Score }) + if limit > 0 && len(hits) > limit { + hits = hits[:limit] + } + return hits, nil +} + +func cosine(a, b vector.Vector) float32 { + var dot, na, nb float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + na += float64(a[i]) * float64(a[i]) + nb += float64(b[i]) * float64(b[i]) + } + if na == 0 || nb == 0 { + return 0 + } + return float32(dot / (math.Sqrt(na) * math.Sqrt(nb))) +} + +// lenEncoder embeds each text as a 1-D vector of its rune length, enough +// to confirm Fill wired chunk content through to SaveVectors. +func lenEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len([]rune(txt)))} + } + return out, nil + } +} + +func TestFillEmbedsAllPendingThenStops(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + store := newMemStore() + store.content[1] = "alpha" + store.content[2] = "beta gamma delta" + + stats, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{ + ScanBatch: 1, // force multiple scan rounds + Split: vector.SplitOptions{MaxRunes: 4, Overlap: 0}, + }) + require.NoError(err) + + assert.Equal(2, stats.Documents) + assert.True(store.embedded[1][7] && store.embedded[2][7], "both docs stamped for gen 7") + require.Len(store.vectors[7][1], 2, "alpha -> 2 chunks of <=4 runes") + assert.InDelta(4, store.vectors[7][1][0].Vector[0], 1e-6, "first chunk carries its rune length") + + // A second run finds nothing pending and embeds zero documents. + again, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(0, again.Documents) +} + +func TestSearchRollsUpAndPrefersBuildingGeneration(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + const active, building = 7, 9 + store := newMemStore() + store.live = []int{building, active} // descending preference + + // Doc 1 is shared; active stored it at chunk 0, building at chunk 5. + store.SaveVectors(ctx, active, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, active, 2, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1}}}) + store.SaveVectors(ctx, building, 1, []vector.ChunkVector{{ChunkIndex: 5, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, building, 3, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) // new, building-only + + // Query vector [1,0] points at docs 1 and 3. + queryEnc := func(int) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{1, 0} + } + return out, nil + } + } + + got, err := vector.Search(ctx, store, "q", queryEnc, vector.SearchOptions{}) + require.NoError(err) + + byDoc := map[int64]vector.Hit[int64]{} + for _, h := range got { + byDoc[h.Doc] = h + } + assert.Contains(byDoc, int64(1)) + assert.Contains(byDoc, int64(2), "active-only doc is not dropped (union coverage)") + assert.Contains(byDoc, int64(3), "building-only new doc is searchable mid-migration") + assert.Equal(5, byDoc[1].ChunkIndex, "shared doc keeps the building generation's hit") +} + +func TestSearchErrorsWhenNoEncoderForGeneration(t *testing.T) { + ctx := context.Background() + store := newMemStore() + store.live = []int{1} + store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1}}}) + + _, err := vector.Search(ctx, store, "q", func(int) vector.EncodeFunc { return nil }, vector.SearchOptions{}) + assert.ErrorContains(t, err, "no encoder") +} diff --git a/vector/generation.go b/vector/generation.go new file mode 100644 index 0000000..50c896e --- /dev/null +++ b/vector/generation.go @@ -0,0 +1,58 @@ +package vector + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" +) + +// Generation identifies an embedding model configuration. Two pieces of +// content embedded under generations with the same Fingerprint share a +// vector space; a different fingerprint means the caller should treat the +// vectors as a new generation and re-embed. +// +// Every field that affects the vector space must be exported and JSON +// encodable so Fingerprint accounts for it automatically. A field that +// must not affect identity has to be tagged json:"-". +type Generation struct { + // Model names the embedding model, e.g. "text-embedding-3-small". + Model string `json:"model,omitempty"` + // Dimensions is the length of the vectors the model emits. + Dimensions int `json:"dimensions,omitempty"` + // Params holds any additional knobs that change the vector space, + // such as a pooling mode or prompt template. + Params map[string]string `json:"params,omitempty"` +} + +// Fingerprint returns a stable identifier derived from every field that +// affects the vector space. Callers persist it alongside stored vectors +// and compare it to decide whether a new generation is required. +// +// It is built to be stable across future changes to this type: +// +// - It encodes the struct itself, so a field added later participates +// automatically rather than being silently excluded — the failure +// mode that would let two distinct vector spaces share a fingerprint. +// - It then re-encodes through a generic value, and encoding/json sorts +// object keys at every level, so neither struct field order nor map +// insertion order affects the hash. +// - Decoding with UseNumber preserves numeric tokens exactly, so no +// field loses precision through float64. +// - omitempty drops zero-valued fields, so adding an unused field never +// shifts an existing generation's fingerprint. +// +// All values are JSON encodable, so the marshal and decode errors are +// unreachable. +func (g Generation) Fingerprint() string { + raw, _ := json.Marshal(g) + + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + var generic any + _ = dec.Decode(&generic) + + canonical, _ := json.Marshal(generic) + sum := sha256.Sum256(canonical) + return hex.EncodeToString(sum[:8]) +} diff --git a/vector/generation_test.go b/vector/generation_test.go new file mode 100644 index 0000000..75e71ac --- /dev/null +++ b/vector/generation_test.go @@ -0,0 +1,93 @@ +package vector_test + +import ( + "crypto/sha256" + "encoding/hex" + "reflect" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestGenerationFingerprintIsStableAndOrderIndependent(t *testing.T) { + assert := assert.New(t) + a := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"pooling": "mean", "prompt": "search"}, + } + b := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"prompt": "search", "pooling": "mean"}, + } + + assert.Equal(a.Fingerprint(), a.Fingerprint(), "same value fingerprints identically") + assert.Equal(a.Fingerprint(), b.Fingerprint(), "map order does not change fingerprint") +} + +func TestGenerationFingerprintIsNotAmbiguousAcrossParams(t *testing.T) { + assert := assert.New(t) + // Two params vs a single param whose value embeds what used to be the + // key/value separator. A naive "key=value\n" join hashes both the + // same; the JSON encoding keeps them distinct. + two := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean", "prompt": "x"}} + one := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean\nprompt=x"}} + + assert.NotEqual(two.Fingerprint(), one.Fingerprint()) +} + +func TestGenerationFingerprintChangesWithSpace(t *testing.T) { + assert := assert.New(t) + base := vector.Generation{Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean"}} + + cases := map[string]vector.Generation{ + "model": {Model: "other", Dimensions: 768, Params: map[string]string{"pooling": "mean"}}, + "dimensions": {Model: "m", Dimensions: 1024, Params: map[string]string{"pooling": "mean"}}, + "param value": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "cls"}}, + "extra param": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean", "prompt": "x"}}, + } + for name, g := range cases { + t.Run(name, func(t *testing.T) { + assert.NotEqual(base.Fingerprint(), g.Fingerprint()) + }) + } +} + +// TestGenerationFingerprintPinsCanonicalEncoding locks the exact hash +// preimage. If the canonical form ever changes — sorting, omit behavior, +// number formatting, or a field added to the struct's encoding — this +// fails, forcing a conscious decision rather than a silent shift of every +// persisted fingerprint. +func TestGenerationFingerprintPinsCanonicalEncoding(t *testing.T) { + g := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"b": "2", "a": "1"}} + + // Keys sorted at every level, zero fields omitted, numbers verbatim. + const canonical = `{"dimensions":3,"model":"m","params":{"a":"1","b":"2"}}` + sum := sha256.Sum256([]byte(canonical)) + want := hex.EncodeToString(sum[:8]) + + assert.Equal(t, want, g.Fingerprint()) +} + +// TestGenerationFieldsAreTracked is a tripwire: adding, removing, or +// renaming a Generation field changes this set. When it fails, decide +// whether the new field affects the vector space. If it does, Fingerprint +// already includes it (it encodes the whole struct); if it must not, tag +// the field json:"-". Then update this expectation and the pinned +// encoding above. +func TestGenerationFieldsAreTracked(t *testing.T) { + want := []string{"Dimensions", "Model", "Params"} + + fields := reflect.VisibleFields(reflect.TypeFor[vector.Generation]()) + got := make([]string, 0, len(fields)) + for _, f := range fields { + got = append(got, f.Name) + } + sort.Strings(got) + + assert.Equal(t, want, got, "Generation fields changed: review fingerprint impact before updating this tripwire") +} diff --git a/vector/search.go b/vector/search.go new file mode 100644 index 0000000..41c8ff2 --- /dev/null +++ b/vector/search.go @@ -0,0 +1,158 @@ +package vector + +import "sort" + +// Hit is a single search result identifying the document it belongs to. K +// is the caller's document key type (for example int64 or a UUID); this +// package compares keys for equality but never interprets them. +type Hit[K comparable] struct { + // Doc identifies the source document. + Doc K + // ChunkIndex is the chunk within Doc that matched. + ChunkIndex int + // Score is the backend's similarity score for this chunk. Merge + // overwrites it with the merged score under the chosen strategy. + Score float32 +} + +// RollupByDocument reduces chunk-level hits to one hit per document, +// keeping the highest-scoring chunk for each, and returns them sorted by +// score descending. It is the chunk->document step a caller applies to a +// single generation's results before merging across generations. +func RollupByDocument[K comparable](hits []Hit[K]) []Hit[K] { + if len(hits) == 0 { + return nil + } + best := make(map[K]Hit[K], len(hits)) + order := make([]K, 0, len(hits)) + for _, h := range hits { + cur, ok := best[h.Doc] + if !ok { + order = append(order, h.Doc) + best[h.Doc] = h + continue + } + if h.Score > cur.Score { + best[h.Doc] = h + } + } + out := make([]Hit[K], 0, len(order)) + for _, k := range order { + out = append(out, best[k]) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + return out +} + +// MergeStrategy selects how Merge orders documents drawn from different +// generations, whose raw scores are not directly comparable. +type MergeStrategy int + +const ( + // MergeNormalizedScore min-max normalizes each generation's scores to + // [0,1] before ordering. It is the default: it keeps score signal + // without letting one generation's score scale dominate. + MergeNormalizedScore MergeStrategy = iota + // MergeRawScore orders by raw score. Use it only when the generations + // share a model family and comparable score distributions. + MergeRawScore + // MergeReciprocalRank ignores absolute scores and fuses by rank. Use + // it when score distributions differ sharply between generations. + MergeReciprocalRank +) + +// MergeOptions configures Merge. +type MergeOptions struct { + // Strategy selects the ordering policy. The zero value is + // MergeNormalizedScore. + Strategy MergeStrategy + // RankConstant is the k term in reciprocal-rank fusion. Values <= 0 + // use 60. + RankConstant float64 + // Limit caps the number of returned hits. Values <= 0 return all. + Limit int +} + +// Merge unions per-generation, document-level result lists into one +// ranking. The lists are given in descending preference: when a document +// appears in more than one list, the hit from the earliest list is kept, +// which is how a caller expresses "prefer the newer generation" during a +// migration. Coverage is a union, so a document present in only one +// generation is never dropped. +// +// Each surviving hit's Score is set to the merged score under the chosen +// strategy, and the result is ordered by that score descending. +func Merge[K comparable](perGeneration [][]Hit[K], o MergeOptions) []Hit[K] { + rep := make(map[K]Hit[K]) + order := make([]K, 0) + score := make(map[K]float64) + + switch o.Strategy { + case MergeReciprocalRank: + k := o.RankConstant + if k <= 0 { + k = 60 + } + for _, list := range perGeneration { + for rank, h := range list { + if _, ok := rep[h.Doc]; !ok { + rep[h.Doc] = h + order = append(order, h.Doc) + } + score[h.Doc] += 1.0 / (k + float64(rank) + 1.0) + } + } + case MergeRawScore: + for _, list := range perGeneration { + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + score[h.Doc] = float64(h.Score) + } + } + default: // MergeNormalizedScore + for _, list := range perGeneration { + lo, hi := scoreRange(list) + span := hi - lo + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + if span > 0 { + score[h.Doc] = float64(h.Score-lo) / float64(span) + } else { + score[h.Doc] = 1 + } + } + } + } + + out := make([]Hit[K], 0, len(order)) + for _, doc := range order { + h := rep[doc] + h.Score = float32(score[doc]) + out = append(out, h) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + if o.Limit > 0 && len(out) > o.Limit { + out = out[:o.Limit] + } + return out +} + +func scoreRange[K comparable](hits []Hit[K]) (lo, hi float32) { + for i, h := range hits { + if i == 0 || h.Score < lo { + lo = h.Score + } + if i == 0 || h.Score > hi { + hi = h.Score + } + } + return lo, hi +} diff --git a/vector/search_test.go b/vector/search_test.go new file mode 100644 index 0000000..1137818 --- /dev/null +++ b/vector/search_test.go @@ -0,0 +1,109 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func docs[K comparable](hits []vector.Hit[K]) []K { + out := make([]K, len(hits)) + for i, h := range hits { + out[i] = h.Doc + } + return out +} + +func TestRollupByDocumentKeepsBestChunkPerDoc(t *testing.T) { + assert := assert.New(t) + hits := []vector.Hit[int64]{ + {Doc: 1, ChunkIndex: 0, Score: 0.2}, + {Doc: 2, ChunkIndex: 0, Score: 0.9}, + {Doc: 1, ChunkIndex: 3, Score: 0.7}, // better chunk for doc 1 + {Doc: 2, ChunkIndex: 1, Score: 0.4}, + } + + got := vector.RollupByDocument(hits) + + assert.Equal([]int64{2, 1}, docs(got), "one hit per doc, ordered by score desc") + assert.Equal(3, got[1].ChunkIndex, "doc 1 keeps its highest-scoring chunk") + assert.InDelta(0.7, got[1].Score, 1e-6) +} + +func TestMergeUnionsAndPrefersEarlierGeneration(t *testing.T) { + assert := assert.New(t) + // String keys stand in for kata's UUIDs; building generation first. + building := []vector.Hit[string]{ + {Doc: "shared", Score: 0.50}, + {Doc: "new-only", Score: 0.40}, + } + active := []vector.Hit[string]{ + {Doc: "shared", Score: 0.99}, // higher raw score, but less preferred + {Doc: "old-only", Score: 0.80}, + } + + got := vector.Merge([][]vector.Hit[string]{building, active}, vector.MergeOptions{Strategy: vector.MergeRawScore}) + + assert.ElementsMatch([]string{"shared", "new-only", "old-only"}, docs(got), "coverage is a union") + for _, h := range got { + if h.Doc == "shared" { + assert.InDelta(0.50, h.Score, 1e-6, "shared doc keeps the preferred (building) hit, not the higher raw score") + } + } +} + +func TestMergeNormalizedScoreIsDefault(t *testing.T) { + assert := assert.New(t) + // Active generation scores live in a compressed high band; building + // generation in a low band. Raw merge would let active dominate; + // normalization puts each generation's top hit at 1.0. + active := []vector.Hit[int]{ + {Doc: 1, Score: 0.90}, + {Doc: 2, Score: 0.85}, + } + building := []vector.Hit[int]{ + {Doc: 3, Score: 0.20}, + {Doc: 4, Score: 0.10}, + } + + got := vector.Merge([][]vector.Hit[int]{building, active}, vector.MergeOptions{}) + + // Each generation's best-normalized hit should reach the top band. + top := got[0] + assert.Contains([]int{1, 3}, top.Doc, "a normalized top hit leads, not just the raw-highest") + assert.InDelta(1.0, float64(top.Score), 1e-6) +} + +func TestMergeReciprocalRankFusesAcrossGenerations(t *testing.T) { + assert := assert.New(t) + // "shared" is rank 1 in one list and rank 2 in the other, so its + // fused score should beat docs that appear in only one list. + a := []vector.Hit[int]{ + {Doc: 10, Score: 0.99}, + {Doc: 99, Score: 0.98}, + } + b := []vector.Hit[int]{ + {Doc: 99, Score: 0.50}, + {Doc: 20, Score: 0.49}, + } + + got := vector.Merge([][]vector.Hit[int]{a, b}, vector.MergeOptions{Strategy: vector.MergeReciprocalRank}) + + assert.Equal(99, got[0].Doc, "the doc found in both generations ranks first") +} + +func TestMergeRespectsLimit(t *testing.T) { + assert := assert.New(t) + list := []vector.Hit[int]{{Doc: 1, Score: 0.9}, {Doc: 2, Score: 0.8}, {Doc: 3, Score: 0.7}} + + got := vector.Merge([][]vector.Hit[int]{list}, vector.MergeOptions{Strategy: vector.MergeRawScore, Limit: 2}) + + assert.Len(got, 2) + assert.Equal([]int{1, 2}, docs(got)) +} + +func TestMergeEmpty(t *testing.T) { + assert.Empty(t, vector.Merge[int](nil, vector.MergeOptions{})) +} diff --git a/vector/sqlitevec/extension_cgo.go b/vector/sqlitevec/extension_cgo.go new file mode 100644 index 0000000..97c0c78 --- /dev/null +++ b/vector/sqlitevec/extension_cgo.go @@ -0,0 +1,18 @@ +//go:build !windows && cgo + +package sqlitevec + +import vecext "github.com/asg017/sqlite-vec-go-bindings/cgo" + +// Register loads the sqlite-vec extension into every SQLite connection +// opened afterwards in this process. It must be called before opening the +// database the store will use. +func Register() { vecext.Auto() } + +func vectorValue(vector []float32) (string, any, error) { + blob, err := vecext.SerializeFloat32(vector) + if err != nil { + return "", nil, err + } + return "?", blob, nil +} diff --git a/vector/sqlitevec/extension_modernc.go b/vector/sqlitevec/extension_modernc.go new file mode 100644 index 0000000..3a411df --- /dev/null +++ b/vector/sqlitevec/extension_modernc.go @@ -0,0 +1,31 @@ +//go:build windows || !cgo + +package sqlitevec + +import ( + "strconv" + "strings" + + _ "modernc.org/sqlite/vec" +) + +// Register is kept as an explicit setup hook for callers. The modernc sqlite-vec +// extension is registered by package initialization, so no runtime work is needed. +func Register() {} + +func vectorValue(vector []float32) (string, any, error) { + return "vec_f32(?)", vectorLiteral(vector), nil +} + +func vectorLiteral(vector []float32) string { + var b strings.Builder + b.WriteByte('[') + for i, v := range vector { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(strconv.FormatFloat(float64(v), 'g', -1, 32)) + } + b.WriteByte(']') + return b.String() +} diff --git a/vector/sqlitevec/sqlite_driver_cgo_test.go b/vector/sqlitevec/sqlite_driver_cgo_test.go new file mode 100644 index 0000000..529de19 --- /dev/null +++ b/vector/sqlitevec/sqlite_driver_cgo_test.go @@ -0,0 +1,18 @@ +//go:build !windows && cgo + +package sqlitevec_test + +import ( + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + + "go.kenn.io/kit/vector/sqlitevec" +) + +func openSQLiteTestDB(t testing.TB, dsn string) (*sql.DB, error) { + t.Helper() + sqlitevec.Register() + return sql.Open("sqlite3", dsn) +} diff --git a/vector/sqlitevec/sqlite_driver_modernc_test.go b/vector/sqlitevec/sqlite_driver_modernc_test.go new file mode 100644 index 0000000..a1afef6 --- /dev/null +++ b/vector/sqlitevec/sqlite_driver_modernc_test.go @@ -0,0 +1,18 @@ +//go:build windows || !cgo + +package sqlitevec_test + +import ( + "database/sql" + "testing" + + _ "modernc.org/sqlite" + + "go.kenn.io/kit/vector/sqlitevec" +) + +func openSQLiteTestDB(t testing.TB, dsn string) (*sql.DB, error) { + t.Helper() + sqlitevec.Register() + return sql.Open("sqlite", dsn) +} diff --git a/vector/sqlitevec/sqlite_drivers_bench_test.go b/vector/sqlitevec/sqlite_drivers_bench_test.go new file mode 100644 index 0000000..0655c7c --- /dev/null +++ b/vector/sqlitevec/sqlite_drivers_bench_test.go @@ -0,0 +1,130 @@ +//go:build !windows && cgo + +package sqlitevec_test + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + "testing" + + cgosqlitevec "github.com/asg017/sqlite-vec-go-bindings/cgo" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + _ "modernc.org/sqlite/vec" + + "go.kenn.io/kit/vector" + "go.kenn.io/kit/vector/sqlitevec" +) + +type sqliteDriverBench struct { + name string + driverName string + setup func() +} + +var sqliteDriverBenches = []sqliteDriverBench{ + {name: "modernc", driverName: "sqlite"}, + {name: "mattn", driverName: "sqlite3", setup: cgosqlitevec.Auto}, +} + +func BenchmarkSQLiteDriverQueryGeneration(b *testing.B) { + for _, driver := range sqliteDriverBenches { + b.Run(driver.name, func(b *testing.B) { + require := require.New(b) + ctx := context.Background() + _, store := setupBenchmarkStore(b, driver, 1000, 16) + query := benchVector(0, 16) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hits, err := store.QueryGeneration(ctx, int64(1), query, 10) + if err != nil { + b.StopTimer() + require.NoError(err) + } + if len(hits) == 0 { + b.StopTimer() + require.NotEmpty(hits) + } + } + }) + } +} + +func BenchmarkSQLiteDriverSaveVectors(b *testing.B) { + for _, driver := range sqliteDriverBenches { + b.Run(driver.name, func(b *testing.B) { + require := require.New(b) + ctx := context.Background() + documents := 1000 + _, store := setupBenchmarkStore(b, driver, documents, 16) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + doc := int64(i%documents + 1) + err := store.SaveVectors(ctx, int64(1), doc, []vector.ChunkVector{{ChunkIndex: 0, Vector: benchVector(i, 16)}}) + if err != nil { + b.StopTimer() + require.NoError(err) + } + } + }) + } +} + +func setupBenchmarkStore(b *testing.B, driver sqliteDriverBench, documents, dimensions int) (*sql.DB, *sqlitevec.Store[int64, int64]) { + b.Helper() + require := require.New(b) + ctx := context.Background() + if driver.setup != nil { + driver.setup() + } + + db, err := sql.Open(driver.driverName, filepath.Join(b.TempDir(), "bench.db")) + require.NoError(err) + b.Cleanup(func() { require.NoError(db.Close()) }) + db.SetMaxOpenConns(1) + + _, err = db.ExecContext(ctx, `CREATE TABLE messages (id INTEGER PRIMARY KEY, body TEXT, embed_gen INTEGER)`) + require.NoError(err) + + tx, err := db.BeginTx(ctx, nil) + require.NoError(err) + stmt, err := tx.PrepareContext(ctx, `INSERT INTO messages (id, body) VALUES (?, ?)`) + require.NoError(err) + for i := 1; i <= documents; i++ { + _, err = stmt.ExecContext(ctx, i, fmt.Sprintf("document %d", i)) + require.NoError(err) + } + require.NoError(stmt.Close()) + require.NoError(tx.Commit()) + + store, err := sqlitevec.New[int64, int64](ctx, db, sqlitevec.Schema{ + DocsTable: "messages", + IDColumn: "id", + ContentColumn: "body", + EmbedGenColumn: "embed_gen", + VectorsPrefix: "message_vectors", + }) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, int64(1), vector.Generation{Model: "bench", Dimensions: dimensions}, sqlitevec.StateActive)) + + for i := 1; i <= documents; i++ { + err = store.SaveVectors(ctx, int64(1), int64(i), []vector.ChunkVector{{ChunkIndex: 0, Vector: benchVector(i, dimensions)}}) + require.NoError(err) + } + return db, store +} + +func benchVector(seed, dimensions int) vector.Vector { + out := make(vector.Vector, dimensions) + for i := range out { + out[i] = float32((seed*(i+3))%17 + 1) + } + return out +} diff --git a/vector/sqlitevec/sqlitevec.go b/vector/sqlitevec/sqlitevec.go new file mode 100644 index 0000000..1537c06 --- /dev/null +++ b/vector/sqlitevec/sqlitevec.go @@ -0,0 +1,185 @@ +// Package sqlitevec implements vector.Store on top of SQLite with the +// sqlite-vec extension. It is a reference backend: a worked example of the +// storage contract the vector flows depend on, built against sqlite-vec. +// +// On Unix platforms with cgo, call Register before opening a mattn/go-sqlite3 +// database: +// +// import _ "github.com/mattn/go-sqlite3" +// sqlitevec.Register() +// db, _ := sql.Open("sqlite3", path) +// +// On Windows or without cgo, import modernc.org/sqlite and open databases with +// the "sqlite" driver. The sqlite-vec extension is registered during package +// initialization in that build. +// +// The caller owns the documents table; this package owns a small set of +// vector tables derived from VectorsPrefix. Each generation gets its own +// vec0 virtual table sized to that generation's dimension, so generations +// with different model dimensions coexist during a migration. +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + "regexp" + + "go.kenn.io/kit/vector" +) + +var identifierPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// State is a generation's role in the active/building lifecycle. Only +// building and active generations are searched; building sorts ahead of +// active so Merge keeps the newer generation's hit on overlap. +type State string + +const ( + StatePending State = "pending" + StateBuilding State = "building" + StateActive State = "active" + StateRetired State = "retired" +) + +// Schema names the caller's documents table and the prefix for the +// vector tables this package manages. Every field must be a bare SQL +// identifier; values are validated before being interpolated into SQL. +type Schema struct { + DocsTable string // caller's documents table, e.g. "messages" + IDColumn string // primary key column, e.g. "id" + ContentColumn string // text to embed, e.g. "body" + EmbedGenColumn string // nullable generation stamp, e.g. "embed_gen" + VectorsPrefix string // prefix for managed tables, e.g. "message_vectors" +} + +func (s Schema) validate() error { + for name, value := range map[string]string{ + "docs table": s.DocsTable, + "id column": s.IDColumn, + "content column": s.ContentColumn, + "embed gen column": s.EmbedGenColumn, + "vectors prefix": s.VectorsPrefix, + } { + if !identifierPattern.MatchString(value) { + return fmt.Errorf("invalid %s %q", name, value) + } + } + return nil +} + +// Store implements vector.Store[K, G] against SQLite + sqlite-vec. K is the +// caller's document key type and G its generation key type; both must be +// types database/sql can bind and scan (for example int64 or string). +type Store[K, G comparable] struct { + db *sql.DB + schema Schema +} + +// New returns a Store bound to db. The caller retains ownership of db. New +// creates the generations and chunks bookkeeping tables if they do not +// exist; per-generation vec0 tables are created by EnsureGeneration. +func New[K, G comparable](ctx context.Context, db *sql.DB, schema Schema) (*Store[K, G], error) { + if err := schema.validate(); err != nil { + return nil, err + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + s := &Store[K, G]{db: db, schema: schema} + if _, err := db.ExecContext(ctx, fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER PRIMARY KEY, + gen_key UNIQUE, + dimension INTEGER NOT NULL, + state TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER NOT NULL, + doc_key NOT NULL, + chunk_index INTEGER NOT NULL, + vec_rowid INTEGER NOT NULL, + PRIMARY KEY (ordinal, doc_key, chunk_index) +);`, s.generationsTable(), s.chunksTable())); err != nil { + return nil, fmt.Errorf("create bookkeeping tables: %w", err) + } + return s, nil +} + +func (s *Store[K, G]) generationsTable() string { return s.schema.VectorsPrefix + "_generations" } +func (s *Store[K, G]) chunksTable() string { return s.schema.VectorsPrefix + "_chunks" } +func (s *Store[K, G]) vecTable(ordinal int64) string { + return fmt.Sprintf("%s_v%d", s.schema.VectorsPrefix, ordinal) +} + +// EnsureGeneration registers gen with model's dimension and the given +// state, creating its vec0 table on first use. Calling it again updates +// only the state; a generation's dimension is fixed once created. +func (s *Store[K, G]) EnsureGeneration(ctx context.Context, gen G, model vector.Generation, state State) error { + if model.Dimensions <= 0 { + return fmt.Errorf("generation dimension must be positive, got %d", model.Dimensions) + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin ensure generation: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(` +INSERT INTO %s (gen_key, dimension, state) VALUES (?, ?, ?) +ON CONFLICT(gen_key) DO UPDATE SET state = excluded.state`, s.generationsTable()), + gen, model.Dimensions, string(state)); err != nil { + return fmt.Errorf("upsert generation: %w", err) + } + + ordinal, dimension, err := s.lookupGenerationTx(ctx, tx, gen) + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf( + `CREATE VIRTUAL TABLE IF NOT EXISTS %s USING vec0(embedding float[%d] distance_metric=cosine)`, + s.vecTable(ordinal), dimension)); err != nil { + return fmt.Errorf("create vec0 table: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit ensure generation: %w", err) + } + return nil +} + +// SetGenerationState transitions gen to state. The caller owns the +// active/building lifecycle; this only records the decision. +func (s *Store[K, G]) SetGenerationState(ctx context.Context, gen G, state State) error { + res, err := s.db.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET state = ? WHERE gen_key = ?`, s.generationsTable()), + string(state), gen) + if err != nil { + return fmt.Errorf("set generation state: %w", err) + } + if n, _ := res.RowsAffected(); n == 0 { + return fmt.Errorf("generation %v not found", gen) + } + return nil +} + +func (s *Store[K, G]) lookupGeneration(ctx context.Context, gen G) (ordinal int64, dimension int, err error) { + return s.scanGeneration(s.db.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) lookupGenerationTx(ctx context.Context, tx *sql.Tx, gen G) (int64, int, error) { + return s.scanGeneration(tx.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) scanGeneration(row *sql.Row, gen G) (int64, int, error) { + var ordinal int64 + var dimension int + if err := row.Scan(&ordinal, &dimension); err != nil { + if err == sql.ErrNoRows { + return 0, 0, fmt.Errorf("generation %v not ensured", gen) + } + return 0, 0, fmt.Errorf("lookup generation %v: %w", gen, err) + } + return ordinal, dimension, nil +} diff --git a/vector/sqlitevec/sqlitevec_test.go b/vector/sqlitevec/sqlitevec_test.go new file mode 100644 index 0000000..c212607 --- /dev/null +++ b/vector/sqlitevec/sqlitevec_test.go @@ -0,0 +1,154 @@ +package sqlitevec_test + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" + "go.kenn.io/kit/vector/sqlitevec" +) + +// topicEncoder maps text to a one-hot 3-D vector by keyword, so queries +// match documents deterministically. +func topicEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, text := range texts { + switch { + case strings.Contains(text, "cat"): + out[i] = []float32{1, 0, 0} + case strings.Contains(text, "dog"): + out[i] = []float32{0, 1, 0} + default: + out[i] = []float32{0, 0, 1} + } + } + return out, nil + } +} + +func setup(t *testing.T) (*sql.DB, *sqlitevec.Store[int64, int64]) { + t.Helper() + db, err := openSQLiteTestDB(t, filepath.Join(t.TempDir(), "vec.db")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + _, err = db.Exec(`CREATE TABLE messages (id INTEGER PRIMARY KEY, body TEXT, embed_gen INTEGER)`) + require.NoError(t, err) + + store, err := sqlitevec.New[int64, int64](context.Background(), db, sqlitevec.Schema{ + DocsTable: "messages", + IDColumn: "id", + ContentColumn: "body", + EmbedGenColumn: "embed_gen", + VectorsPrefix: "message_vectors", + }) + require.NoError(t, err) + return db, store +} + +func TestStoreFillThenSearch(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat'), (2, 'a dog ran')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + stats, err := vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(2, stats.Documents) + + pending, err := store.PendingForGeneration(ctx, 1, 10) + require.NoError(err) + assert.Empty(pending, "nothing pending once every document is stamped") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + require.NotEmpty(hits) + assert.Equal(int64(1), hits[0].Doc, "the cat query ranks the cat document first") +} + +func TestStoreReembeddingReplacesVectors(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1, 0}}})) + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{0, 1, 0}, 10) + require.NoError(err) + require.Len(hits, 1, "re-embedding replaces the prior vector rather than duplicating it") + assert.InDelta(1.0, hits[0].Score, 1e-6, "stored vector now matches the new query") +} + +func TestStoreSearchUnionsLiveGenerations(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat'), (2, 'a dog')`) + require.NoError(err) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "v1", Dimensions: 3}, sqlitevec.StateActive)) + _, err = vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + + // The building generation has covered only doc 1 so far. + require.NoError(store.EnsureGeneration(ctx, 2, vector.Generation{Model: "v2", Dimensions: 3}, sqlitevec.StateBuilding)) + require.NoError(store.SaveVectors(ctx, 2, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + + gens, err := store.LiveGenerations(ctx) + require.NoError(err) + assert.Equal([]int64{2, 1}, gens, "building precedes active in preference order") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + + found := map[int64]bool{} + for _, h := range hits { + found[h.Doc] = true + } + assert.True(found[1], "shared doc is searchable") + assert.True(found[2], "active-only doc is not dropped mid-migration (union coverage)") +} + +func TestStoreSaveVectorsRejectsMissingDocument(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + _, store := setup(t) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + err := store.SaveVectors(ctx, 1, 999, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}}) + require.Error(err, "saving vectors for a document not in the source table fails") + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{1, 0, 0}, 10) + require.NoError(err) + assert.Empty(hits, "no orphan vectors are committed when the source row is missing") +} + +func TestNewRejectsUnsafeIdentifiers(t *testing.T) { + _, err := sqlitevec.New[int64, int64](context.Background(), nil, sqlitevec.Schema{ + DocsTable: "messages; DROP TABLE messages", + }) + require.Error(t, err) +} diff --git a/vector/sqlitevec/store.go b/vector/sqlitevec/store.go new file mode 100644 index 0000000..7754567 --- /dev/null +++ b/vector/sqlitevec/store.go @@ -0,0 +1,204 @@ +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + + "go.kenn.io/kit/vector" +) + +// PendingForGeneration scans the caller's documents table for rows whose +// stamp does not yet match gen, ordered by primary key for stable paging. +func (s *Store[K, G]) PendingForGeneration(ctx context.Context, gen G, limit int) ([]vector.Pending[K], error) { + query := fmt.Sprintf( + `SELECT %s, %s FROM %s WHERE %s IS NULL OR %s <> ? ORDER BY %s LIMIT ?`, + s.schema.IDColumn, s.schema.ContentColumn, s.schema.DocsTable, + s.schema.EmbedGenColumn, s.schema.EmbedGenColumn, s.schema.IDColumn) + rows, err := s.db.QueryContext(ctx, query, gen, limit) + if err != nil { + return nil, fmt.Errorf("scan pending: %w", err) + } + defer func() { _ = rows.Close() }() + + var pending []vector.Pending[K] + for rows.Next() { + var p vector.Pending[K] + if err := rows.Scan(&p.Doc, &p.Content); err != nil { + return nil, fmt.Errorf("scan pending row: %w", err) + } + pending = append(pending, p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan pending rows: %w", err) + } + return pending, nil +} + +// SaveVectors replaces doc's chunk vectors for gen and stamps the document +// as embedded for gen, all in one transaction. +func (s *Store[K, G]) SaveVectors(ctx context.Context, gen G, doc K, vectors []vector.ChunkVector) error { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return err + } + for _, cv := range vectors { + if len(cv.Vector) != dimension { + return fmt.Errorf("chunk %d has %d dimensions, generation expects %d", cv.ChunkIndex, len(cv.Vector), dimension) + } + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin save vectors: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Drop any prior vectors for this document so re-embedding is clean. + rowids, err := s.docRowids(ctx, tx, ordinal, doc) + if err != nil { + return err + } + for _, rowid := range rowids { + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE rowid = ?`, s.vecTable(ordinal)), rowid); err != nil { + return fmt.Errorf("delete stale vector: %w", err) + } + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc); err != nil { + return fmt.Errorf("delete stale chunk map: %w", err) + } + + for _, cv := range vectors { + expr, value, err := vectorValue(cv.Vector) + if err != nil { + return fmt.Errorf("serialize vector: %w", err) + } + res, err := tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (embedding) VALUES (%s)`, s.vecTable(ordinal), expr), value) + if err != nil { + return fmt.Errorf("insert vector: %w", err) + } + rowid, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("vector rowid: %w", err) + } + if _, err := tx.ExecContext(ctx, + fmt.Sprintf(`INSERT INTO %s (ordinal, doc_key, chunk_index, vec_rowid) VALUES (?, ?, ?, ?)`, s.chunksTable()), + ordinal, doc, cv.ChunkIndex, rowid); err != nil { + return fmt.Errorf("insert chunk map: %w", err) + } + } + + res, err := tx.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET %s = ? WHERE %s = ?`, s.schema.DocsTable, s.schema.EmbedGenColumn, s.schema.IDColumn), + gen, doc) + if err != nil { + return fmt.Errorf("stamp embed generation: %w", err) + } + stamped, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("stamp embed generation rows: %w", err) + } + if stamped == 0 { + // The source row vanished between scan and save (or the key is + // wrong). Roll back rather than commit vectors with no document, + // which QueryGeneration would otherwise surface as orphan hits. + return fmt.Errorf("document %v not present in %s; vectors not persisted", doc, s.schema.DocsTable) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit save vectors: %w", err) + } + return nil +} + +func (s *Store[K, G]) docRowids(ctx context.Context, tx txQuerier, ordinal int64, doc K) ([]int64, error) { + rows, err := tx.QueryContext(ctx, + fmt.Sprintf(`SELECT vec_rowid FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc) + if err != nil { + return nil, fmt.Errorf("read chunk map: %w", err) + } + defer func() { _ = rows.Close() }() + + var rowids []int64 + for rows.Next() { + var rowid int64 + if err := rows.Scan(&rowid); err != nil { + return nil, fmt.Errorf("scan chunk rowid: %w", err) + } + rowids = append(rowids, rowid) + } + return rowids, rows.Err() +} + +// LiveGenerations returns building and active generations, building first, +// so Merge prefers the newer generation when a document is in both. +func (s *Store[K, G]) LiveGenerations(ctx context.Context) ([]G, error) { + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` +SELECT gen_key FROM %s + WHERE state IN (?, ?) + ORDER BY CASE state WHEN ? THEN 0 ELSE 1 END, ordinal`, s.generationsTable()), + string(StateBuilding), string(StateActive), string(StateBuilding)) + if err != nil { + return nil, fmt.Errorf("list live generations: %w", err) + } + defer func() { _ = rows.Close() }() + + var gens []G + for rows.Next() { + var gen G + if err := rows.Scan(&gen); err != nil { + return nil, fmt.Errorf("scan generation key: %w", err) + } + gens = append(gens, gen) + } + return gens, rows.Err() +} + +// QueryGeneration runs a cosine KNN search within gen's vec0 table and +// maps each neighbor back to its document and chunk. Score is the cosine +// similarity (1 - cosine distance), so higher is more similar. +func (s *Store[K, G]) QueryGeneration(ctx context.Context, gen G, query vector.Vector, limit int) ([]vector.Hit[K], error) { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return nil, err + } + if len(query) != dimension { + return nil, fmt.Errorf("query has %d dimensions, generation expects %d", len(query), dimension) + } + expr, value, err := vectorValue(query) + if err != nil { + return nil, fmt.Errorf("serialize query: %w", err) + } + // The KNN runs against the vec0 table alone (its required form), then + // joins to the chunk map to recover document keys. + sqlText := fmt.Sprintf(` +WITH knn AS ( + SELECT rowid, distance FROM %s WHERE embedding MATCH %s ORDER BY distance LIMIT ? +) +SELECT c.doc_key, c.chunk_index, knn.distance + FROM knn JOIN %s c ON c.ordinal = ? AND c.vec_rowid = knn.rowid + ORDER BY knn.distance`, s.vecTable(ordinal), expr, s.chunksTable()) + rows, err := s.db.QueryContext(ctx, sqlText, value, limit, ordinal) + if err != nil { + return nil, fmt.Errorf("query generation: %w", err) + } + defer func() { _ = rows.Close() }() + + var hits []vector.Hit[K] + for rows.Next() { + var ( + doc K + chunkIndex int + distance float64 + ) + if err := rows.Scan(&doc, &chunkIndex, &distance); err != nil { + return nil, fmt.Errorf("scan hit: %w", err) + } + hits = append(hits, vector.Hit[K]{Doc: doc, ChunkIndex: chunkIndex, Score: float32(1 - distance)}) + } + return hits, rows.Err() +} + +// txQuerier is the read surface shared by *sql.DB and *sql.Tx. +type txQuerier interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} diff --git a/vector/sqlitevec/vec0_smoke_test.go b/vector/sqlitevec/vec0_smoke_test.go new file mode 100644 index 0000000..4955e88 --- /dev/null +++ b/vector/sqlitevec/vec0_smoke_test.go @@ -0,0 +1,34 @@ +package sqlitevec_test + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestVec0LoadsHermetically confirms the sqlite-vec extension is available +// through the modernc SQLite driver, so the backend's tests need no external setup. +func TestVec0LoadsHermetically(t *testing.T) { + require := require.New(t) + + db, err := openSQLiteTestDB(t, ":memory:") + require.NoError(err) + t.Cleanup(func() { require.NoError(db.Close()) }) + + _, err = db.Exec(`CREATE VIRTUAL TABLE v USING vec0(embedding float[3])`) + require.NoError(err) + + _, err = db.Exec(`INSERT INTO v(rowid, embedding) VALUES (1, vec_f32(?))`, `[1,2,3]`) + require.NoError(err) + + var rowid int64 + var distance float64 + err = db.QueryRow( + `SELECT rowid, distance FROM v WHERE embedding MATCH vec_f32(?) ORDER BY distance LIMIT 1`, + `[1,2,3]`, + ).Scan(&rowid, &distance) + require.NoError(err) + + require.Equal(int64(1), rowid) + require.InDelta(0, distance, 1e-6, "identical vectors have zero distance") +} diff --git a/vector/store.go b/vector/store.go new file mode 100644 index 0000000..28facc9 --- /dev/null +++ b/vector/store.go @@ -0,0 +1,47 @@ +package vector + +import "context" + +// Pending is one document that still needs embedding for a generation, +// paired with the text to embed. +type Pending[K comparable] struct { + Doc K + Content string +} + +// ChunkVector is a single chunk's embedding, ready to persist. +type ChunkVector struct { + ChunkIndex int + Vector Vector +} + +// Store is the persistence contract the Fill and Search flows depend on. +// Implementations are a function of the caller's source system — a SQLite, +// pgvector, or DuckDB table — and own all backend SQL and query +// construction. The flows never open a database or build SQL themselves. +// +// K is the caller's document key type and G its generation id type; the +// package compares both for equality but never interprets them. +type Store[K, G comparable] interface { + // PendingForGeneration returns up to limit documents that are not yet + // embedded for gen, in a stable order. Implementations typically scan + // for "embed_gen IS NULL OR embed_gen <> gen". A document must stop + // being reported once SaveVectors has persisted it for gen, so that a + // fill loop terminates. + PendingForGeneration(ctx context.Context, gen G, limit int) ([]Pending[K], error) + + // SaveVectors persists every chunk vector for doc under gen and marks + // doc as embedded for gen (the scan-and-fill stamp). + SaveVectors(ctx context.Context, gen G, doc K, vectors []ChunkVector) error + + // LiveGenerations returns the generations a search should query, in + // descending preference. During a migration the building generation + // precedes the active one, so Merge keeps the newer generation's hit + // when a document appears in both. + LiveGenerations(ctx context.Context) ([]G, error) + + // QueryGeneration returns chunk-level hits for query within gen, + // ranked best first and capped at limit. This is where each backend's + // vector query construction lives. + QueryGeneration(ctx context.Context, gen G, query Vector, limit int) ([]Hit[K], error) +}