Skip to content

Commit 224ffcf

Browse files
committed
fixing formatting issues and resolving pr comments
Signed-off-by: Nelesh Singla <117123879+nsingla@users.noreply.github.com>
1 parent de94ca6 commit 224ffcf

17 files changed

Lines changed: 780 additions & 54 deletions

File tree

.github/workflows/kfp-sdk-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ concurrency:
2020
jobs:
2121
sdk-tests:
2222
runs-on: ubuntu-latest
23-
timeout-minutes: 45
23+
timeout-minutes: 50
2424
strategy:
2525
matrix:
2626
python-version: ['3.9', '3.13']

backend/test/testutil/kubernetes_utils.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ func ReadPodLogs(client *kubernetes.Clientset, namespace string, podName string,
7777
podLogsRequest := client.CoreV1().Pods(namespace).GetLogs(podFromPodName.Name, podLogOptions)
7878
podLogs, err := podLogsRequest.Stream(context.Background()) // Pass a context for cancellation
7979
if err != nil {
80-
logger.Log("Failed to stream pod logs due to %v", err)
80+
logger.Log("Failed to stream pod logs for container '%s' due to %v", container.Name, err)
81+
continue
82+
}
83+
if podLogs == nil {
84+
logger.Log("Pod log stream is nil for container '%s'", container.Name)
85+
continue
8186
}
8287
defer func(podLogs io.ReadCloser) {
8388
err = podLogs.Close()
@@ -87,7 +92,7 @@ func ReadPodLogs(client *kubernetes.Clientset, namespace string, podName string,
8792
}(podLogs)
8893
_, err = io.Copy(buf, podLogs)
8994
if err != nil {
90-
logger.Log("Failed to add pod logs to buffer due to: %v", err)
95+
logger.Log("Failed to add pod logs to buffer for container '%s' due to: %v", container.Name, err)
9196
}
9297
}
9398
} else {

kubernetes_platform/python/kfp/kubernetes/common.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,6 @@ def get_existing_kubernetes_config_as_message(
4747
return json_format.ParseDict(cur_k8_config_dict, k8_config_msg)
4848

4949

50-
def ensure_channel_input(task: PipelineTask,
51-
channel: pipeline_channel.PipelineChannel) -> None:
52-
"""Adds a channel to the task's tracked inputs if not already present.
53-
54-
This ensures the compiler propagates the channel through sub-DAG
55-
boundaries (e.g. ParallelFor) even when the channel is only
56-
referenced from Kubernetes platform config and not from normal
57-
task arguments.
58-
"""
59-
existing_channel_patterns = {
60-
existing.pattern for existing in task._channel_inputs
61-
}
62-
if channel.pattern not in existing_channel_patterns:
63-
task._channel_inputs.append(channel)
64-
65-
6650
def parse_k8s_parameter_input(
6751
input_param: Union[pipeline_channel.PipelineParameterChannel, str, dict],
6852
task: PipelineTask,

kubernetes_platform/python/kfp/kubernetes/volume.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def mount_pvc(
8888
pvc_mount.pvc_name_parameter.CopyFrom(pvc_name_parameter)
8989

9090
# deprecated: for backwards compatibility
91-
pvc_name_from_task = _assign_pvc_name_to_msg(task, pvc_mount, pvc_name)
91+
pvc_name_from_task = _assign_pvc_name_to_msg(pvc_mount, pvc_name)
9292

9393
if pvc_name_from_task:
9494
task.after(pvc_name.task)
@@ -109,11 +109,8 @@ def DeletePVC(pvc_name: str):
109109
return dsl.ContainerSpec(image='argostub/deletepvc')
110110

111111

112-
def _assign_pvc_name_to_msg(
113-
task: PipelineTask,
114-
msg: message.Message,
115-
pvc_name: Union[str, 'PipelineChannel'],
116-
) -> bool:
112+
def _assign_pvc_name_to_msg(msg: message.Message,
113+
pvc_name: Union[str, 'PipelineChannel']) -> bool:
117114
"""Assigns pvc_name to the msg's pvc_reference oneof.
118115
119116
Returns True if pvc_name is an upstream task output; otherwise, False.
@@ -122,7 +119,6 @@ def _assign_pvc_name_to_msg(
122119
msg.constant = pvc_name
123120
return False
124121
elif hasattr(pvc_name, 'task_name'):
125-
common.ensure_channel_input(task=task, channel=pvc_name)
126122
if pvc_name.task_name is None:
127123
msg.component_input_parameter = pvc_name.name
128124
return False

kubernetes_platform/python/test/unit/test_secret.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -973,29 +973,29 @@ def my_pipeline():
973973

974974
class TestEnsureChannelInputPropagation:
975975
"""Tests that PipelineParameterChannel values used in Kubernetes platform
976-
config are correctly registered in task._channel_inputs for sub-DAG
976+
config are correctly registered in task.channel_inputs for sub-DAG
977977
propagation."""
978978

979979
def test_pipeline_param_added_to_channel_inputs(self):
980980
"""When a PipelineParameterChannel is passed as secret_name,
981-
it should be appended to the task's _channel_inputs."""
981+
it should be appended to the task's channel_inputs."""
982982

983983
@dsl.pipeline
984984
def my_pipeline(secret_name: str):
985985
task = comp()
986-
initial_count = len(task._channel_inputs)
986+
initial_count = len(task.channel_inputs)
987987
kubernetes.use_secret_as_env(
988988
task,
989989
secret_name=secret_name,
990990
secret_key_to_env={"key": "VAR"},
991991
)
992-
assert len(task._channel_inputs) == initial_count + 1
993-
channel_patterns = {ch.pattern for ch in task._channel_inputs}
992+
assert len(task.channel_inputs) == initial_count + 1
993+
channel_patterns = {ch.pattern for ch in task.channel_inputs}
994994
assert secret_name.pattern in channel_patterns
995995

996996
def test_duplicate_pipeline_param_not_added_twice(self):
997997
"""When the same PipelineParameterChannel is used in both
998-
use_secret_as_env and use_secret_as_volume, _channel_inputs
998+
use_secret_as_env and use_secret_as_volume, channel_inputs
999999
should not contain duplicates."""
10001000

10011001
@dsl.pipeline
@@ -1006,44 +1006,44 @@ def my_pipeline(secret_name: str):
10061006
secret_name=secret_name,
10071007
secret_key_to_env={"key1": "VAR1"},
10081008
)
1009-
count_after_first = len(task._channel_inputs)
1009+
count_after_first = len(task.channel_inputs)
10101010
kubernetes.use_secret_as_volume(
10111011
task,
10121012
secret_name=secret_name,
10131013
mount_path="/mnt/secret",
10141014
)
1015-
assert len(task._channel_inputs) == count_after_first
1015+
assert len(task.channel_inputs) == count_after_first
10161016

10171017
def test_string_input_does_not_add_channel_inputs(self):
1018-
"""When a literal string is passed, _channel_inputs should not
1018+
"""When a literal string is passed, channel_inputs should not
10191019
be modified."""
10201020

10211021
@dsl.pipeline
10221022
def my_pipeline():
10231023
task = comp()
1024-
initial_count = len(task._channel_inputs)
1024+
initial_count = len(task.channel_inputs)
10251025
kubernetes.use_secret_as_env(
10261026
task,
10271027
secret_name="literal-secret",
10281028
secret_key_to_env={"key": "VAR"},
10291029
)
1030-
assert len(task._channel_inputs) == initial_count
1030+
assert len(task.channel_inputs) == initial_count
10311031

10321032
def test_task_output_added_to_channel_inputs(self):
10331033
"""When a task output PipelineChannel is passed as secret_name,
1034-
it should be appended to _channel_inputs."""
1034+
it should be appended to channel_inputs."""
10351035

10361036
@dsl.pipeline
10371037
def my_pipeline():
10381038
name_task = comp_with_output()
10391039
task = comp()
1040-
initial_count = len(task._channel_inputs)
1040+
initial_count = len(task.channel_inputs)
10411041
kubernetes.use_secret_as_env(
10421042
task,
10431043
secret_name=name_task.output,
10441044
secret_key_to_env={"key": "VAR"},
10451045
)
1046-
assert len(task._channel_inputs) == initial_count + 1
1046+
assert len(task.channel_inputs) == initial_count + 1
10471047

10481048

10491049
@dsl.component

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,12 @@ def _rewrite(data: Any) -> Any:
16981698
if surfaced_name in parent_component_inputs.parameters:
16991699
rewritten.pop('taskOutputParameter')
17001700
rewritten['componentInputParameter'] = surfaced_name
1701+
else:
1702+
raise compiler_utils.InvalidTopologyException(
1703+
'Failed to rewrite cross-DAG platform config reference '
1704+
f'for task output {producer_task}.{output_key}. '
1705+
f'Expected surfaced input {surfaced_name!r} in parent '
1706+
'component inputs, but it was not found.')
17011707

17021708
return rewritten
17031709

sdk/python/kfp/compiler/pipeline_spec_builder_test.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from kfp import dsl
2525
from kfp import kubernetes
2626
from kfp.compiler import compiler
27+
from kfp.compiler import compiler_utils
2728
from kfp.compiler import pipeline_spec_builder
2829
from kfp.dsl import TaskConfigField
2930
from kfp.pipeline_spec import pipeline_spec_pb2
@@ -724,6 +725,82 @@ def pipe():
724725
self.assertNotIn('componentInputParameter', secret_name_fields)
725726
self.assertNotIn('taskOutputParameter', secret_name_fields)
726727

728+
def test_parallelfor_pipeline_input_mount_pvc(self):
729+
"""Pipeline pvc_name param inside ParallelFor is correctly surfaced and
730+
rewritten for mount_pvc platform config."""
731+
732+
@dsl.component
733+
def my_comp(item: str):
734+
print(item)
735+
736+
@dsl.pipeline
737+
def pipe(pvc_name: str):
738+
with dsl.ParallelFor(items=['a', 'b'], parallelism=1) as item:
739+
t = my_comp(item=item)
740+
kubernetes.mount_pvc(
741+
t,
742+
pvc_name=pvc_name,
743+
mount_path='/mnt/data',
744+
)
745+
746+
pipeline_spec, platform_spec = self._compile_and_parse(pipe)
747+
748+
loop_component = pipeline_spec.components['comp-for-loop-2']
749+
self.assertIn('pipelinechannel--pvc_name',
750+
loop_component.input_definitions.parameters)
751+
752+
root_task_params = pipeline_spec.root.dag.tasks[
753+
'for-loop-2'].inputs.parameters
754+
self.assertEqual(
755+
root_task_params['pipelinechannel--pvc_name']
756+
.component_input_parameter,
757+
'pvc_name',
758+
)
759+
760+
pvc_param = (
761+
platform_spec.platforms['kubernetes'].deployment_spec
762+
.executors['exec-my-comp'].fields['pvcMount'].list_value.values[0]
763+
.struct_value.fields['pvcNameParameter'].struct_value
764+
.fields['componentInputParameter'].string_value)
765+
self.assertEqual(pvc_param, 'pipelinechannel--pvc_name')
766+
767+
def test_exit_handler_platform_config_rewrite_path(self):
768+
"""Exit handler task platform config uses rewrite path with parent
769+
component context."""
770+
771+
@dsl.component
772+
def cleanup():
773+
print('cleanup')
774+
775+
@dsl.component
776+
def main_task():
777+
print('main')
778+
779+
@dsl.pipeline
780+
def pipe(secret_name: str):
781+
exit_task = cleanup()
782+
kubernetes.use_secret_as_env(
783+
exit_task,
784+
secret_name=secret_name,
785+
secret_key_to_env={'key': 'VAL'},
786+
)
787+
with dsl.ExitHandler(exit_task=exit_task):
788+
main_task()
789+
790+
_, platform_spec = self._compile_and_parse(pipe)
791+
792+
cleanup_executors = [
793+
executor for executor in platform_spec.platforms['kubernetes']
794+
.deployment_spec.executors.values()
795+
if 'secretAsEnv' in executor.fields
796+
]
797+
self.assertEqual(len(cleanup_executors), 1)
798+
secret_param = (
799+
cleanup_executors[0].fields['secretAsEnv'].list_value.values[0]
800+
.struct_value.fields['secretNameParameter'].struct_value
801+
.fields['componentInputParameter'].string_value)
802+
self.assertEqual(secret_param, 'secret_name')
803+
727804

728805
class TestRewritePlatformConfigInputReferences(unittest.TestCase):
729806
"""Unit tests for the _rewrite_platform_config_input_references helper."""
@@ -878,6 +955,29 @@ def test_rewrites_multiple_params(self):
878955
'pipelinechannel--pvc_name',
879956
)
880957

958+
def test_raises_when_cross_dag_output_missing_surfaced_input(self):
959+
platform_config = {
960+
'kubernetes': {
961+
'secretAsEnv': [{
962+
'secretNameParameter': {
963+
'taskOutputParameter': {
964+
'producerTask': 'emit-secret',
965+
'outputParameterKey': 'Output',
966+
}
967+
},
968+
}]
969+
}
970+
}
971+
parent_inputs = pipeline_spec_pb2.ComponentInputsSpec()
972+
973+
with self.assertRaisesRegex(compiler_utils.InvalidTopologyException,
974+
'Expected surfaced input'):
975+
pipeline_spec_builder._rewrite_platform_config_input_references(
976+
platform_config,
977+
parent_inputs,
978+
tasks_in_current_dag=['worker-task'],
979+
)
980+
881981

882982
def pipeline_spec_from_file(filepath: str) -> str:
883983
with open(filepath, 'r') as f:

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def _register_pipeline_channels(
328328
pipeline_channels: List[pipeline_channel.PipelineChannel]) -> None:
329329
"""Backwards-compatible wrapper for ``register_pipeline_channels``."""
330330
self.register_pipeline_channels(pipeline_channels)
331+
331332
@block_if_final()
332333
def set_caching_options(self,
333334
enable_caching: bool,

test_data/compiled-workflows/nested_parallel_for_secret.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ spec:
2121
\"test-secret-1\"\n\n"],"image":"python:3.11"}'
2222
- name: kubernetes-comp-worker-component
2323
value: '{"secretAsEnv":[{"keyToEnv":[{"envVar":"MY_SECRET_KEY","secretKey":"username"}],"optional":false,"secretNameParameter":{"componentInputParameter":"pipelinechannel--emit-secret-name-Output"}}]}'
24-
- name: components-490eb648ead5e4c6f02bebcc44125df7cf0178498ee770cc67d4bd310900a4fe
24+
- name: components-3b4506509514761c798c91195d4caf336df6411cbc2327c68dd7902b1bcad98c
2525
value: '{"executorLabel":"exec-worker-component","inputDefinitions":{"parameters":{"item":{"parameterType":"STRING"}}},"outputDefinitions":{"parameters":{"Output":{"parameterType":"STRING"}}}}'
26-
- name: implementations-490eb648ead5e4c6f02bebcc44125df7cf0178498ee770cc67d4bd310900a4fe
26+
- name: implementations-3b4506509514761c798c91195d4caf336df6411cbc2327c68dd7902b1bcad98c
2727
value: '{"args":["--executor_input","{{$}}","--function_to_execute","worker_component"],"command":["sh","-c","\nif
2828
! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip || python3
2929
-m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1
@@ -35,7 +35,7 @@ spec:
3535
kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import *\n\ndef
3636
worker_component(item: str) -\u003e str:\n import os\n secret_val =
3737
os.environ.get(\"MY_SECRET_KEY\", \"not-set\")\n print(f\"Item: {item},
38-
Secret value: {secret_val}\")\n return secret_val\n\n"],"image":"python:3.11"}'
38+
Secret set: {secret_val != ''not-set''}\")\n return secret_val\n\n"],"image":"python:3.11"}'
3939
- name: components-comp-for-loop-2
4040
value: '{"dag":{"tasks":{"worker-component":{"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-worker-component"},"inputs":{"parameters":{"item":{"componentInputParameter":"pipelinechannel--loop-item-param-1"}}},"taskInfo":{"name":"worker-component"}}}},"inputDefinitions":{"parameters":{"pipelinechannel--emit-secret-name-Output":{"parameterType":"STRING"},"pipelinechannel--loop-item-param-1":{"parameterType":"STRING"}}}}'
4141
- name: components-root
@@ -272,11 +272,11 @@ spec:
272272
- arguments:
273273
parameters:
274274
- name: component
275-
value: '{{workflow.parameters.components-490eb648ead5e4c6f02bebcc44125df7cf0178498ee770cc67d4bd310900a4fe}}'
275+
value: '{{workflow.parameters.components-3b4506509514761c798c91195d4caf336df6411cbc2327c68dd7902b1bcad98c}}'
276276
- name: task
277277
value: '{"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-worker-component"},"inputs":{"parameters":{"item":{"componentInputParameter":"pipelinechannel--loop-item-param-1"}}},"taskInfo":{"name":"worker-component"}}'
278278
- name: container
279-
value: '{{workflow.parameters.implementations-490eb648ead5e4c6f02bebcc44125df7cf0178498ee770cc67d4bd310900a4fe}}'
279+
value: '{{workflow.parameters.implementations-3b4506509514761c798c91195d4caf336df6411cbc2327c68dd7902b1bcad98c}}'
280280
- name: task-name
281281
value: worker-component
282282
- name: parent-dag-id

0 commit comments

Comments
 (0)