From 6c6f1e9660bd81424a25a9a4d1a9c66b9bae5a6c Mon Sep 17 00:00:00 2001 From: Ryan Nelson Date: Sun, 1 Feb 2026 14:49:14 -0800 Subject: [PATCH] Fix missing before_tool_call hook integration (#6570) * Fix missing before_tool_call hook integration - Add hook call in handleToolExecutionStart before tool execution begins - Support parameter modification via hookResult.params - Support tool call blocking via hookResult.block with custom blockReason - Fix try/catch logic to properly re-throw blocking errors using __isHookBlocking flag - Maintain tool event consistency by emitting start/end events when blocked - Addresses GitHub issue #6535 (1 of 8 unimplemented hooks now working) Co-Authored-By: Claude Sonnet 4 * Add comprehensive test suite for before_tool_call hook - 9 tests covering all hook scenarios: no hooks, parameter passing, modification, blocking, error handling - Tests tool name normalization and different argument types - Verifies proper error re-throwing and logging behavior - Maintained in fork for regression testing * Fix all issues identified by Greptile code review Address P0/P1/P3 bugs: P0 - Fix parameter mutation crash for non-object args: - Normalize args to objects before passing to hooks (maintains hook contract) - Handle parameter merging safely for both object and non-object args P1 - Add missing internal state updates when blocking tools: - Set toolMetaById metadata like normal flow - Call onAgentEvent callback to maintain consistency - Emit events in same order as normal tool execution P1 - Fix test expectations to match implementation reality: - Non-object args normalized to {} for hook params (not passed as-is) - Add test for safe parameter modification with various arg types - Update mocks to verify state updates when blocking P3 - Replace magic __isHookBlocking property with dedicated ToolBlockedError class: - More robust error handling without property collision risk - Cleaner control flow that's serialization-safe Co-Authored-By: Claude Sonnet 4 --------- Co-authored-by: Claude Sonnet 4 --- ...ndlers.tools.before-tool-call-hook.test.ts | 351 ++++++++++++++++++ .../pi-embedded-subscribe.handlers.tools.ts | 98 ++++- 2 files changed, 448 insertions(+), 1 deletion(-) create mode 100644 src/agents/pi-embedded-subscribe.handlers.tools.before-tool-call-hook.test.ts diff --git a/src/agents/pi-embedded-subscribe.handlers.tools.before-tool-call-hook.test.ts b/src/agents/pi-embedded-subscribe.handlers.tools.before-tool-call-hook.test.ts new file mode 100644 index 000000000..02e93c964 --- /dev/null +++ b/src/agents/pi-embedded-subscribe.handlers.tools.before-tool-call-hook.test.ts @@ -0,0 +1,351 @@ +import type { AgentEvent } from "@mariozechner/pi-agent-core"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js"; +import { getGlobalHookRunner } from "../plugins/hook-runner-global.js"; +import { handleToolExecutionStart } from "./pi-embedded-subscribe.handlers.tools.js"; + +// Mock dependencies +vi.mock("../plugins/hook-runner-global.js"); +vi.mock("../infra/agent-events.js", () => ({ + emitAgentEvent: vi.fn(), +})); +vi.mock("./pi-embedded-helpers.js"); +vi.mock("./pi-embedded-messaging.js"); +vi.mock("./pi-embedded-subscribe.tools.js"); +vi.mock("./pi-embedded-utils.js", () => ({ + inferToolMetaFromArgs: vi.fn(() => undefined), +})); +vi.mock("./tool-policy.js", () => ({ + normalizeToolName: vi.fn((name: string) => name.toLowerCase()), +})); + +const mockGetGlobalHookRunner = vi.mocked(getGlobalHookRunner); + +describe("before_tool_call hook integration", () => { + let mockContext: EmbeddedPiSubscribeContext; + let mockHookRunner: any; + + beforeEach(() => { + // Reset mocks + vi.clearAllMocks(); + + // Mock context + mockContext = { + params: { + runId: "test-run-123", + session: { key: "test-session" }, + onBlockReplyFlush: vi.fn(), + onAgentEvent: vi.fn(), + }, + state: { + toolMetaById: { + set: vi.fn(), + get: vi.fn(), + has: vi.fn(), + }, + }, + log: { + debug: vi.fn(), + warn: vi.fn(), + }, + flushBlockReplyBuffer: vi.fn(), + shouldEmitToolResult: vi.fn().mockReturnValue(true), + } as any; + + // Mock hook runner + mockHookRunner = { + hasHooks: vi.fn(), + runBeforeToolCall: vi.fn(), + }; + + mockGetGlobalHookRunner.mockReturnValue(mockHookRunner); + }); + + describe("when no hooks are registered", () => { + beforeEach(() => { + mockHookRunner.hasHooks.mockReturnValue(false); + }); + + it("should proceed with tool execution normally", async () => { + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: "tool-call-123", + args: { param: "value" }, + }; + + // Should not throw + await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined(); + + // Hook runner should check for hooks but not run them + expect(mockHookRunner.hasHooks).toHaveBeenCalledWith("before_tool_call"); + expect(mockHookRunner.runBeforeToolCall).not.toHaveBeenCalled(); + }); + }); + + describe("when hooks are registered", () => { + beforeEach(() => { + mockHookRunner.hasHooks.mockReturnValue(true); + }); + + it("should call the hook with correct parameters", async () => { + mockHookRunner.runBeforeToolCall.mockResolvedValue(undefined); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: "tool-call-123", + args: { param: "value" }, + }; + + await handleToolExecutionStart(mockContext, event); + + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith( + { + toolName: "testtool", // normalized + params: { param: "value" }, + }, + { + toolName: "testtool", + }, + ); + }); + + it("should allow hook to modify parameters", async () => { + const modifiedParams = { param: "modified_value", newParam: "added" }; + mockHookRunner.runBeforeToolCall.mockResolvedValue({ + params: modifiedParams, + }); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: "tool-call-123", + args: { param: "value" }, + }; + + // The function should complete without error + await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined(); + + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith( + { + toolName: "testtool", + params: { param: "value" }, + }, + { + toolName: "testtool", + }, + ); + + // Hook should be called and parameter modification should work + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalled(); + }); + + it("should handle parameter modification with non-object args safely", async () => { + const modifiedParams = { newParam: "replaced" }; + mockHookRunner.runBeforeToolCall.mockResolvedValue({ + params: modifiedParams, + }); + + const testCases = [ + { args: null, description: "null args" }, + { args: "string", description: "string args" }, + { args: 123, description: "number args" }, + { args: [1, 2, 3], description: "array args" }, + ]; + + for (const { args, description } of testCases) { + mockHookRunner.runBeforeToolCall.mockClear(); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: `call-${description}`, + args, + }; + + // Should not crash even with non-object args + await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined(); + + // Hook should be called with normalized empty params + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith( + { + toolName: "testtool", + params: {}, // Non-objects normalized to empty object + }, + { + toolName: "testtool", + }, + ); + } + }); + + it("should block tool call when hook returns block=true", async () => { + const blockReason = "Tool blocked by security policy"; + const mockResult = { + block: true, + blockReason, + }; + + mockHookRunner.runBeforeToolCall.mockResolvedValue(mockResult); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "BlockedTool", + toolCallId: "tool-call-456", + args: { dangerous: "payload" }, + }; + + // Should throw an error with the block reason + await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow(blockReason); + + // Should log the block + expect(mockContext.log.debug).toHaveBeenCalledWith( + expect.stringContaining("Tool call blocked by plugin hook"), + ); + expect(mockContext.log.debug).toHaveBeenCalledWith(expect.stringContaining(blockReason)); + + // Should update internal state like normal tool flow + expect(mockContext.state.toolMetaById.set).toHaveBeenCalled(); + expect(mockContext.params.onAgentEvent).toHaveBeenCalledWith({ + stream: "tool", + data: { phase: "start", name: "blockedtool", toolCallId: "tool-call-456" }, + }); + }); + + it("should block tool call with default reason when no blockReason provided", async () => { + mockHookRunner.runBeforeToolCall.mockResolvedValue({ + block: true, + // no blockReason + }); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "BlockedTool", + toolCallId: "tool-call-789", + args: {}, + }; + + // Should throw with default message + await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow( + "Tool call blocked by plugin hook", + ); + }); + + it("should handle hook errors gracefully and continue execution", async () => { + const hookError = new Error("Hook implementation error"); + mockHookRunner.runBeforeToolCall.mockRejectedValue(hookError); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: "tool-call-999", + args: { param: "value" }, + }; + + // Should not throw - hook errors should be caught + await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined(); + + // Should log the hook error + expect(mockContext.log.warn).toHaveBeenCalledWith( + expect.stringContaining("before_tool_call hook failed"), + ); + expect(mockContext.log.warn).toHaveBeenCalledWith( + expect.stringContaining("Hook implementation error"), + ); + }); + + it("should re-throw blocking errors even when caught", async () => { + const blockReason = "Blocked by security"; + mockHookRunner.runBeforeToolCall.mockResolvedValue({ + block: true, + blockReason, + }); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: "tool-call-000", + args: {}, + }; + + // The blocking error should still be thrown + await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow(blockReason); + }); + }); + + describe("hook context handling", () => { + beforeEach(() => { + mockHookRunner.hasHooks.mockReturnValue(true); + mockHookRunner.runBeforeToolCall.mockResolvedValue(undefined); + }); + + it("should handle various tool name formats", async () => { + const testCases = [ + { input: "ReadFile", expected: "readfile" }, + { input: "EXEC", expected: "exec" }, + { input: "bash-command", expected: "bash-command" }, + { input: " SpacedTool ", expected: " spacedtool " }, + ]; + + for (const { input, expected } of testCases) { + mockHookRunner.runBeforeToolCall.mockClear(); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: input, + toolCallId: `call-${input}`, + args: {}, + }; + + await handleToolExecutionStart(mockContext, event); + + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith( + { + toolName: expected, + params: {}, + }, + { + toolName: expected, + }, + ); + } + }); + + it("should handle different argument types", async () => { + const testCases = [ + // Non-objects get normalized to {} for hook params (to maintain hook contract) + { args: null, expectedParams: {} }, + { args: undefined, expectedParams: {} }, + { args: "string", expectedParams: {} }, + { args: 123, expectedParams: {} }, + { args: [1, 2, 3], expectedParams: {} }, // arrays are not plain objects + // Only plain objects are passed through + { args: { key: "value" }, expectedParams: { key: "value" } }, + ]; + + for (const { args, expectedParams } of testCases) { + mockHookRunner.runBeforeToolCall.mockClear(); + + const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = { + type: "tool_start", + toolName: "TestTool", + toolCallId: `call-${typeof args}`, + args, + }; + + await handleToolExecutionStart(mockContext, event); + + expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith( + { + toolName: "testtool", + params: expectedParams, + }, + { + toolName: "testtool", + }, + ); + } + }); + }); +}); diff --git a/src/agents/pi-embedded-subscribe.handlers.tools.ts b/src/agents/pi-embedded-subscribe.handlers.tools.ts index 39dc8d8fa..b73109c72 100644 --- a/src/agents/pi-embedded-subscribe.handlers.tools.ts +++ b/src/agents/pi-embedded-subscribe.handlers.tools.ts @@ -1,7 +1,16 @@ import type { AgentEvent } from "@mariozechner/pi-agent-core"; import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js"; import { emitAgentEvent } from "../infra/agent-events.js"; +import { getGlobalHookRunner } from "../plugins/hook-runner-global.js"; import { normalizeTextForComparison } from "./pi-embedded-helpers.js"; + +// Dedicated error class for hook blocking to avoid magic property issues +class ToolBlockedError extends Error { + constructor(message: string) { + super(message); + this.name = "ToolBlockedError"; + } +} import { isMessagingTool, isMessagingToolSendAction } from "./pi-embedded-messaging.js"; import { extractToolErrorMessage, @@ -49,7 +58,94 @@ export async function handleToolExecutionStart( const rawToolName = String(evt.toolName); const toolName = normalizeToolName(rawToolName); const toolCallId = String(evt.toolCallId); - const args = evt.args; + let args = evt.args; + + // Run before_tool_call hook - allows plugins to modify or block tool calls + const hookRunner = getGlobalHookRunner(); + if (hookRunner?.hasHooks("before_tool_call")) { + try { + // Normalize args to object for hook contract - plugins expect params to be an object + const normalizedParams = + args && typeof args === "object" && !Array.isArray(args) + ? (args as Record) + : {}; + + const hookResult = await hookRunner.runBeforeToolCall( + { + toolName, + params: normalizedParams, + }, + { + toolName, + }, + ); + + // Check if hook blocked the tool call + if (hookResult?.block) { + const blockReason = hookResult.blockReason || "Tool call blocked by plugin hook"; + + // Update internal state to match normal tool execution flow + const meta = extendExecMeta(toolName, args, inferToolMetaFromArgs(toolName, args)); + ctx.state.toolMetaById.set(toolCallId, meta); + + ctx.log.debug( + `Tool call blocked by plugin hook: runId=${ctx.params.runId} tool=${toolName} toolCallId=${toolCallId} reason=${blockReason}`, + ); + + // Emit tool start/end events with error to maintain event consistency + emitAgentEvent({ + runId: ctx.params.runId, + stream: "tool", + data: { + phase: "start", + name: toolName, + toolCallId, + args: args as Record, + }, + }); + + // Call onAgentEvent callback to match normal flow + void ctx.params.onAgentEvent?.({ + stream: "tool", + data: { phase: "start", name: toolName, toolCallId }, + }); + + emitAgentEvent({ + runId: ctx.params.runId, + stream: "tool", + data: { + phase: "end", + name: toolName, + toolCallId, + error: blockReason, + }, + }); + + // Throw dedicated error class instead of using magic properties + throw new ToolBlockedError(blockReason); + } + + // If hook modified params, update args safely + if (hookResult?.params) { + if (args && typeof args === "object" && !Array.isArray(args)) { + // Safe to merge with existing object args + args = { ...(args as Record), ...hookResult.params }; + } else { + // For non-object args, replace entirely with hook params + args = hookResult.params; + } + } + } catch (err) { + // If it's our blocking error, re-throw it + if (err instanceof ToolBlockedError) { + throw err; + } + // For other hook errors, log but don't block the tool call + ctx.log.warn( + `before_tool_call hook failed: runId=${ctx.params.runId} tool=${toolName} toolCallId=${toolCallId} error=${String(err)}`, + ); + } + } if (toolName === "read") { const record = args && typeof args === "object" ? (args as Record) : {};