Skip to content

pydantic_ai.tools

AgentDeps module-attribute

AgentDeps = TypeVar('AgentDeps')

Type variable for agent dependencies.

RunContext dataclass

Bases: Generic[AgentDeps]

Information about the current call.

Source code in pydantic_ai_slim/pydantic_ai/tools.py
37
38
39
40
41
42
43
44
45
46
@dataclass
class RunContext(Generic[AgentDeps]):
    """Information about the current call."""

    deps: AgentDeps
    """Dependencies for the agent."""
    retry: int
    """Number of retries so far."""
    tool_name: str | None
    """Name of the tool being called."""

deps instance-attribute

deps: AgentDeps

Dependencies for the agent.

retry instance-attribute

retry: int

Number of retries so far.

tool_name instance-attribute

tool_name: str | None

Name of the tool being called.

ToolParams module-attribute

ToolParams = ParamSpec('ToolParams')

Retrieval function param spec.

SystemPromptFunc module-attribute

A function that may or maybe not take RunContext as an argument, and may or may not be async.

Usage SystemPromptFunc[AgentDeps].

ResultValidatorFunc module-attribute

A function that always takes ResultData and returns ResultData, but may or maybe not take CallInfo as a first argument, and may or may not be async.

Usage ResultValidator[AgentDeps, ResultData].

ToolFuncContext module-attribute

A tool function that takes RunContext as the first argument.

Usage ToolContextFunc[AgentDeps, ToolParams].

ToolFuncPlain module-attribute

ToolFuncPlain = Callable[ToolParams, Any]

A tool function that does not take RunContext as the first argument.

Usage ToolPlainFunc[ToolParams].

ToolFuncEither module-attribute

Either kind of tool function.

This is just a union of ToolFuncContext and ToolFuncPlain.

Usage ToolFuncEither[AgentDeps, ToolParams].

Tool dataclass

Bases: Generic[AgentDeps]

A tool function for an agent.

Source code in pydantic_ai_slim/pydantic_ai/tools.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@final
@dataclass(init=False)
class Tool(Generic[AgentDeps]):
    """A tool function for an agent."""

    function: ToolFuncEither[AgentDeps, ...]
    takes_ctx: bool
    max_retries: int | None
    name: str
    description: str
    _is_async: bool = field(init=False)
    _single_arg_name: str | None = field(init=False)
    _positional_fields: list[str] = field(init=False)
    _var_positional_field: str | None = field(init=False)
    _validator: SchemaValidator = field(init=False, repr=False)
    _json_schema: _utils.ObjectJsonSchema = field(init=False)
    _current_retry: int = field(default=0, init=False)

    def __init__(
        self,
        function: ToolFuncEither[AgentDeps, ...],
        takes_ctx: bool,
        *,
        max_retries: int | None = None,
        name: str | None = None,
        description: str | None = None,
    ):
        """Create a new tool instance.

        Example usage:

        ```py
        from pydantic_ai import Agent, RunContext, Tool

        async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
            return f'{ctx.deps} {x} {y}'

        agent = Agent('test', tools=[Tool(my_tool, True)])
        ```

        Args:
            function: The Python function to call as the tool.
            takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
            max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
            name: Name of the tool, inferred from the function if `None`.
            description: Description of the tool, inferred from the function if `None`.
        """
        f = _pydantic.function_schema(function, takes_ctx)
        self.function = function
        self.takes_ctx = takes_ctx
        self.max_retries = max_retries
        self.name = name or function.__name__
        self.description = description or f['description']
        self._is_async = inspect.iscoroutinefunction(self.function)
        self._single_arg_name = f['single_arg_name']
        self._positional_fields = f['positional_fields']
        self._var_positional_field = f['var_positional_field']
        self._validator = f['validator']
        self._json_schema = f['json_schema']

    @staticmethod
    def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
        """Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.

        Args:
            function: The tool function to wrap; or for convenience, a `Tool` instance.

        Returns:
            A new `Tool` instance.
        """
        if isinstance(function, Tool):
            return function
        else:
            return Tool(function, takes_ctx=_pydantic.takes_ctx(function))

    def reset(self) -> None:
        """Reset the current retry count."""
        self._current_retry = 0

    async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
        """Run the tool function asynchronously."""
        try:
            if isinstance(message.args, messages.ArgsJson):
                args_dict = self._validator.validate_json(message.args.args_json)
            else:
                args_dict = self._validator.validate_python(message.args.args_dict)
        except ValidationError as e:
            return self._on_error(e, message)

        args, kwargs = self._call_args(deps, args_dict, message)
        try:
            if self._is_async:
                function = cast(Callable[[Any], Awaitable[str]], self.function)
                response_content = await function(*args, **kwargs)
            else:
                function = cast(Callable[[Any], str], self.function)
                response_content = await _utils.run_in_executor(function, *args, **kwargs)
        except ModelRetry as e:
            return self._on_error(e, message)

        self._current_retry = 0
        return messages.ToolReturn(
            tool_name=message.tool_name,
            content=response_content,
            tool_id=message.tool_id,
        )

    @property
    def json_schema(self) -> _utils.ObjectJsonSchema:
        return self._json_schema

    @property
    def outer_typed_dict_key(self) -> str | None:
        return None

    def _call_args(
        self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
    ) -> tuple[list[Any], dict[str, Any]]:
        if self._single_arg_name:
            args_dict = {self._single_arg_name: args_dict}

        args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else []
        for positional_field in self._positional_fields:
            args.append(args_dict.pop(positional_field))
        if self._var_positional_field:
            args.extend(args_dict.pop(self._var_positional_field))

        return args, args_dict

    def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
        self._current_retry += 1
        if self.max_retries is None or self._current_retry > self.max_retries:
            raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
        else:
            if isinstance(exc, ValidationError):
                content = exc.errors(include_url=False)
            else:
                content = exc.message
            return messages.RetryPrompt(
                tool_name=call_message.tool_name,
                content=content,
                tool_id=call_message.tool_id,
            )

__init__

__init__(
    function: ToolFuncEither[AgentDeps, ...],
    takes_ctx: bool,
    *,
    max_retries: int | None = None,
    name: str | None = None,
    description: str | None = None
)

Create a new tool instance.

Example usage:

from pydantic_ai import Agent, RunContext, Tool

async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
    return f'{ctx.deps} {x} {y}'

agent = Agent('test', tools=[Tool(my_tool, True)])

Parameters:

Name Type Description Default
function ToolFuncEither[AgentDeps, ...]

The Python function to call as the tool.

required
takes_ctx bool

Whether the function takes a RunContext first argument.

required
max_retries int | None

Maximum number of retries allowed for this tool, set to the agent default if None.

None
name str | None

Name of the tool, inferred from the function if None.

None
description str | None

Description of the tool, inferred from the function if None.

None
Source code in pydantic_ai_slim/pydantic_ai/tools.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self,
    function: ToolFuncEither[AgentDeps, ...],
    takes_ctx: bool,
    *,
    max_retries: int | None = None,
    name: str | None = None,
    description: str | None = None,
):
    """Create a new tool instance.

    Example usage:

    ```py
    from pydantic_ai import Agent, RunContext, Tool

    async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
        return f'{ctx.deps} {x} {y}'

    agent = Agent('test', tools=[Tool(my_tool, True)])
    ```

    Args:
        function: The Python function to call as the tool.
        takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
        max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
        name: Name of the tool, inferred from the function if `None`.
        description: Description of the tool, inferred from the function if `None`.
    """
    f = _pydantic.function_schema(function, takes_ctx)
    self.function = function
    self.takes_ctx = takes_ctx
    self.max_retries = max_retries
    self.name = name or function.__name__
    self.description = description or f['description']
    self._is_async = inspect.iscoroutinefunction(self.function)
    self._single_arg_name = f['single_arg_name']
    self._positional_fields = f['positional_fields']
    self._var_positional_field = f['var_positional_field']
    self._validator = f['validator']
    self._json_schema = f['json_schema']

infer staticmethod

infer(
    function: ToolFuncEither[A, ...] | Tool[A]
) -> Tool[A]

Create a tool from a pure function, inferring whether it takes RunContext as its first argument.

Parameters:

Name Type Description Default
function ToolFuncEither[A, ...] | Tool[A]

The tool function to wrap; or for convenience, a Tool instance.

required

Returns:

Type Description
Tool[A]

A new Tool instance.

Source code in pydantic_ai_slim/pydantic_ai/tools.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@staticmethod
def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
    """Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.

    Args:
        function: The tool function to wrap; or for convenience, a `Tool` instance.

    Returns:
        A new `Tool` instance.
    """
    if isinstance(function, Tool):
        return function
    else:
        return Tool(function, takes_ctx=_pydantic.takes_ctx(function))

reset

reset() -> None

Reset the current retry count.

Source code in pydantic_ai_slim/pydantic_ai/tools.py
173
174
175
def reset(self) -> None:
    """Reset the current retry count."""
    self._current_retry = 0

run async

run(deps: AgentDeps, message: ToolCall) -> Message

Run the tool function asynchronously.

Source code in pydantic_ai_slim/pydantic_ai/tools.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
    """Run the tool function asynchronously."""
    try:
        if isinstance(message.args, messages.ArgsJson):
            args_dict = self._validator.validate_json(message.args.args_json)
        else:
            args_dict = self._validator.validate_python(message.args.args_dict)
    except ValidationError as e:
        return self._on_error(e, message)

    args, kwargs = self._call_args(deps, args_dict, message)
    try:
        if self._is_async:
            function = cast(Callable[[Any], Awaitable[str]], self.function)
            response_content = await function(*args, **kwargs)
        else:
            function = cast(Callable[[Any], str], self.function)
            response_content = await _utils.run_in_executor(function, *args, **kwargs)
    except ModelRetry as e:
        return self._on_error(e, message)

    self._current_retry = 0
    return messages.ToolReturn(
        tool_name=message.tool_name,
        content=response_content,
        tool_id=message.tool_id,
    )