fix: L2-normalize local embedding vectors to fix semantic search (#5332)

* fix: L2-normalize local embedding vectors to fix semantic search

* fix: handle non‑finite magnitude in L2 normalization and remove stale test reset

* refactor: add braces to l2Normalize guard clause in embeddings

* fix: sanitize local embeddings (#5332) (thanks @akramcodez)

---------

Co-authored-by: Gustavo Madeira Santana <gumadeiras@gmail.com>
This commit is contained in:
Sk Akram 2026-02-02 09:26:44 +05:30 committed by GitHub
parent b9910ab037
commit 5020bfa2a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 165 additions and 2 deletions

View file

@ -326,3 +326,157 @@ describe("embedding provider local fallback", () => {
).rejects.toThrow(/optional dependency node-llama-cpp/i);
});
});
describe("local embedding normalization", () => {
afterEach(() => {
vi.resetAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
vi.doUnmock("./node-llama.js");
});
it("normalizes local embeddings to magnitude ~1.0", async () => {
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
vi.doMock("./node-llama.js", () => ({
importNodeLlamaCpp: async () => ({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(unnormalizedVector),
}),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
}),
}));
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embedding = await result.provider.embedQuery("test query");
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
});
it("handles zero vector without division by zero", async () => {
const zeroVector = [0, 0, 0, 0];
vi.doMock("./node-llama.js", () => ({
importNodeLlamaCpp: async () => ({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(zeroVector),
}),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
}),
}));
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embedding = await result.provider.embedQuery("test");
expect(embedding).toEqual([0, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("sanitizes non-finite values before normalization", async () => {
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
vi.doMock("./node-llama.js", () => ({
importNodeLlamaCpp: async () => ({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(nonFiniteVector),
}),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
}),
}));
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embedding = await result.provider.embedQuery("test");
expect(embedding).toEqual([1, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("normalizes batch embeddings to magnitude ~1.0", async () => {
const unnormalizedVectors = [
[2.35, 3.45, 0.63, 4.3],
[10.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
];
vi.doMock("./node-llama.js", () => ({
importNodeLlamaCpp: async () => ({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi
.fn()
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
}),
}));
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embeddings = await result.provider.embedBatch(["text1", "text2", "text3"]);
for (const embedding of embeddings) {
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
}
});
});

View file

@ -6,6 +6,15 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { importNodeLlamaCpp } from "./node-llama.js";
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
if (magnitude < 1e-10) {
return sanitized;
}
return sanitized.map((value) => value / magnitude);
}
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
@ -98,14 +107,14 @@ async function createLocalEmbeddingProvider(
embedQuery: async (text) => {
const ctx = await ensureContext();
const embedding = await ctx.getEmbeddingFor(text);
return Array.from(embedding.vector);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
},
embedBatch: async (texts) => {
const ctx = await ensureContext();
const embeddings = await Promise.all(
texts.map(async (text) => {
const embedding = await ctx.getEmbeddingFor(text);
return Array.from(embedding.vector);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}),
);
return embeddings;