diff --git a/pyproject.toml b/pyproject.toml index 5815b971f5..9a37dc492c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,8 +123,9 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ - "google-cloud-firestore>=2.11.0, <3.0.0", "google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0", + "docker>=7.0.0", + "google-cloud-firestore>=2.11.0", "google-cloud-parametermanager>=0.4.0, <1.0.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", diff --git a/src/google/adk/code_executors/container_code_executor.py b/src/google/adk/code_executors/container_code_executor.py index d6a78d4d26..ccfeebe253 100644 --- a/src/google/adk/code_executors/container_code_executor.py +++ b/src/google/adk/code_executors/container_code_executor.py @@ -15,8 +15,10 @@ from __future__ import annotations import atexit +import io import logging import os +import tarfile from typing import Optional import docker @@ -29,9 +31,10 @@ from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult +from .code_execution_utils import File -logger = logging.getLogger('google_adk.' + __name__) -DEFAULT_IMAGE_TAG = 'adk-code-executor:latest' +logger = logging.getLogger("google_adk." + __name__) +DEFAULT_IMAGE_TAG = "adk-code-executor:latest" class ContainerCodeExecutor(BaseCodeExecutor): @@ -44,6 +47,9 @@ class ContainerCodeExecutor(BaseCodeExecutor): docker_path: The path to the directory containing the Dockerfile. If set, build the image from the dockerfile path instead of using the predefined image. Either docker_path or image must be set. + input_dir: The directory in the container where input files will be placed. + output_dir: The directory in the container where output files will be + retrieved from. """ base_url: Optional[str] = None @@ -61,15 +67,22 @@ class ContainerCodeExecutor(BaseCodeExecutor): """ The path to the directory containing the Dockerfile. If set, build the image from the dockerfile path instead of using the - predefined image. Either docker_path or image must be set. + predefined image. Either docker_path or image must be set . + """ + + input_dir: str = "/tmp/inputs" + """ + The directory in the container where input files will be placed. + """ + + output_dir: str = "/tmp/outputs" + """ + The directory in the container where output files will be retrieved from. """ - # Overrides the BaseCodeExecutor attribute: this executor cannot be stateful. stateful: bool = Field(default=False, frozen=True, exclude=True) - # Overrides the BaseCodeExecutor attribute: this executor cannot - # optimize_data_file. - optimize_data_file: bool = Field(default=False, frozen=True, exclude=True) + optimize_data_file: bool = Field(default=True, frozen=True, exclude=True) _client: DockerClient = None _container: Container = None @@ -79,6 +92,8 @@ def __init__( base_url: Optional[str] = None, image: Optional[str] = None, docker_path: Optional[str] = None, + input_dir: Optional[str] = None, + output_dir: Optional[str] = None, **data, ): """Initializes the ContainerCodeExecutor. @@ -90,33 +105,33 @@ def __init__( docker_path: The path to the directory containing the Dockerfile. If set, build the image from the dockerfile path instead of using the predefined image. Either docker_path or image must be set. + input_dir: The directory in the container where input files will be placed. + Defaults to '/tmp/inputs'. + output_dir: The directory in the container where output files will be + retrieved from. Defaults to '/tmp/outputs'. **data: The data to initialize the ContainerCodeExecutor. """ if not image and not docker_path: raise ValueError( - 'Either image or docker_path must be set for ContainerCodeExecutor.' - ) - if 'stateful' in data and data['stateful']: - raise ValueError('Cannot set `stateful=True` in ContainerCodeExecutor.') - if 'optimize_data_file' in data and data['optimize_data_file']: - raise ValueError( - 'Cannot set `optimize_data_file=True` in ContainerCodeExecutor.' + "Either image or docker_path must be set for ContainerCodeExecutor." ) + if "stateful" in data and data["stateful"]: + raise ValueError("Cannot set `stateful=True` in ContainerCodeExecutor.") super().__init__(**data) self.base_url = base_url self.image = image if image else DEFAULT_IMAGE_TAG self.docker_path = os.path.abspath(docker_path) if docker_path else None + self.input_dir = input_dir if input_dir else "/tmp/inputs" + self.output_dir = output_dir if output_dir else "/tmp/outputs" self._client = ( docker.from_env() if not self.base_url else docker.DockerClient(base_url=self.base_url) ) - # Initialize the container. self.__init_container() - # Close the container when the on exit. atexit.register(self.__cleanup_container) @override @@ -125,68 +140,165 @@ def execute_code( invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: - output = '' - error = '' + if code_execution_input.input_files: + self._put_input_files(code_execution_input.input_files) + + self._create_output_directory() + + output = "" + error = "" exec_result = self._container.exec_run( - ['python3', '-c', code_execution_input.code], + ["python3", "-c", code_execution_input.code], demux=True, ) - logger.debug('Executed code:\n```\n%s\n```', code_execution_input.code) + logger.debug("Executed code:\n```\n%s\n```", code_execution_input.code) if exec_result.output and exec_result.output[0]: - output = exec_result.output[0].decode('utf-8') + output = exec_result.output[0].decode("utf-8") if ( exec_result.output and len(exec_result.output) > 1 and exec_result.output[1] ): - error = exec_result.output[1].decode('utf-8') + error = exec_result.output[1].decode("utf-8") + + output_files = self._get_output_files() - # Collect the final result. return CodeExecutionResult( stdout=output, stderr=error, - output_files=[], + output_files=output_files, ) def _build_docker_image(self): """Builds the Docker image.""" if not self.docker_path: - raise ValueError('Docker path is not set.') + raise ValueError("Docker path is not set.") if not os.path.exists(self.docker_path): - raise FileNotFoundError(f'Invalid Docker path: {self.docker_path}') + raise FileNotFoundError(f"Invalid Docker path: {self.docker_path}") - logger.info('Building Docker image...') + logger.info("Building Docker image...") self._client.images.build( path=self.docker_path, tag=self.image, rm=True, ) - logger.info('Docker image: %s built.', self.image) + logger.info("Docker image: %s built.", self.image) def _verify_python_installation(self): """Verifies the container has python3 installed.""" - exec_result = self._container.exec_run(['which', 'python3']) + exec_result = self._container.exec_run(["which", "python3"]) if exec_result.exit_code != 0: - raise ValueError('python3 is not installed in the container.') + raise ValueError("python3 is not installed in the container.") + + def _put_input_files(self, input_files: list[File]) -> None: + """Puts input files into the container using put_archive. + + Args: + input_files: The list of input files to copy into the container. + """ + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w") as tar: + for file in input_files: + content = file.content + if isinstance(content, str): + content = content.encode("utf-8") + tarinfo = tarfile.TarInfo(name=file.name) + tarinfo.size = len(content) + tar.addfile(tarinfo, io.BytesIO(content)) + + tar_buffer.seek(0) + self._container.put_archive( + self.input_dir, + tar_buffer.read(), + ) + logger.debug( + "Copied %d input files to %s", len(input_files), self.input_dir + ) + + def _create_output_directory(self) -> None: + """Creates the output directory in the container if it doesn't exist.""" + exec_result = self._container.exec_run( + ["mkdir", "-p", self.output_dir], + ) + if exec_result.exit_code != 0: + logger.warning( + "Failed to create output directory %s: %s", + self.output_dir, + exec_result.output, + ) + + def _get_output_files(self) -> list[File]: + """Gets output files from the container. + + Returns: + The list of output files retrieved from the container. + """ + try: + tar_stream, _ = self._container.get_archive(self.output_dir) + except docker.errors.APIError as e: + if e.response.status_code == 404: + logger.debug("No output files found at %s", self.output_dir) + return [] + raise + if isinstance(tar_stream, bytes): + tar_bytes = tar_stream + else: + tar_bytes = b"".join(tar_stream) + tar_buffer = io.BytesIO(tar_bytes) + output_files = [] + + with tarfile.open(fileobj=tar_buffer, mode="r") as tar: + for member in tar.getmembers(): + if member.isfile(): + file_obj = tar.extractfile(member) + if file_obj: + content = file_obj.read() + file_name = os.path.basename(member.name) + if file_name: + output_files.append( + File( + name=file_name, + content=content, + mime_type=self._guess_mime_type(file_name), + ) + ) + + logger.debug( + "Retrieved %d output files from %s", len(output_files), self.output_dir + ) + return output_files + + def _guess_mime_type(self, file_name: str) -> str: + """Guesses the MIME type based on the file extension. + + Args: + file_name: The name of the file. + + Returns: + The guessed MIME type, or 'application/octet-stream' if unknown. + """ + import mimetypes + + mime_type, _ = mimetypes.guess_type(file_name) + return mime_type if mime_type else "application/octet-stream" def __init_container(self): """Initializes the container.""" if not self._client: - raise RuntimeError('Docker client is not initialized.') + raise RuntimeError("Docker client is not initialized.") if self.docker_path: self._build_docker_image() - logger.info('Starting container for ContainerCodeExecutor...') + logger.info("Starting container for ContainerCodeExecutor...") self._container = self._client.containers.run( image=self.image, detach=True, tty=True, ) - logger.info('Container %s started.', self._container.id) + logger.info("Container %s started.", self._container.id) - # Verify the container is able to run python3. self._verify_python_installation() def __cleanup_container(self): @@ -194,7 +306,7 @@ def __cleanup_container(self): if not self._container: return - logger.info('[Cleanup] Stopping the container...') + logger.info("[Cleanup] Stopping the container...") self._container.stop() self._container.remove() - logger.info('Container %s stopped and removed.', self._container.id) + logger.info("Container %s stopped and removed.", self._container.id) diff --git a/tests/unittests/code_executors/test_container_code_executor.py b/tests/unittests/code_executors/test_container_code_executor.py new file mode 100644 index 0000000000..fabfaf62e1 --- /dev/null +++ b/tests/unittests/code_executors/test_container_code_executor.py @@ -0,0 +1,286 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ContainerCodeExecutor.""" + +import io +import tarfile +from unittest import mock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.code_executors.code_execution_utils import CodeExecutionInput +from google.adk.code_executors.code_execution_utils import CodeExecutionResult +from google.adk.code_executors.code_execution_utils import File +from google.adk.code_executors.container_code_executor import ContainerCodeExecutor +import pytest + + +@pytest.fixture +def mock_container(): + container = mock.MagicMock() + container.id = "test-container-id" + container.exec_run.return_value = mock.MagicMock( + output=(b"Hello World", b""), + exit_code=0, + ) + container.get_archive.return_value = (b"", mock.MagicMock()) + return container + + +@pytest.fixture +def mock_docker_client(mock_container): + client = mock.MagicMock() + client.containers.run.return_value = mock_container + return client + + +class TestContainerCodeExecutorInit: + + def test_init_with_image(self, mock_docker_client): + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + assert executor.image == "test-image:latest" + assert executor.input_dir == "/tmp/inputs" + assert executor.output_dir == "/tmp/outputs" + + def test_init_with_custom_dirs(self, mock_docker_client): + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor( + image="test-image:latest", + input_dir="/custom/inputs", + output_dir="/custom/outputs", + ) + assert executor.input_dir == "/custom/inputs" + assert executor.output_dir == "/custom/outputs" + + def test_init_requires_image_or_docker_path(self): + with pytest.raises( + ValueError, match="Either image or docker_path must be set" + ): + ContainerCodeExecutor() + + def test_init_rejects_stateful(self): + with pytest.raises(ValueError, match="Cannot set `stateful=True`"): + ContainerCodeExecutor(image="test", stateful=True) + + def test_init_allows_optimize_data_file(self, mock_docker_client): + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor( + image="test-image:latest", optimize_data_file=True + ) + assert executor.optimize_data_file is True + + +class TestExecuteCode: + + def test_execute_code_basic(self, mock_docker_client, mock_container): + import docker + + mock_container.get_archive.side_effect = docker.errors.APIError( + "Not found", response=mock.MagicMock(status_code=404) + ) + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + context = mock.MagicMock(spec=InvocationContext) + code_input = CodeExecutionInput(code='print("Hello World")') + + result = executor.execute_code(context, code_input) + + assert result.stdout == "Hello World" + assert result.stderr == "" + assert result.output_files == [] + + def test_execute_code_with_error(self, mock_docker_client, mock_container): + import docker + + call_count = [0] + + def exec_run_side_effect(cmd, demux=False): + call_count[0] += 1 + if call_count[0] == 3: + return mock.MagicMock( + exit_code=1, + output=(b"", b"Some error"), + ) + return mock.MagicMock(exit_code=0, output=(b"", b"")) + + mock_container.exec_run.side_effect = exec_run_side_effect + mock_container.get_archive.side_effect = docker.errors.APIError( + "Not found", response=mock.MagicMock(status_code=404) + ) + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + context = mock.MagicMock(spec=InvocationContext) + code_input = CodeExecutionInput(code='raise Error("test")') + + result = executor.execute_code(context, code_input) + + assert result.stderr == "Some error" + + def test_execute_code_with_input_files( + self, mock_docker_client, mock_container + ): + import docker + + mock_container.put_archive = mock.MagicMock() + mock_container.exec_run.return_value = mock.MagicMock( + output=(b"", b""), + exit_code=0, + ) + mock_container.get_archive.side_effect = docker.errors.APIError( + "Not found", response=mock.MagicMock(status_code=404) + ) + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + context = mock.MagicMock(spec=InvocationContext) + code_input = CodeExecutionInput( + code='print("test")', + input_files=[File(name="test.txt", content="test content")], + ) + + result = executor.execute_code(context, code_input) + + mock_container.put_archive.assert_called_once() + call_args = mock_container.put_archive.call_args + assert call_args[0][0] == "/tmp/inputs" + + def test_execute_code_with_output_files( + self, mock_docker_client, mock_container + ): + mock_container.exec_run.return_value = mock.MagicMock( + output=(b"", b""), + exit_code=0, + ) + + content = b"output content" + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w") as tar: + tar_info = tarfile.TarInfo(name="output.txt") + tar_info.size = len(content) + tar.addfile(tar_info, io.BytesIO(content)) + tar_buffer.seek(0) + tar_bytes = tar_buffer.read() + + mock_container.get_archive.return_value = (tar_bytes, mock.MagicMock()) + + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + context = mock.MagicMock(spec=InvocationContext) + code_input = CodeExecutionInput(code='print("test")') + + result = executor.execute_code(context, code_input) + + assert len(result.output_files) == 1 + assert result.output_files[0].name == "output.txt" + + +class TestPutInputFiles: + + def test_put_archive_called(self, mock_docker_client, mock_container): + mock_container.put_archive = mock.MagicMock() + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + input_files = [ + File(name="file1.txt", content="content1"), + File(name="file2.txt", content="content2"), + ] + + executor._put_input_files(input_files) + + mock_container.put_archive.assert_called_once() + call_args = mock_container.put_archive.call_args + assert call_args[0][0] == "/tmp/inputs" + + def test_handles_string_content(self, mock_docker_client, mock_container): + mock_container.put_archive = mock.MagicMock() + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + input_files = [File(name="test.txt", content="string content")] + + executor._put_input_files(input_files) + + mock_container.put_archive.assert_called_once() + + +class TestGetOutputFiles: + + def test_no_output_files(self, mock_docker_client, mock_container): + import docker + + mock_container.get_archive.side_effect = docker.errors.APIError( + "Not found", response=mock.MagicMock(status_code=404) + ) + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + + output_files = executor._get_output_files() + + assert output_files == [] + + def test_extracts_files_from_archive( + self, mock_docker_client, mock_container + ): + content = b"output content" + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w") as tar: + tar_info = tarfile.TarInfo(name="output.txt") + tar_info.size = len(content) + tar.addfile(tar_info, io.BytesIO(content)) + tar_buffer.seek(0) + tar_bytes = tar_buffer.read() + + mock_container.get_archive.return_value = (tar_bytes, mock.MagicMock()) + + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + + output_files = executor._get_output_files() + + assert len(output_files) == 1 + assert output_files[0].name == "output.txt" + assert output_files[0].content == content + + +class TestMimeTypeGuessing: + + def test_guess_txt(self, mock_docker_client, mock_container): + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + + mime_type = executor._guess_mime_type("test.txt") + + assert mime_type == "text/plain" + + def test_guess_csv(self, mock_docker_client, mock_container): + import mimetypes as mimetypes_module + + original_guess_type = mimetypes_module.guess_type + mimetypes_module.guess_type = lambda f: ("text/csv", None) + try: + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + + mime_type = executor._guess_mime_type("data.csv") + + assert mime_type == "text/csv" + finally: + mimetypes_module.guess_type = original_guess_type + + def test_default_for_unknown(self, mock_docker_client, mock_container): + with mock.patch("docker.from_env", return_value=mock_docker_client): + executor = ContainerCodeExecutor(image="test-image:latest") + + mime_type = executor._guess_mime_type("unknown.xyz") + + assert mime_type == "application/octet-stream"