|
24 | 24 | from kfp import dsl |
25 | 25 | from kfp import kubernetes |
26 | 26 | from kfp.compiler import compiler |
| 27 | +from kfp.compiler import compiler_utils |
27 | 28 | from kfp.compiler import pipeline_spec_builder |
28 | 29 | from kfp.dsl import TaskConfigField |
29 | 30 | from kfp.pipeline_spec import pipeline_spec_pb2 |
@@ -724,6 +725,82 @@ def pipe(): |
724 | 725 | self.assertNotIn('componentInputParameter', secret_name_fields) |
725 | 726 | self.assertNotIn('taskOutputParameter', secret_name_fields) |
726 | 727 |
|
| 728 | + def test_parallelfor_pipeline_input_mount_pvc(self): |
| 729 | + """Pipeline pvc_name param inside ParallelFor is correctly surfaced and |
| 730 | + rewritten for mount_pvc platform config.""" |
| 731 | + |
| 732 | + @dsl.component |
| 733 | + def my_comp(item: str): |
| 734 | + print(item) |
| 735 | + |
| 736 | + @dsl.pipeline |
| 737 | + def pipe(pvc_name: str): |
| 738 | + with dsl.ParallelFor(items=['a', 'b'], parallelism=1) as item: |
| 739 | + t = my_comp(item=item) |
| 740 | + kubernetes.mount_pvc( |
| 741 | + t, |
| 742 | + pvc_name=pvc_name, |
| 743 | + mount_path='/mnt/data', |
| 744 | + ) |
| 745 | + |
| 746 | + pipeline_spec, platform_spec = self._compile_and_parse(pipe) |
| 747 | + |
| 748 | + loop_component = pipeline_spec.components['comp-for-loop-2'] |
| 749 | + self.assertIn('pipelinechannel--pvc_name', |
| 750 | + loop_component.input_definitions.parameters) |
| 751 | + |
| 752 | + root_task_params = pipeline_spec.root.dag.tasks[ |
| 753 | + 'for-loop-2'].inputs.parameters |
| 754 | + self.assertEqual( |
| 755 | + root_task_params['pipelinechannel--pvc_name'] |
| 756 | + .component_input_parameter, |
| 757 | + 'pvc_name', |
| 758 | + ) |
| 759 | + |
| 760 | + pvc_param = ( |
| 761 | + platform_spec.platforms['kubernetes'].deployment_spec |
| 762 | + .executors['exec-my-comp'].fields['pvcMount'].list_value.values[0] |
| 763 | + .struct_value.fields['pvcNameParameter'].struct_value |
| 764 | + .fields['componentInputParameter'].string_value) |
| 765 | + self.assertEqual(pvc_param, 'pipelinechannel--pvc_name') |
| 766 | + |
| 767 | + def test_exit_handler_platform_config_rewrite_path(self): |
| 768 | + """Exit handler task platform config uses rewrite path with parent |
| 769 | + component context.""" |
| 770 | + |
| 771 | + @dsl.component |
| 772 | + def cleanup(): |
| 773 | + print('cleanup') |
| 774 | + |
| 775 | + @dsl.component |
| 776 | + def main_task(): |
| 777 | + print('main') |
| 778 | + |
| 779 | + @dsl.pipeline |
| 780 | + def pipe(secret_name: str): |
| 781 | + exit_task = cleanup() |
| 782 | + kubernetes.use_secret_as_env( |
| 783 | + exit_task, |
| 784 | + secret_name=secret_name, |
| 785 | + secret_key_to_env={'key': 'VAL'}, |
| 786 | + ) |
| 787 | + with dsl.ExitHandler(exit_task=exit_task): |
| 788 | + main_task() |
| 789 | + |
| 790 | + _, platform_spec = self._compile_and_parse(pipe) |
| 791 | + |
| 792 | + cleanup_executors = [ |
| 793 | + executor for executor in platform_spec.platforms['kubernetes'] |
| 794 | + .deployment_spec.executors.values() |
| 795 | + if 'secretAsEnv' in executor.fields |
| 796 | + ] |
| 797 | + self.assertEqual(len(cleanup_executors), 1) |
| 798 | + secret_param = ( |
| 799 | + cleanup_executors[0].fields['secretAsEnv'].list_value.values[0] |
| 800 | + .struct_value.fields['secretNameParameter'].struct_value |
| 801 | + .fields['componentInputParameter'].string_value) |
| 802 | + self.assertEqual(secret_param, 'secret_name') |
| 803 | + |
727 | 804 |
|
728 | 805 | class TestRewritePlatformConfigInputReferences(unittest.TestCase): |
729 | 806 | """Unit tests for the _rewrite_platform_config_input_references helper.""" |
@@ -878,6 +955,29 @@ def test_rewrites_multiple_params(self): |
878 | 955 | 'pipelinechannel--pvc_name', |
879 | 956 | ) |
880 | 957 |
|
| 958 | + def test_raises_when_cross_dag_output_missing_surfaced_input(self): |
| 959 | + platform_config = { |
| 960 | + 'kubernetes': { |
| 961 | + 'secretAsEnv': [{ |
| 962 | + 'secretNameParameter': { |
| 963 | + 'taskOutputParameter': { |
| 964 | + 'producerTask': 'emit-secret', |
| 965 | + 'outputParameterKey': 'Output', |
| 966 | + } |
| 967 | + }, |
| 968 | + }] |
| 969 | + } |
| 970 | + } |
| 971 | + parent_inputs = pipeline_spec_pb2.ComponentInputsSpec() |
| 972 | + |
| 973 | + with self.assertRaisesRegex(compiler_utils.InvalidTopologyException, |
| 974 | + 'Expected surfaced input'): |
| 975 | + pipeline_spec_builder._rewrite_platform_config_input_references( |
| 976 | + platform_config, |
| 977 | + parent_inputs, |
| 978 | + tasks_in_current_dag=['worker-task'], |
| 979 | + ) |
| 980 | + |
881 | 981 |
|
882 | 982 | def pipeline_spec_from_file(filepath: str) -> str: |
883 | 983 | with open(filepath, 'r') as f: |
|
0 commit comments