Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Shared MLflow presigned URL utilities."""

import logging
from typing import Optional

logger = logging.getLogger(__name__)


def get_presigned_mlflow_experiment_url(
mlflow_resource_arn: str,
mlflow_experiment_name: Optional[str] = None,
) -> Optional[str]:
"""Generate a presigned MLflow URL, optionally deep-linked to an experiment.

Args:
mlflow_resource_arn: MLflow tracking server or app ARN.
mlflow_experiment_name: Optional experiment name for deep-linking.

Returns:
Presigned URL with experiment fragment, or base URL, or None on failure.
"""
try:
from sagemaker.core.utils.utils import SageMakerClient

sm_client = SageMakerClient().sagemaker_client
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_resource_arn)
base_url = response.get("AuthorizedUrl")
if not base_url:
return None

if mlflow_experiment_name:
try:
import mlflow
from mlflow.tracking import MlflowClient

mlflow.set_tracking_uri(mlflow_resource_arn)
experiment = MlflowClient(
tracking_uri=mlflow_resource_arn
).get_experiment_by_name(mlflow_experiment_name)
if experiment:
return f"{base_url}#/experiments/{experiment.experiment_id}"
except Exception as e:
logger.debug(f"Failed to resolve MLflow experiment '{mlflow_experiment_name}': {e}")

return base_url
except Exception as e:
logger.debug(f"Failed to generate MLflow experiment URL: {e}")
return None
32 changes: 7 additions & 25 deletions sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,35 +182,17 @@ def get_mlflow_url(training_job) -> str:
if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config):
raise ValueError("Training job does not have MLflow configured")

import os
from mlflow.tracking import MlflowClient
import mlflow
from sagemaker.core.utils.utils import SageMakerClient
from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url

mlflow_arn = training_job.mlflow_config.mlflow_resource_arn
exp_name = training_job.mlflow_config.mlflow_experiment_name
if _is_unassigned_attribute(exp_name):
exp_name = None

# Get presigned base URL
sm_client = SageMakerClient().sagemaker_client
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn)
base_url = response.get('AuthorizedUrl')

# Try to get experiment ID and append to URL
try:
os.environ['MLFLOW_TRACKING_URI'] = mlflow_arn
mlflow.set_tracking_uri(mlflow_arn)

mlflow_client = MlflowClient(tracking_uri=mlflow_arn)
experiment = mlflow_client.get_experiment_by_name(exp_name)

if experiment:
# Format: base_url#/experiments/{id}
# The base_url already has /auth?authToken=...
return f"{base_url}#/experiments/{experiment.experiment_id}"
except Exception:
pass

return base_url
url = get_presigned_mlflow_experiment_url(mlflow_arn, exp_name)
if url is None:
raise ValueError("Failed to generate presigned MLflow URL")
return url



Expand Down
26 changes: 26 additions & 0 deletions sagemaker-train/src/sagemaker/train/evaluate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,28 @@ def wait(
# Create console with Jupyter support
console = Console(force_jupyter=True)

# MLflow link caching (presigned URLs expire after 5 min)
mlflow_link_cache = {'url': None, 'timestamp': 0}

def get_cached_mlflow_url():
"""Get cached MLflow URL, regenerating every 4 minutes."""
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute
from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url

current_time = time.time()
if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240:
pe = self._pipeline_execution
mlflow_cfg = getattr(pe, 'm_lflow_config', None) if pe else None
if mlflow_cfg and not _is_unassigned_attribute(mlflow_cfg):
arn = getattr(mlflow_cfg, 'mlflow_resource_arn', None)
if arn and not _is_unassigned_attribute(arn):
exp_name = getattr(mlflow_cfg, 'mlflow_experiment_name', None)
if exp_name and _is_unassigned_attribute(exp_name):
exp_name = None
mlflow_link_cache['url'] = get_presigned_mlflow_experiment_url(arn, exp_name)
mlflow_link_cache['timestamp'] = current_time
return mlflow_link_cache['url']

while True:
clear_output(wait=True)
self.refresh()
Expand Down Expand Up @@ -960,6 +982,10 @@ def wait(
links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]")
except Exception:
pass
# Add MLflow experiment link if available
cached_mlflow_url = get_cached_mlflow_url()
if cached_mlflow_url:
links.append(f"[bright_blue underline][link={cached_mlflow_url}]🔗 MLflow Experiment[/link][/bright_blue underline]")
if links:
header_table.add_row("Links", " | ".join(links))

Expand Down
64 changes: 64 additions & 0 deletions sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_is_unassigned_attribute,
_calculate_training_progress,
_calculate_transition_duration,
get_mlflow_url,
wait
)

Expand Down Expand Up @@ -489,3 +490,66 @@ def test_wait_metrics_exception_non_jupyter(self, mock_is_jupyter, mock_setup_ml

# Should complete successfully despite metrics exception
training_job.refresh.assert_called()


class TestGetMlflowUrl:
"""Test cases for get_mlflow_url function."""

@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
def test_delegates_to_shared_helper(self, mock_helper):
"""Test that get_mlflow_url extracts config and delegates to shared helper."""
mock_helper.return_value = "https://mlflow.example.com/auth?token=abc#/experiments/42"

training_job = MagicMock()
training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
training_job.mlflow_config.mlflow_experiment_name = "my-experiment"

result = get_mlflow_url(training_job)

mock_helper.assert_called_once_with(
"arn:aws:sagemaker:us-west-2:123:mlflow-app/test",
"my-experiment",
)
assert result == "https://mlflow.example.com/auth?token=abc#/experiments/42"

@patch("sagemaker.train.common_utils.trainer_wait.TrainingJob")
@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
def test_accepts_job_name_string(self, mock_helper, mock_tj_class):
"""Test that a string job name is resolved via TrainingJob.get()."""
mock_helper.return_value = "https://mlflow.example.com/auth"
mock_tj = MagicMock()
mock_tj.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
mock_tj.mlflow_config.mlflow_experiment_name = None
mock_tj_class.get.return_value = mock_tj

result = get_mlflow_url("my-training-job")

mock_tj_class.get.assert_called_once_with(training_job_name="my-training-job")
assert result == "https://mlflow.example.com/auth"

def test_raises_when_no_mlflow_config(self):
"""Test raises ValueError when training job has no mlflow config."""
training_job = MagicMock()
training_job.mlflow_config = MockUnassignedAttribute()

with pytest.raises(ValueError, match="does not have MLflow configured"):
get_mlflow_url(training_job)

def test_raises_when_mlflow_config_missing(self):
"""Test raises ValueError when training job lacks mlflow_config attribute."""
training_job = MagicMock(spec=[]) # no attributes

with pytest.raises(ValueError, match="does not have MLflow configured"):
get_mlflow_url(training_job)

@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
def test_raises_when_helper_returns_none(self, mock_helper):
"""Test raises ValueError when presigned URL generation fails."""
mock_helper.return_value = None

training_job = MagicMock()
training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
training_job.mlflow_config.mlflow_experiment_name = "exp"

with pytest.raises(ValueError, match="Failed to generate presigned MLflow URL"):
get_mlflow_url(training_job)
Loading
Loading