fix: expand SSRF guard coverage

This commit is contained in:
Peter Steinberger 2026-02-02 04:57:09 -08:00
parent c429ccb64f
commit 9bd64c8a1f
14 changed files with 214 additions and 96 deletions

View file

@ -2,6 +2,13 @@
Docs: https://docs.openclaw.ai Docs: https://docs.openclaw.ai
## 2026.2.2
### Fixes
- Security: guard skill installer downloads with SSRF checks (block private/localhost URLs).
- Media understanding: apply SSRF guardrails to provider fetches; allow private baseUrl overrides explicitly.
## 2026.2.1 ## 2026.2.1
### Changes ### Changes

View file

@ -5,6 +5,7 @@ import { Readable } from "node:stream";
import { pipeline } from "node:stream/promises"; import { pipeline } from "node:stream/promises";
import type { OpenClawConfig } from "../config/config.js"; import type { OpenClawConfig } from "../config/config.js";
import { resolveBrewExecutable } from "../infra/brew.js"; import { resolveBrewExecutable } from "../infra/brew.js";
import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js";
import { runCommandWithTimeout } from "../process/exec.js"; import { runCommandWithTimeout } from "../process/exec.js";
import { CONFIG_DIR, ensureDir, resolveUserPath } from "../utils.js"; import { CONFIG_DIR, ensureDir, resolveUserPath } from "../utils.js";
import { import {
@ -176,10 +177,11 @@ async function downloadFile(
destPath: string, destPath: string,
timeoutMs: number, timeoutMs: number,
): Promise<{ bytes: number }> { ): Promise<{ bytes: number }> {
const controller = new AbortController(); const { response, release } = await fetchWithSsrFGuard({
const timeout = setTimeout(() => controller.abort(), Math.max(1_000, timeoutMs)); url,
timeoutMs: Math.max(1_000, timeoutMs),
});
try { try {
const response = await fetch(url, { signal: controller.signal });
if (!response.ok || !response.body) { if (!response.ok || !response.body) {
throw new Error(`Download failed (${response.status} ${response.statusText})`); throw new Error(`Download failed (${response.status} ${response.statusText})`);
} }
@ -193,7 +195,7 @@ async function downloadFile(
const stat = await fs.promises.stat(destPath); const stat = await fs.promises.stat(destPath);
return { bytes: stat.size }; return { bytes: stat.size };
} finally { } finally {
clearTimeout(timeout); await release();
} }
} }

View file

@ -394,10 +394,12 @@ async function runWebFetch(params: {
url: params.url, url: params.url,
maxRedirects: params.maxRedirects, maxRedirects: params.maxRedirects,
timeoutMs: params.timeoutSeconds * 1000, timeoutMs: params.timeoutSeconds * 1000,
headers: { init: {
Accept: "*/*", headers: {
"User-Agent": params.userAgent, Accept: "*/*",
"Accept-Language": "en-US,en;q=0.9", "User-Agent": params.userAgent,
"Accept-Language": "en-US,en;q=0.9",
},
}, },
}); });
res = result.response; res = result.response;

View file

@ -13,13 +13,13 @@ type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise<Respo
export type GuardedFetchOptions = { export type GuardedFetchOptions = {
url: string; url: string;
fetchImpl?: FetchLike; fetchImpl?: FetchLike;
method?: string; init?: RequestInit;
headers?: HeadersInit;
maxRedirects?: number; maxRedirects?: number;
timeoutMs?: number; timeoutMs?: number;
signal?: AbortSignal; signal?: AbortSignal;
policy?: SsrFPolicy; policy?: SsrFPolicy;
lookupFn?: LookupFn; lookupFn?: LookupFn;
pinDns?: boolean;
}; };
export type GuardedFetchResult = { export type GuardedFetchResult = {
@ -122,13 +122,14 @@ export async function fetchWithSsrFGuard(params: GuardedFetchOptions): Promise<G
policy: params.policy, policy: params.policy,
}) })
: await resolvePinnedHostname(parsedUrl.hostname, params.lookupFn); : await resolvePinnedHostname(parsedUrl.hostname, params.lookupFn);
dispatcher = createPinnedDispatcher(pinned); if (params.pinDns !== false) {
dispatcher = createPinnedDispatcher(pinned);
}
const init: RequestInit & { dispatcher?: Dispatcher } = { const init: RequestInit & { dispatcher?: Dispatcher } = {
...(params.init ? { ...params.init } : {}),
redirect: "manual", redirect: "manual",
dispatcher, ...(dispatcher ? { dispatcher } : {}),
...(params.method ? { method: params.method } : {}),
...(params.headers ? { headers: params.headers } : {}),
...(signal ? { signal } : {}), ...(signal ? { signal } : {}),
}; };

View file

@ -1,6 +1,13 @@
import { describe, expect, it } from "vitest"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as ssrf from "../../../infra/net/ssrf.js";
import { transcribeDeepgramAudio } from "./audio.js"; import { transcribeDeepgramAudio } from "./audio.js";
const resolvePinnedHostname = ssrf.resolvePinnedHostname;
const resolvePinnedHostnameWithPolicy = ssrf.resolvePinnedHostnameWithPolicy;
const lookupMock = vi.fn();
let resolvePinnedHostnameSpy: ReturnType<typeof vi.spyOn> | null = null;
let resolvePinnedHostnameWithPolicySpy: ReturnType<typeof vi.spyOn> | null = null;
const resolveRequestUrl = (input: RequestInfo | URL) => { const resolveRequestUrl = (input: RequestInfo | URL) => {
if (typeof input === "string") { if (typeof input === "string") {
return input; return input;
@ -12,6 +19,26 @@ const resolveRequestUrl = (input: RequestInfo | URL) => {
}; };
describe("transcribeDeepgramAudio", () => { describe("transcribeDeepgramAudio", () => {
beforeEach(() => {
lookupMock.mockResolvedValue([{ address: "93.184.216.34", family: 4 }]);
resolvePinnedHostnameSpy = vi
.spyOn(ssrf, "resolvePinnedHostname")
.mockImplementation((hostname) => resolvePinnedHostname(hostname, lookupMock));
resolvePinnedHostnameWithPolicySpy = vi
.spyOn(ssrf, "resolvePinnedHostnameWithPolicy")
.mockImplementation((hostname, params) =>
resolvePinnedHostnameWithPolicy(hostname, { ...params, lookupFn: lookupMock }),
);
});
afterEach(() => {
lookupMock.mockReset();
resolvePinnedHostnameSpy?.mockRestore();
resolvePinnedHostnameWithPolicySpy?.mockRestore();
resolvePinnedHostnameSpy = null;
resolvePinnedHostnameWithPolicySpy = null;
});
it("respects lowercase authorization header overrides", async () => { it("respects lowercase authorization header overrides", async () => {
let seenAuth: string | null = null; let seenAuth: string | null = null;
const fetchFn = async (_input: RequestInfo | URL, init?: RequestInit) => { const fetchFn = async (_input: RequestInfo | URL, init?: RequestInit) => {

View file

@ -1,5 +1,5 @@
import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js"; import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js";
import { fetchWithTimeout, normalizeBaseUrl, readErrorResponse } from "../shared.js"; import { fetchWithTimeoutGuarded, normalizeBaseUrl, readErrorResponse } from "../shared.js";
export const DEFAULT_DEEPGRAM_AUDIO_BASE_URL = "https://api.deepgram.com/v1"; export const DEFAULT_DEEPGRAM_AUDIO_BASE_URL = "https://api.deepgram.com/v1";
export const DEFAULT_DEEPGRAM_AUDIO_MODEL = "nova-3"; export const DEFAULT_DEEPGRAM_AUDIO_MODEL = "nova-3";
@ -24,6 +24,7 @@ export async function transcribeDeepgramAudio(
): Promise<AudioTranscriptionResult> { ): Promise<AudioTranscriptionResult> {
const fetchFn = params.fetchFn ?? fetch; const fetchFn = params.fetchFn ?? fetch;
const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_DEEPGRAM_AUDIO_BASE_URL); const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_DEEPGRAM_AUDIO_BASE_URL);
const allowPrivate = Boolean(params.baseUrl?.trim());
const model = resolveModel(params.model); const model = resolveModel(params.model);
const url = new URL(`${baseUrl}/listen`); const url = new URL(`${baseUrl}/listen`);
@ -49,7 +50,7 @@ export async function transcribeDeepgramAudio(
} }
const body = new Uint8Array(params.buffer); const body = new Uint8Array(params.buffer);
const res = await fetchWithTimeout( const { response: res, release } = await fetchWithTimeoutGuarded(
url.toString(), url.toString(),
{ {
method: "POST", method: "POST",
@ -58,18 +59,23 @@ export async function transcribeDeepgramAudio(
}, },
params.timeoutMs, params.timeoutMs,
fetchFn, fetchFn,
allowPrivate ? { ssrfPolicy: { allowPrivateNetwork: true } } : undefined,
); );
if (!res.ok) { try {
const detail = await readErrorResponse(res); if (!res.ok) {
const suffix = detail ? `: ${detail}` : ""; const detail = await readErrorResponse(res);
throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`); const suffix = detail ? `: ${detail}` : "";
} throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`);
}
const payload = (await res.json()) as DeepgramTranscriptResponse; const payload = (await res.json()) as DeepgramTranscriptResponse;
const transcript = payload.results?.channels?.[0]?.alternatives?.[0]?.transcript?.trim(); const transcript = payload.results?.channels?.[0]?.alternatives?.[0]?.transcript?.trim();
if (!transcript) { if (!transcript) {
throw new Error("Audio transcription response missing transcript"); throw new Error("Audio transcription response missing transcript");
}
return { text: transcript, model };
} finally {
await release();
} }
return { text: transcript, model };
} }

View file

@ -1,6 +1,6 @@
import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js"; import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js";
import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js"; import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js";
import { fetchWithTimeout, normalizeBaseUrl, readErrorResponse } from "../shared.js"; import { fetchWithTimeoutGuarded, normalizeBaseUrl, readErrorResponse } from "../shared.js";
export const DEFAULT_GOOGLE_AUDIO_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; export const DEFAULT_GOOGLE_AUDIO_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GOOGLE_AUDIO_MODEL = "gemini-3-flash-preview"; const DEFAULT_GOOGLE_AUDIO_MODEL = "gemini-3-flash-preview";
@ -24,6 +24,7 @@ export async function transcribeGeminiAudio(
): Promise<AudioTranscriptionResult> { ): Promise<AudioTranscriptionResult> {
const fetchFn = params.fetchFn ?? fetch; const fetchFn = params.fetchFn ?? fetch;
const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_GOOGLE_AUDIO_BASE_URL); const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_GOOGLE_AUDIO_BASE_URL);
const allowPrivate = Boolean(params.baseUrl?.trim());
const model = resolveModel(params.model); const model = resolveModel(params.model);
const url = `${baseUrl}/models/${model}:generateContent`; const url = `${baseUrl}/models/${model}:generateContent`;
@ -52,7 +53,7 @@ export async function transcribeGeminiAudio(
], ],
}; };
const res = await fetchWithTimeout( const { response: res, release } = await fetchWithTimeoutGuarded(
url, url,
{ {
method: "POST", method: "POST",
@ -61,26 +62,31 @@ export async function transcribeGeminiAudio(
}, },
params.timeoutMs, params.timeoutMs,
fetchFn, fetchFn,
allowPrivate ? { ssrfPolicy: { allowPrivateNetwork: true } } : undefined,
); );
if (!res.ok) { try {
const detail = await readErrorResponse(res); if (!res.ok) {
const suffix = detail ? `: ${detail}` : ""; const detail = await readErrorResponse(res);
throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`); const suffix = detail ? `: ${detail}` : "";
} throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`);
}
const payload = (await res.json()) as { const payload = (await res.json()) as {
candidates?: Array<{ candidates?: Array<{
content?: { parts?: Array<{ text?: string }> }; content?: { parts?: Array<{ text?: string }> };
}>; }>;
}; };
const parts = payload.candidates?.[0]?.content?.parts ?? []; const parts = payload.candidates?.[0]?.content?.parts ?? [];
const text = parts const text = parts
.map((part) => part?.text?.trim()) .map((part) => part?.text?.trim())
.filter(Boolean) .filter(Boolean)
.join("\n"); .join("\n");
if (!text) { if (!text) {
throw new Error("Audio transcription response missing text"); throw new Error("Audio transcription response missing text");
}
return { text, model };
} finally {
await release();
} }
return { text, model };
} }

View file

@ -1,6 +1,6 @@
import type { VideoDescriptionRequest, VideoDescriptionResult } from "../../types.js"; import type { VideoDescriptionRequest, VideoDescriptionResult } from "../../types.js";
import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js"; import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js";
import { fetchWithTimeout, normalizeBaseUrl, readErrorResponse } from "../shared.js"; import { fetchWithTimeoutGuarded, normalizeBaseUrl, readErrorResponse } from "../shared.js";
export const DEFAULT_GOOGLE_VIDEO_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; export const DEFAULT_GOOGLE_VIDEO_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GOOGLE_VIDEO_MODEL = "gemini-3-flash-preview"; const DEFAULT_GOOGLE_VIDEO_MODEL = "gemini-3-flash-preview";
@ -24,6 +24,7 @@ export async function describeGeminiVideo(
): Promise<VideoDescriptionResult> { ): Promise<VideoDescriptionResult> {
const fetchFn = params.fetchFn ?? fetch; const fetchFn = params.fetchFn ?? fetch;
const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_GOOGLE_VIDEO_BASE_URL); const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_GOOGLE_VIDEO_BASE_URL);
const allowPrivate = Boolean(params.baseUrl?.trim());
const model = resolveModel(params.model); const model = resolveModel(params.model);
const url = `${baseUrl}/models/${model}:generateContent`; const url = `${baseUrl}/models/${model}:generateContent`;
@ -52,7 +53,7 @@ export async function describeGeminiVideo(
], ],
}; };
const res = await fetchWithTimeout( const { response: res, release } = await fetchWithTimeoutGuarded(
url, url,
{ {
method: "POST", method: "POST",
@ -61,26 +62,31 @@ export async function describeGeminiVideo(
}, },
params.timeoutMs, params.timeoutMs,
fetchFn, fetchFn,
allowPrivate ? { ssrfPolicy: { allowPrivateNetwork: true } } : undefined,
); );
if (!res.ok) { try {
const detail = await readErrorResponse(res); if (!res.ok) {
const suffix = detail ? `: ${detail}` : ""; const detail = await readErrorResponse(res);
throw new Error(`Video description failed (HTTP ${res.status})${suffix}`); const suffix = detail ? `: ${detail}` : "";
} throw new Error(`Video description failed (HTTP ${res.status})${suffix}`);
}
const payload = (await res.json()) as { const payload = (await res.json()) as {
candidates?: Array<{ candidates?: Array<{
content?: { parts?: Array<{ text?: string }> }; content?: { parts?: Array<{ text?: string }> };
}>; }>;
}; };
const parts = payload.candidates?.[0]?.content?.parts ?? []; const parts = payload.candidates?.[0]?.content?.parts ?? [];
const text = parts const text = parts
.map((part) => part?.text?.trim()) .map((part) => part?.text?.trim())
.filter(Boolean) .filter(Boolean)
.join("\n"); .join("\n");
if (!text) { if (!text) {
throw new Error("Video description response missing text"); throw new Error("Video description response missing text");
}
return { text, model };
} finally {
await release();
} }
return { text, model };
} }

View file

@ -1,6 +1,13 @@
import { describe, expect, it } from "vitest"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as ssrf from "../../../infra/net/ssrf.js";
import { transcribeOpenAiCompatibleAudio } from "./audio.js"; import { transcribeOpenAiCompatibleAudio } from "./audio.js";
const resolvePinnedHostname = ssrf.resolvePinnedHostname;
const resolvePinnedHostnameWithPolicy = ssrf.resolvePinnedHostnameWithPolicy;
const lookupMock = vi.fn();
let resolvePinnedHostnameSpy: ReturnType<typeof vi.spyOn> | null = null;
let resolvePinnedHostnameWithPolicySpy: ReturnType<typeof vi.spyOn> | null = null;
const resolveRequestUrl = (input: RequestInfo | URL) => { const resolveRequestUrl = (input: RequestInfo | URL) => {
if (typeof input === "string") { if (typeof input === "string") {
return input; return input;
@ -12,6 +19,26 @@ const resolveRequestUrl = (input: RequestInfo | URL) => {
}; };
describe("transcribeOpenAiCompatibleAudio", () => { describe("transcribeOpenAiCompatibleAudio", () => {
beforeEach(() => {
lookupMock.mockResolvedValue([{ address: "93.184.216.34", family: 4 }]);
resolvePinnedHostnameSpy = vi
.spyOn(ssrf, "resolvePinnedHostname")
.mockImplementation((hostname) => resolvePinnedHostname(hostname, lookupMock));
resolvePinnedHostnameWithPolicySpy = vi
.spyOn(ssrf, "resolvePinnedHostnameWithPolicy")
.mockImplementation((hostname, params) =>
resolvePinnedHostnameWithPolicy(hostname, { ...params, lookupFn: lookupMock }),
);
});
afterEach(() => {
lookupMock.mockReset();
resolvePinnedHostnameSpy?.mockRestore();
resolvePinnedHostnameWithPolicySpy?.mockRestore();
resolvePinnedHostnameSpy = null;
resolvePinnedHostnameWithPolicySpy = null;
});
it("respects lowercase authorization header overrides", async () => { it("respects lowercase authorization header overrides", async () => {
let seenAuth: string | null = null; let seenAuth: string | null = null;
const fetchFn = async (_input: RequestInfo | URL, init?: RequestInit) => { const fetchFn = async (_input: RequestInfo | URL, init?: RequestInit) => {

View file

@ -1,6 +1,6 @@
import path from "node:path"; import path from "node:path";
import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js"; import type { AudioTranscriptionRequest, AudioTranscriptionResult } from "../../types.js";
import { fetchWithTimeout, normalizeBaseUrl, readErrorResponse } from "../shared.js"; import { fetchWithTimeoutGuarded, normalizeBaseUrl, readErrorResponse } from "../shared.js";
export const DEFAULT_OPENAI_AUDIO_BASE_URL = "https://api.openai.com/v1"; export const DEFAULT_OPENAI_AUDIO_BASE_URL = "https://api.openai.com/v1";
const DEFAULT_OPENAI_AUDIO_MODEL = "gpt-4o-mini-transcribe"; const DEFAULT_OPENAI_AUDIO_MODEL = "gpt-4o-mini-transcribe";
@ -15,6 +15,7 @@ export async function transcribeOpenAiCompatibleAudio(
): Promise<AudioTranscriptionResult> { ): Promise<AudioTranscriptionResult> {
const fetchFn = params.fetchFn ?? fetch; const fetchFn = params.fetchFn ?? fetch;
const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_OPENAI_AUDIO_BASE_URL); const baseUrl = normalizeBaseUrl(params.baseUrl, DEFAULT_OPENAI_AUDIO_BASE_URL);
const allowPrivate = Boolean(params.baseUrl?.trim());
const url = `${baseUrl}/audio/transcriptions`; const url = `${baseUrl}/audio/transcriptions`;
const model = resolveModel(params.model); const model = resolveModel(params.model);
@ -38,7 +39,7 @@ export async function transcribeOpenAiCompatibleAudio(
headers.set("authorization", `Bearer ${params.apiKey}`); headers.set("authorization", `Bearer ${params.apiKey}`);
} }
const res = await fetchWithTimeout( const { response: res, release } = await fetchWithTimeoutGuarded(
url, url,
{ {
method: "POST", method: "POST",
@ -47,18 +48,23 @@ export async function transcribeOpenAiCompatibleAudio(
}, },
params.timeoutMs, params.timeoutMs,
fetchFn, fetchFn,
allowPrivate ? { ssrfPolicy: { allowPrivateNetwork: true } } : undefined,
); );
if (!res.ok) { try {
const detail = await readErrorResponse(res); if (!res.ok) {
const suffix = detail ? `: ${detail}` : ""; const detail = await readErrorResponse(res);
throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`); const suffix = detail ? `: ${detail}` : "";
} throw new Error(`Audio transcription failed (HTTP ${res.status})${suffix}`);
}
const payload = (await res.json()) as { text?: string }; const payload = (await res.json()) as { text?: string };
const text = payload.text?.trim(); const text = payload.text?.trim();
if (!text) { if (!text) {
throw new Error("Audio transcription response missing text"); throw new Error("Audio transcription response missing text");
}
return { text, model };
} finally {
await release();
} }
return { text, model };
} }

View file

@ -1,3 +1,7 @@
import type { GuardedFetchResult } from "../../infra/net/fetch-guard.js";
import type { LookupFn, SsrFPolicy } from "../../infra/net/ssrf.js";
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
const MAX_ERROR_CHARS = 300; const MAX_ERROR_CHARS = 300;
export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string { export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string {
@ -20,6 +24,28 @@ export async function fetchWithTimeout(
} }
} }
export async function fetchWithTimeoutGuarded(
url: string,
init: RequestInit,
timeoutMs: number,
fetchFn: typeof fetch,
options?: {
ssrfPolicy?: SsrFPolicy;
lookupFn?: LookupFn;
pinDns?: boolean;
},
): Promise<GuardedFetchResult> {
return await fetchWithSsrFGuard({
url,
fetchImpl: fetchFn,
init,
timeoutMs,
policy: options?.ssrfPolicy,
lookupFn: options?.lookupFn,
pinDns: options?.pinDns,
});
}
export async function readErrorResponse(res: Response): Promise<string | undefined> { export async function readErrorResponse(res: Response): Promise<string | undefined> {
try { try {
const text = await res.text(); const text = await res.text();

View file

@ -146,7 +146,7 @@ export async function fetchWithGuard(params: {
url: params.url, url: params.url,
maxRedirects: params.maxRedirects, maxRedirects: params.maxRedirects,
timeoutMs: params.timeoutMs, timeoutMs: params.timeoutMs,
headers: { "User-Agent": "OpenClaw-Gateway/1.0" }, init: { headers: { "User-Agent": "OpenClaw-Gateway/1.0" } },
}); });
try { try {

View file

@ -54,8 +54,7 @@ function resolveRequestUrl(input: RequestInfo | URL): string {
if ("url" in input && typeof input.url === "string") { if ("url" in input && typeof input.url === "string") {
return input.url; return input.url;
} }
throw new Error("Unsupported fetch input: expected string, URL, or Request");
throw new Error(`Unable to resolve request URL from input: ${JSON.stringify(input, null, 2)}`);
} }
function createSlackMediaFetch(token: string): FetchLike { function createSlackMediaFetch(token: string): FetchLike {

View file

@ -11,7 +11,9 @@ const sendChatActionSpy = vi.fn();
const cacheStickerSpy = vi.fn(); const cacheStickerSpy = vi.fn();
const getCachedStickerSpy = vi.fn(); const getCachedStickerSpy = vi.fn();
const describeStickerImageSpy = vi.fn(); const describeStickerImageSpy = vi.fn();
const ssrfResolveSpy = vi.spyOn(ssrf, "resolvePinnedHostname"); const resolvePinnedHostname = ssrf.resolvePinnedHostname;
const lookupMock = vi.fn();
let resolvePinnedHostnameSpy: ReturnType<typeof vi.spyOn> | null = null;
type ApiStub = { type ApiStub = {
config: { use: (arg: unknown) => void }; config: { use: (arg: unknown) => void };
@ -28,15 +30,16 @@ const apiStub: ApiStub = {
beforeEach(() => { beforeEach(() => {
vi.useRealTimers(); vi.useRealTimers();
resetInboundDedupe(); resetInboundDedupe();
ssrfResolveSpy.mockImplementation(async (hostname) => { lookupMock.mockResolvedValue([{ address: "93.184.216.34", family: 4 }]);
const normalized = hostname.trim().toLowerCase().replace(/\.$/, ""); resolvePinnedHostnameSpy = vi
const addresses = ["93.184.216.34"]; .spyOn(ssrf, "resolvePinnedHostname")
return { .mockImplementation((hostname) => resolvePinnedHostname(hostname, lookupMock));
hostname: normalized, });
addresses,
lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }), afterEach(() => {
}; lookupMock.mockReset();
}); resolvePinnedHostnameSpy?.mockRestore();
resolvePinnedHostnameSpy = null;
}); });
vi.mock("grammy", () => ({ vi.mock("grammy", () => ({
@ -169,7 +172,7 @@ describe("telegram inbound media", () => {
expect(runtimeError).not.toHaveBeenCalled(); expect(runtimeError).not.toHaveBeenCalled();
expect(fetchSpy).toHaveBeenCalledWith( expect(fetchSpy).toHaveBeenCalledWith(
"https://api.telegram.org/file/bottok/photos/1.jpg", "https://api.telegram.org/file/bottok/photos/1.jpg",
expect.any(Object), expect.objectContaining({ redirect: "manual" }),
); );
expect(replySpy).toHaveBeenCalledTimes(1); expect(replySpy).toHaveBeenCalledTimes(1);
const payload = replySpy.mock.calls[0][0]; const payload = replySpy.mock.calls[0][0];
@ -227,7 +230,7 @@ describe("telegram inbound media", () => {
expect(runtimeError).not.toHaveBeenCalled(); expect(runtimeError).not.toHaveBeenCalled();
expect(proxyFetch).toHaveBeenCalledWith( expect(proxyFetch).toHaveBeenCalledWith(
"https://api.telegram.org/file/bottok/photos/2.jpg", "https://api.telegram.org/file/bottok/photos/2.jpg",
expect.any(Object), expect.objectContaining({ redirect: "manual" }),
); );
globalFetchSpy.mockRestore(); globalFetchSpy.mockRestore();
@ -501,7 +504,7 @@ describe("telegram stickers", () => {
expect(runtimeError).not.toHaveBeenCalled(); expect(runtimeError).not.toHaveBeenCalled();
expect(fetchSpy).toHaveBeenCalledWith( expect(fetchSpy).toHaveBeenCalledWith(
"https://api.telegram.org/file/bottok/stickers/sticker.webp", "https://api.telegram.org/file/bottok/stickers/sticker.webp",
expect.any(Object), expect.objectContaining({ redirect: "manual" }),
); );
expect(replySpy).toHaveBeenCalledTimes(1); expect(replySpy).toHaveBeenCalledTimes(1);
const payload = replySpy.mock.calls[0][0]; const payload = replySpy.mock.calls[0][0];