fix: L2-normalize local embedding vectors to fix semantic search (#5332)
* fix: L2-normalize local embedding vectors to fix semantic search * fix: handle non‑finite magnitude in L2 normalization and remove stale test reset * refactor: add braces to l2Normalize guard clause in embeddings * fix: sanitize local embeddings (#5332) (thanks @akramcodez) --------- Co-authored-by: Gustavo Madeira Santana <gumadeiras@gmail.com>
This commit is contained in:
parent
b9910ab037
commit
5020bfa2a9
2 changed files with 165 additions and 2 deletions
|
|
@ -326,3 +326,157 @@ describe("embedding provider local fallback", () => {
|
|||
).rejects.toThrow(/optional dependency node-llama-cpp/i);
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding normalization", () => {
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
|
||||
it("normalizes local embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
|
||||
|
||||
vi.doMock("./node-llama.js", () => ({
|
||||
importNodeLlamaCpp: async () => ({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(unnormalizedVector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
}),
|
||||
}));
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await result.provider.embedQuery("test query");
|
||||
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
});
|
||||
|
||||
it("handles zero vector without division by zero", async () => {
|
||||
const zeroVector = [0, 0, 0, 0];
|
||||
|
||||
vi.doMock("./node-llama.js", () => ({
|
||||
importNodeLlamaCpp: async () => ({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(zeroVector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
}),
|
||||
}));
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await result.provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([0, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("sanitizes non-finite values before normalization", async () => {
|
||||
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
|
||||
|
||||
vi.doMock("./node-llama.js", () => ({
|
||||
importNodeLlamaCpp: async () => ({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(nonFiniteVector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
}),
|
||||
}));
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await result.provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("normalizes batch embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVectors = [
|
||||
[2.35, 3.45, 0.63, 4.3],
|
||||
[10.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
];
|
||||
|
||||
vi.doMock("./node-llama.js", () => ({
|
||||
importNodeLlamaCpp: async () => ({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
}),
|
||||
}));
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embeddings = await result.provider.embedBatch(["text1", "text2", "text3"]);
|
||||
|
||||
for (const embedding of embeddings) {
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb
|
|||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
||||
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
|
||||
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
|
||||
if (magnitude < 1e-10) {
|
||||
return sanitized;
|
||||
}
|
||||
return sanitized.map((value) => value / magnitude);
|
||||
}
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
|
||||
|
|
@ -98,14 +107,14 @@ async function createLocalEmbeddingProvider(
|
|||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return Array.from(embedding.vector);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return Array.from(embedding.vector);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
|
|
|
|||
Loading…
Reference in a new issue