Source code for apple_fm_sdk.tool
# For licensing see accompanying LICENSE file.
# Copyright (C) 2026 Apple Inc. All Rights Reserved.
import asyncio
import threading
import logging
from .generation_schema import GenerationSchema
from .generable import GeneratedContent
from .c_helpers import _ManagedObject, _get_error_string
import ctypes
from abc import ABC, abstractmethod
from .errors import _status_code_to_exception
logger = logging.getLogger(__name__)
try:
from . import _ctypes_bindings as lib
except ImportError:
raise ImportError(
"Foundation Models C bindings not found. Please ensure _foundationmodels_ctypes.py is available."
)
[docs]
class Tool(_ManagedObject, ABC):
"""Base class for creating tools that foundation models can invoke during generation.
A ``Tool`` bridges Python async functions to the Foundation Models API, enabling
foundation models to perform actions like calculations, API calls, database queries,
or any other programmatic operations during generation.
**Tool Lifecycle:**
1. **Definition**: Subclass Tool and implement required methods/properties
2. **Registration**: Pass tool instances to LanguageModelSession
3. **Invocation**: Model automatically calls tools when appropriate
4. **Execution**: Your async ``call()`` method executes with parsed arguments
5. **Response**: Tool result is returned to the model to continue generation
**Callback Mechanism:**
Tools use an async callback system that:
- Automatically handles argument parsing from GeneratedContent
- Executes your ``call()`` method in the appropriate async context
- Manages threading and event loops transparently
- Returns results or errors back to the model
**Async Requirements:**
The ``call()`` method MUST be an async function (coroutine). This allows tools to:
- Make async API calls without blocking
- Perform I/O operations efficiently
- Run concurrent operations when needed
- Integrate with async frameworks
**Error Handling:**
- Exceptions in ``call()`` are caught and reported to the model
- The model receives error messages and can adapt its response
- Tools should raise descriptive exceptions for better model understanding
Examples:
Simple calculator tool::
import apple_fm_sdk as fm
@fm.generable("Calculator parameters")
class CalculatorParams:
operation: str = fm.guide("The operation to perform")
a: float = fm.guide("First number")
b: float = fm.guide("Second number")
class CalculatorTool(fm.Tool):
name = "calculator"
description = "Performs basic arithmetic operations"
@property
def arguments_schema(self) -> fm.GenerationSchema:
return CalculatorParams.generation_schema()
async def call(self, args: fm.GeneratedContent) -> str:
op = args.value(str, for_property="operation")
a = args.value(float, for_property="a")
b = args.value(float, for_property="b")
if op == "add":
result = a + b
elif op == "multiply":
result = a * b
else:
raise ValueError(f"Unknown operation: {op}")
return str(result)
Tool with async API call::
import aiohttp
import apple_fm_sdk as fm
@fm.generable("Weather parameters")
class WeatherParams:
city: str = fm.guide("The city to get weather for")
units: str = fm.guide("Temperature units (metric or imperial)")
class WeatherTool(fm.Tool):
name = "get_weather"
description = "Gets current weather for a city"
@property
def arguments_schema(self) -> fm.GenerationSchema:
return WeatherParams.generation_schema()
async def call(self, args: fm.GeneratedContent) -> str:
city = args.value(str, for_property="city")
try:
units = args.value(str, for_property="units")
except Exception:
units = "metric"
# Implement async API call to fetch weather here
return "Sunny, 25°C" # Placeholder response
Tool with error handling::
import apple_fm_sdk as fm
@fm.generable("Database query parameters")
class DatabaseParams:
user_id: int = fm.guide("The user ID to query")
class DatabaseTool(fm.Tool):
name = "query_database"
description = "Queries the user database"
@property
def arguments_schema(self) -> fm.GenerationSchema:
return DatabaseParams.generation_schema()
async def call(self, args: fm.GeneratedContent) -> str:
user_id = args.value(int, for_property="user_id")
# Implement database query with error handling here
return f"User data for ID {user_id}" # Placeholder response
Using tools in a session::
from apple_fm_sdk import LanguageModelSession
session = LanguageModelSession(
instructions="You are a helpful assistant with access to tools.",
tools=[CalculatorTool(), WeatherTool(), DatabaseTool()]
)
# Model will automatically use tools when appropriate
response = await session.respond("What's 15% of 240?")
# Model invokes CalculatorTool internally
Attributes:
name: The tool's name (must be set by subclass)
description: Human-readable description of what the tool does (must be set by subclass)
Note:
- Tool names should be descriptive and follow snake_case convention
- Descriptions should explain the tool's purpose and when to use it
- The ``call()`` method must be async even if it doesn't perform async operations
- Tools are automatically managed by the session's lifecycle
- Multiple tools can be registered with a single session
See Also:
- :class:`~apple_fm_sdk.session.LanguageModelSession`: For using tools in sessions
- :class:`~apple_fm_sdk.generation_schema.GenerationSchema`: For defining argument schemas
- :class:`~apple_fm_sdk.generable.GeneratedContent`: For accessing parsed arguments
"""
name: str
description: str
@property
@abstractmethod
def arguments_schema(self) -> GenerationSchema:
"""Define the schema for tool arguments.
This property must return a GenerationSchema that describes the structure
and types of arguments the tool expects. The model uses this schema to
generate properly formatted arguments when invoking the tool.
:return: Schema defining the tool's expected arguments
:rtype: GenerationSchema
Example:
::
import apple_fm_sdk as fm
@fm.generable("Search parameters")
class SearchParams:
query: str = fm.guide("The search query")
limit: int = fm.guide("Maximum number of results")
@property
def arguments_schema(self) -> fm.GenerationSchema:
return SearchParams.generation_schema()
"""
pass
[docs]
@abstractmethod
async def call(self, args: GeneratedContent) -> str:
"""Execute the tool's functionality with the provided arguments.
This async method is invoked when the model decides to use the tool.
The arguments are automatically parsed according to the ``arguments_schema``
and provided as a GeneratedContent object.
:param args: Parsed arguments as GeneratedContent. Access values via ``args.value``
which contains a dictionary matching your schema structure.
:type args: GeneratedContent
:return: The tool's result as a string. This result is provided back to the
model to inform its continued generation.
:rtype: str
:raises Exception: Any exception raised will be caught and reported to the model
as an error message. Use descriptive exceptions to help the model
understand what went wrong.
Example:
::
async def call(self, args: fm.GeneratedContent) -> str:
query = args.value(str, for_property="query")
try:
limit = args.value(int, for_property="limit")
except Exception:
limit = 10
# Perform async operation, for example, database search or another session call here
return f"Results for '{query}' with limit {limit}" # Placeholder response
Note:
- Must be an async function even if no async operations are performed
- Return value must be a string (convert other types as needed)
- Exceptions are automatically handled and reported to the model
"""
pass
def __init__(self):
# Verify the subclass implementation
self._verify_subclass_()
# Store the async callable
self._async_callable = self.call
self._pending_calls = {} # Maps call_id to future
self._call_lock = threading.Lock()
# Create the C callback function type matching the bindings
# UNCHECKED(None) in the bindings returns ctypes.c_void_p
CallbackType = ctypes.CFUNCTYPE(
ctypes.c_void_p, lib.FMGeneratedContentRef, ctypes.c_uint
)
# Create the actual callback function
def _c_callback_impl(content_ref, call_id):
"""C callback that gets invoked when the tool is called."""
try:
# Create GeneratedContent from the C pointer
# Swift passes the pointer with ownership already transferred (passRetained)
# so we don't need to manually retain it here
generated_content = GeneratedContent(_ptr=content_ref)
# Run the async callable in a new task
async def _run_async_callable():
try:
# Call the tool subclass's async function
result = await self._async_callable(generated_content)
# Convert result to string if needed
if not isinstance(result, str):
result = str(result)
# Finish the tool call with the result
result_bytes = result.encode("utf-8")
lib.FMBridgedToolFinishCall(self._ptr, call_id, result_bytes)
except Exception as e:
# On error, finish with error message
error_msg = f"Tool error: {str(e)}"
error_bytes = error_msg.encode("utf-8")
lib.FMBridgedToolFinishCall(self._ptr, call_id, error_bytes)
# Schedule the async callable
# Try to get the current running loop, or create a new one
try:
loop = asyncio.get_running_loop() # noqa: F841 this unused variable is needed to check if a loop is running
asyncio.create_task(_run_async_callable())
except RuntimeError:
# No running loop - create a new thread with event loop
def _run_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_run_async_callable())
finally:
loop.close()
thread = threading.Thread(target=_run_in_thread, daemon=True)
thread.start()
except Exception as e:
# Catch-all error handler
error_msg = f"Callback error: {str(e)}"
error_bytes = error_msg.encode("utf-8")
try:
lib.FMBridgedToolFinishCall(self._ptr, call_id, error_bytes)
except Exception:
raise
# Wrap the callback implementation with the callback type
_c_callback = CallbackType(_c_callback_impl)
# Store the callback to prevent garbage collection
self._c_callback = _c_callback
# Initialize _ptr to None before calling super().__init__() to avoid AttributeError in __del__
self._ptr = None
# Create the bridged tool using the C API
name_bytes = self.name.encode("utf-8")
description_bytes = self.description.encode("utf-8")
# Store the schema to keep it alive (prevents deallocation before FMBridgedToolCreate completes)
# This is necessary because arguments_schema is a property that returns a new object each time
self._arguments_schema = self.arguments_schema
# Prepare error handling parameters
error_code = ctypes.c_int()
error_description = ctypes.POINTER(ctypes.c_char)()
ptr = lib.FMBridgedToolCreate(
name_bytes,
description_bytes,
self._arguments_schema._ptr,
self._c_callback,
ctypes.byref(error_code),
ctypes.byref(error_description),
)
# Check for errors
if not ptr:
err_code, err_desc = _get_error_string(error_code, error_description)
error_msg = "Failed to create bridged tool"
if err_desc:
error_msg = error_msg + ": " + err_desc
raise _status_code_to_exception(err_code or error_code.value, error_msg)
super().__init__(ptr)
def _verify_subclass_(self):
assert hasattr(self, "name"), "Tool subclass must have a 'name' property."
assert hasattr(self, "description"), (
"Tool subclass must have a 'description' property."
)
assert hasattr(self, "arguments_schema"), (
"Tool subclass must have an 'arguments_schema' property."
)
assert hasattr(self, "call"), "Tool subclass must implement the 'call' method."
if not isinstance(self.name, str):
raise TypeError("Tool name must be a string.")
if not isinstance(self.description, str):
raise TypeError("Tool description must be a string.")
if not isinstance(self.arguments_schema, GenerationSchema):
raise TypeError(
"Tool arguments_schema must be a GenerationSchema instance."
)
if not asyncio.iscoroutinefunction(self.call):
raise TypeError("Tool call method must be an async function.")