fix: L2-normalize local embedding vectors to fix semantic search (#5332)
* fix: L2-normalize local embedding vectors to fix semantic search * fix: handle non‑finite magnitude in L2 normalization and remove stale test reset * refactor: add braces to l2Normalize guard clause in embeddings * fix: sanitize local embeddings (#5332) (thanks @akramcodez) --------- Co-authored-by: Gustavo Madeira Santana <gumadeiras@gmail.com>
This commit is contained in:
parent
b9910ab037
commit
5020bfa2a9
2 changed files with 165 additions and 2 deletions
|
|
@ -326,3 +326,157 @@ describe("embedding provider local fallback", () => {
|
||||||
).rejects.toThrow(/optional dependency node-llama-cpp/i);
|
).rejects.toThrow(/optional dependency node-llama-cpp/i);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("local embedding normalization", () => {
|
||||||
|
afterEach(() => {
|
||||||
|
vi.resetAllMocks();
|
||||||
|
vi.resetModules();
|
||||||
|
vi.unstubAllGlobals();
|
||||||
|
vi.doUnmock("./node-llama.js");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("normalizes local embeddings to magnitude ~1.0", async () => {
|
||||||
|
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
|
||||||
|
|
||||||
|
vi.doMock("./node-llama.js", () => ({
|
||||||
|
importNodeLlamaCpp: async () => ({
|
||||||
|
getLlama: async () => ({
|
||||||
|
loadModel: vi.fn().mockResolvedValue({
|
||||||
|
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||||
|
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||||
|
vector: new Float32Array(unnormalizedVector),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
resolveModelFile: async () => "/fake/model.gguf",
|
||||||
|
LlamaLogLevel: { error: 0 },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||||
|
|
||||||
|
const result = await createEmbeddingProvider({
|
||||||
|
config: {} as never,
|
||||||
|
provider: "local",
|
||||||
|
model: "",
|
||||||
|
fallback: "none",
|
||||||
|
});
|
||||||
|
|
||||||
|
const embedding = await result.provider.embedQuery("test query");
|
||||||
|
|
||||||
|
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||||
|
|
||||||
|
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles zero vector without division by zero", async () => {
|
||||||
|
const zeroVector = [0, 0, 0, 0];
|
||||||
|
|
||||||
|
vi.doMock("./node-llama.js", () => ({
|
||||||
|
importNodeLlamaCpp: async () => ({
|
||||||
|
getLlama: async () => ({
|
||||||
|
loadModel: vi.fn().mockResolvedValue({
|
||||||
|
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||||
|
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||||
|
vector: new Float32Array(zeroVector),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
resolveModelFile: async () => "/fake/model.gguf",
|
||||||
|
LlamaLogLevel: { error: 0 },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||||
|
|
||||||
|
const result = await createEmbeddingProvider({
|
||||||
|
config: {} as never,
|
||||||
|
provider: "local",
|
||||||
|
model: "",
|
||||||
|
fallback: "none",
|
||||||
|
});
|
||||||
|
|
||||||
|
const embedding = await result.provider.embedQuery("test");
|
||||||
|
|
||||||
|
expect(embedding).toEqual([0, 0, 0, 0]);
|
||||||
|
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sanitizes non-finite values before normalization", async () => {
|
||||||
|
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
|
||||||
|
|
||||||
|
vi.doMock("./node-llama.js", () => ({
|
||||||
|
importNodeLlamaCpp: async () => ({
|
||||||
|
getLlama: async () => ({
|
||||||
|
loadModel: vi.fn().mockResolvedValue({
|
||||||
|
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||||
|
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||||
|
vector: new Float32Array(nonFiniteVector),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
resolveModelFile: async () => "/fake/model.gguf",
|
||||||
|
LlamaLogLevel: { error: 0 },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||||
|
|
||||||
|
const result = await createEmbeddingProvider({
|
||||||
|
config: {} as never,
|
||||||
|
provider: "local",
|
||||||
|
model: "",
|
||||||
|
fallback: "none",
|
||||||
|
});
|
||||||
|
|
||||||
|
const embedding = await result.provider.embedQuery("test");
|
||||||
|
|
||||||
|
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||||
|
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("normalizes batch embeddings to magnitude ~1.0", async () => {
|
||||||
|
const unnormalizedVectors = [
|
||||||
|
[2.35, 3.45, 0.63, 4.3],
|
||||||
|
[10.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0],
|
||||||
|
];
|
||||||
|
|
||||||
|
vi.doMock("./node-llama.js", () => ({
|
||||||
|
importNodeLlamaCpp: async () => ({
|
||||||
|
getLlama: async () => ({
|
||||||
|
loadModel: vi.fn().mockResolvedValue({
|
||||||
|
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||||
|
getEmbeddingFor: vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
|
||||||
|
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
|
||||||
|
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
resolveModelFile: async () => "/fake/model.gguf",
|
||||||
|
LlamaLogLevel: { error: 0 },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||||
|
|
||||||
|
const result = await createEmbeddingProvider({
|
||||||
|
config: {} as never,
|
||||||
|
provider: "local",
|
||||||
|
model: "",
|
||||||
|
fallback: "none",
|
||||||
|
});
|
||||||
|
|
||||||
|
const embeddings = await result.provider.embedBatch(["text1", "text2", "text3"]);
|
||||||
|
|
||||||
|
for (const embedding of embeddings) {
|
||||||
|
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||||
|
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,15 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb
|
||||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||||
|
|
||||||
|
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
||||||
|
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
|
||||||
|
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
|
||||||
|
if (magnitude < 1e-10) {
|
||||||
|
return sanitized;
|
||||||
|
}
|
||||||
|
return sanitized.map((value) => value / magnitude);
|
||||||
|
}
|
||||||
|
|
||||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||||
|
|
||||||
|
|
@ -98,14 +107,14 @@ async function createLocalEmbeddingProvider(
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text) => {
|
||||||
const ctx = await ensureContext();
|
const ctx = await ensureContext();
|
||||||
const embedding = await ctx.getEmbeddingFor(text);
|
const embedding = await ctx.getEmbeddingFor(text);
|
||||||
return Array.from(embedding.vector);
|
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||||
},
|
},
|
||||||
embedBatch: async (texts) => {
|
embedBatch: async (texts) => {
|
||||||
const ctx = await ensureContext();
|
const ctx = await ensureContext();
|
||||||
const embeddings = await Promise.all(
|
const embeddings = await Promise.all(
|
||||||
texts.map(async (text) => {
|
texts.map(async (text) => {
|
||||||
const embedding = await ctx.getEmbeddingFor(text);
|
const embedding = await ctx.getEmbeddingFor(text);
|
||||||
return Array.from(embedding.vector);
|
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
return embeddings;
|
return embeddings;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue