From 07faa08b8622144e656d45cf015a665bb9a2a722 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 18 Apr 2019 22:01:55 -0700 Subject: [PATCH] Updated the "Basic - Recursive loop" sample (#1113) Modernized the sample pipeline code. --- samples/basic/recursion.py | 63 ++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/samples/basic/recursion.py b/samples/basic/recursion.py index 5bd6c652049..a5f219dac3f 100644 --- a/samples/basic/recursion.py +++ b/samples/basic/recursion.py @@ -13,54 +13,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -import kfp.dsl as dsl -class FlipCoinOp(dsl.ContainerOp): - """Flip a coin and output heads or tails randomly.""" +import kfp +from kfp import dsl - def __init__(self): - super(FlipCoinOp, self).__init__( - name='Flip', + +def flip_coin_op(): + """Flip a coin and output heads or tails randomly.""" + return dsl.ContainerOp( + name='Flip coin', image='python:alpine3.6', command=['sh', '-c'], arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 ' - 'else \'tails\'; print(result)" | tee /tmp/output'], - file_outputs={'output': '/tmp/output'}) + 'else \'tails\'; print(result)" | tee /tmp/output'], + file_outputs={'output': '/tmp/output'} + ) -class PrintOp(dsl.ContainerOp): - """Print a message.""" - def __init__(self, msg): - super(PrintOp, self).__init__( +def print_op(msg): + """Print a message.""" + return dsl.ContainerOp( name='Print', image='alpine:3.6', command=['echo', msg], ) -# Use the dsl.graph_component to decorate functions that are + +# Use the dsl.graph_component to decorate pipeline functions that can be # recursively called. @dsl.graph_component def flip_component(flip_result): - print_flip = PrintOp(flip_result) - flipA = FlipCoinOp().after(print_flip) - with dsl.Condition(flipA.output == 'heads'): - # When the flip_component is called recursively, the flipA.output - # from inside the graph component will be passed to the next flip_component - # as the input whereas the flip_result in the current graph component - # comes from the flipA.output in the flipcoin function. - flip_component(flipA.output) + print_flip = print_op(flip_result) + flipA = flip_coin_op().after(print_flip) + with dsl.Condition(flipA.output == 'heads'): + # When the flip_component is called recursively, the flipA.output + # from inside the graph component will be passed to the next flip_component + # as the input whereas the flip_result in the current graph component + # comes from the flipA.output in the flipcoin function. + flip_component(flipA.output) + @dsl.pipeline( - name='pipeline flip coin', - description='shows how to use dsl.Condition.' + name='Recursive loop pipeline', + description='Shows how to create recursive loops.' ) def flipcoin(): - flipA = FlipCoinOp() - flip_loop = flip_component(flipA.output) - # flip_loop is a graph_component with the outputs field - # filled with the returned dictionary. - PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop) + first_flip = flip_coin_op() + flip_loop = flip_component(first_flip.output) + # flip_loop is a graph_component with the outputs field + # filled with the returned dictionary. + print_op('cool, it is over.').after(flip_loop) + if __name__ == '__main__': - import kfp.compiler as compiler - compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz') + kfp.compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz')