diff --git a/CMakeLists.txt b/CMakeLists.txt index 0858e290..62e555be 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,4 +87,5 @@ else() endif(LLVM_MLBRIDGE) install(DIRECTORY include/ DESTINATION include) -install(DIRECTORY CompilerInterface DESTINATION include/python/MLCompilerBridge) +install(DIRECTORY CompilerInterface DESTINATION MLModelRunner/CompilerInterface) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/CompilerInterface DESTINATION ${CMAKE_BINARY_DIR}/MLModelRunner/) diff --git a/CompilerInterface/GrpcCompilerInterface.py b/CompilerInterface/GrpcCompilerInterface.py index 217af66f..6cb82466 100644 --- a/CompilerInterface/GrpcCompilerInterface.py +++ b/CompilerInterface/GrpcCompilerInterface.py @@ -91,7 +91,7 @@ def start_server(self): "{}:{}".format(self.host, self.server_port) ) - if str(added_port) == self.server_port: + if added_port == self.server_port: server.start() print("Server Running") server.wait_for_termination() @@ -100,7 +100,7 @@ def start_server(self): retries += 1 print( "The port", - self.port, + self.server_port, "is already in use retrying! attempt: ", retries, ) diff --git a/MLModelRunner/gRPCModelRunner/CMakeLists.txt b/MLModelRunner/gRPCModelRunner/CMakeLists.txt index d9fb94a6..6ce52d6f 100755 --- a/MLModelRunner/gRPCModelRunner/CMakeLists.txt +++ b/MLModelRunner/gRPCModelRunner/CMakeLists.txt @@ -84,7 +84,7 @@ if(LLVM_MLBRIDGE) ${proto_python_srcs_list} ) else() - add_library(gRPCModelRunnerLib OBJECT + add_library(gRPCModelRunnerLib ${cc_files} ${proto_srcs_list} ${grpc_srcs_list} diff --git a/SerDes/protobufSerDes.cpp b/SerDes/protobufSerDes.cpp index 9990b7e0..38c68a83 100644 --- a/SerDes/protobufSerDes.cpp +++ b/SerDes/protobufSerDes.cpp @@ -156,6 +156,13 @@ void *ProtobufSerDes::deserializeUntyped(void *data) { this->MessageLength = ref.size() * sizeof(int32_t); return ret->data(); } + if (field->type() == FieldDescriptor::Type::TYPE_INT64) { + auto &ref = reflection->GetRepeatedField(*Response, field); + std::vector *ret = + new std::vector(ref.begin(), ref.end()); + this->MessageLength = ref.size() * sizeof(int64_t); + return ret->data(); + } if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { auto ref = reflection->GetRepeatedField(*Response, field); std::vector *ret = new std::vector(ref.begin(), ref.end()); @@ -199,6 +206,12 @@ void *ProtobufSerDes::deserializeUntyped(void *data) { this->MessageLength = sizeof(int32_t); return ptr; } + if (field->type() == FieldDescriptor::Type::TYPE_INT64) { + int64_t value = reflection->GetInt64(*Response, field); + int64_t *ptr = new int64_t(value); + this->MessageLength = sizeof(int64_t); + return ptr; + } if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { float value = reflection->GetFloat(*Response, field); float *ptr = new float(value); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dccdc25b..f4a83a3f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,5 +6,6 @@ file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o) foreach(MODEL_OBJECT ${MODEL_OBJECTS}) target_link_libraries(MLBridgeCPPTest PRIVATE ${MODEL_OBJECT}) endforeach() -target_link_libraries(MLBridgeCPPTest PRIVATE ModelRunnerUtils) -target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include) +target_link_libraries(MLBridgeCPPTest PRIVATE MLCompilerBridge ) +target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_link_libraries(MLBridgeCPPTest PRIVATE tf_xla_runtime) diff --git a/test/MLBridgeTest.cpp b/test/MLBridgeTest.cpp index 794b1eee..19153833 100644 --- a/test/MLBridgeTest.cpp +++ b/test/MLBridgeTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "HelloMLBridge_Env.h" #include "MLModelRunner/MLModelRunner.h" #include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h" #include "MLModelRunner/PipeModelRunner.h" @@ -13,15 +14,10 @@ #include "MLModelRunner/Utils/DataTypes.h" #include "MLModelRunner/Utils/MLConfig.h" #include "MLModelRunner/gRPCModelRunner.h" -// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.grpc.pb.h" -// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.pb.h" -#include "grpcpp/impl/codegen/status.h" -#include "inference/HelloMLBridge_Env.h" +#include "ProtosInclude.h" #include "llvm/Support/CommandLine.h" -// #include "llvm/Support/raw_ostream.h" #include #include -#include #include #include #include @@ -31,6 +27,20 @@ #define debug_out \ if (!silent) \ std::cout +using namespace grpc; + +#define gRPCModelRunnerInit(datatype) \ + increment_port(1); \ + MLBridgeTestgRPC_##datatype::Reply response; \ + MLBridgeTestgRPC_##datatype::Request request; \ + MLRunner = std::make_unique< \ + gRPCModelRunner>( \ + server_address, &request, &response, nullptr); \ + MLRunner->setRequest(&request); \ + MLRunner->setResponse(&response) static llvm::cl::opt cl_server_address("test-server-address", llvm::cl::Hidden, @@ -41,6 +51,10 @@ static llvm::cl::opt cl_pipe_name("test-pipe-name", llvm::cl::Hidden, llvm::cl::init(""), llvm::cl::desc("Name for pipe file")); +static llvm::cl::opt + cl_onnx_path("onnx-model-path", llvm::cl::Hidden, llvm::cl::init(""), + llvm::cl::desc("Path to onnx model")); + static llvm::cl::opt cl_test_config( "test-config", llvm::cl::Hidden, llvm::cl::desc("Method for communication with python model")); @@ -55,9 +69,9 @@ std::string basename; BaseSerDes::Kind SerDesType; std::string test_config; -std::string data_format; std::string pipe_name; std::string server_address; +std::string onnx_path; // send value of type T1. Test received value of type T2 against expected value template @@ -96,77 +110,149 @@ void testVector(std::string label, std::vector value, debug_out << "\n"; } -void runTests() { - if (data_format != "json") { - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 1234567891); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, - 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, - {123456780, 123456781, 123456782}); - testVector("vec_float", {11.1, 22.2, 33.3}, - {1.11, 2.22, -3.33, 0}); - testVector("vec_double", - {-1.1111111111, -2.2222222222, -3.3333333333}, - {1.12345678912345670, -1.12345678912345671}); - } else if (data_format == "json") { - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 12345); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, - 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, - {6780, 6781, 6782}); - testVector("vec_float", {11.1, 22.2, 33.3}, - {1.11, 2.22, -3.33, 0}); - testVector("vec_double", - {-1.1111111111, -2.2222222222, -3.3333333333}, - {1.12345678912345670, -1.12345678912345671}); +int testPipeBytes() { + if (pipe_name == "") { + std::cerr + << "Pipe name must be specified via --test-pipe-name=\n"; + exit(1); } + basename = "./" + pipe_name; + SerDesType = BaseSerDes::Kind::Bitstream; + MLRunner = std::make_unique( + basename + ".out", basename + ".in", SerDesType, nullptr); + testPrimitive("int", 11, 12); + testPrimitive("long", 1234567890, 1234567891); + testPrimitive("float", 3.14, 4.14); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + testPrimitive("char", 'a', 'b'); + testPrimitive("bool", true, false); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + testVector("vec_long", {123456780, 222, 333}, + {123456780, 123456781, 123456782}); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); + return 0; } -int testPipes() { +int testPipeJSON() { if (pipe_name == "") { std::cerr << "Pipe name must be specified via --test-pipe-name=\n"; exit(1); } - basename = "/tmp/" + pipe_name; - if (data_format == "json") - SerDesType = BaseSerDes::Kind::Json; - else if (data_format == "protobuf") - SerDesType = BaseSerDes::Kind::Protobuf; - else if (data_format == "bytes") - SerDesType = BaseSerDes::Kind::Bitstream; - else { - std::cout << "Invalid data format\n"; - exit(1); - } - + basename = "./" + pipe_name; + SerDesType = BaseSerDes::Kind::Json; MLRunner = std::make_unique( basename + ".out", basename + ".in", SerDesType, nullptr); - - runTests(); + testPrimitive("int", 11, 12); + testPrimitive("long", 1234567890, 12345); + testPrimitive("float", 3.14, 4.14); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + testPrimitive("char", 'a', 'b'); + testPrimitive("bool", true, false); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + testVector("vec_long", {123456780, 222, 333}, + {6780, 6781, 6782}); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); return 0; } +void increment_port(int delta) { + int split = server_address.find(":"); + int port = stoi(server_address.substr(split + 1)); + server_address = + server_address.substr(0, split) + ":" + to_string(port + delta); +} + int testGRPC() { if (server_address == "") { std::cerr << "Server Address must be specified via " - "--test-server-address=:\n"; + "--test-server-address=\":\"\n"; exit(1); } + { + gRPCModelRunnerInit(int); + testPrimitive("int", 11, 12); + } + { + gRPCModelRunnerInit(long); + testPrimitive("long", 1234567890, 1234567891); + } + { + gRPCModelRunnerInit(float); + testPrimitive("float", 3.14, 4.14); + } + { + gRPCModelRunnerInit(double); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + } + increment_port(1); + { + gRPCModelRunnerInit(bool); + testPrimitive("bool", true, false); + } + { + gRPCModelRunnerInit(vec_int); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + } + { + gRPCModelRunnerInit(vec_long); + testVector("vec_long", {123456780, 222, 333}, + {123456780, 123456781, 123456782}); + } + { + gRPCModelRunnerInit(vec_float); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + } + { + gRPCModelRunnerInit(vec_double); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); + } return 0; } -int testONNX() { return 0; } +class ONNXTest : public MLBridgeTestEnv { +public: + int run(int expectedAction) { + onnx_path = cl_onnx_path.getValue(); + if (onnx_path == "") { + std::cerr << "ONNX model path must be specified via " + "--onnx-model-path=\n"; + exit(1); + } + FeatureVector.clear(); + int n = 100; + for (int i = 0; i < n; i++) { + float delta = (float)(i - expectedAction) / n; + FeatureVector.push_back(delta * delta); + } + + Agent *agent = new Agent(onnx_path); + std::map agents; + agents["agent"] = agent; + MLRunner = std::make_unique(this, agents, nullptr); + MLRunner->evaluate(); + if (lastAction != expectedAction) { + std::cerr << "Error: Expected action: " << expectedAction + << ", Computed action: " << lastAction << "\n"; + exit(1); + } + return 0; + } +}; } // namespace @@ -176,18 +262,17 @@ int main(int argc, char **argv) { if (test_config == "pipe-bytes") { pipe_name = cl_pipe_name.getValue(); - data_format = "bytes"; - testPipes(); + testPipeBytes(); } else if (test_config == "pipe-json") { pipe_name = cl_pipe_name.getValue(); - data_format = "json"; - testPipes(); + testPipeJSON(); } else if (test_config == "grpc") { server_address = cl_server_address.getValue(); testGRPC(); - } else if (test_config == "onnx") - testONNX(); - else + } else if (test_config == "onnx") { + ONNXTest t; + t.run(20); + } else std::cerr << "--test-config must be provided from [pipe-bytes, pipe-json, " "grpc, onnx]\n"; return 0; diff --git a/test/inference/HelloMLBridge_Env.h b/test/include/HelloMLBridge_Env.h similarity index 80% rename from test/inference/HelloMLBridge_Env.h rename to test/include/HelloMLBridge_Env.h index 0faa328f..243013ff 100644 --- a/test/inference/HelloMLBridge_Env.h +++ b/test/include/HelloMLBridge_Env.h @@ -12,28 +12,29 @@ #include "llvm/Support/raw_ostream.h" using namespace MLBridge; -class HelloMLBridgeEnv : public Environment { +class MLBridgeTestEnv : public Environment { Observation CurrObs; public: - HelloMLBridgeEnv() { setNextAgent("agent"); }; + MLBridgeTestEnv() { setNextAgent("agent"); }; Observation &reset() override; Observation &step(Action) override; + Action lastAction; protected: std::vector FeatureVector; }; -Observation &HelloMLBridgeEnv::step(Action Action) { +Observation &MLBridgeTestEnv::step(Action Action) { CurrObs.clear(); std::copy(FeatureVector.begin(), FeatureVector.end(), std::back_inserter(CurrObs)); - llvm::outs() << "Action: " << Action << "\n"; + lastAction = Action; setDone(); return CurrObs; } -Observation &HelloMLBridgeEnv::reset() { +Observation &MLBridgeTestEnv::reset() { std::copy(FeatureVector.begin(), FeatureVector.end(), std::back_inserter(CurrObs)); return CurrObs; diff --git a/test/include/ProtosInclude.h b/test/include/ProtosInclude.h new file mode 100644 index 00000000..da4cdc2d --- /dev/null +++ b/test/include/ProtosInclude.h @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM +// Exceptions. See the LICENSE file for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.grpc.pb.h" +#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.pb.h" +#include "grpc/MLBridgeTest_char/MLBridgeTest_char.grpc.pb.h" +#include "grpc/MLBridgeTest_char/MLBridgeTest_char.pb.h" +#include "grpc/MLBridgeTest_double/MLBridgeTest_double.grpc.pb.h" +#include "grpc/MLBridgeTest_double/MLBridgeTest_double.pb.h" +#include "grpc/MLBridgeTest_float/MLBridgeTest_float.grpc.pb.h" +#include "grpc/MLBridgeTest_float/MLBridgeTest_float.pb.h" +#include "grpc/MLBridgeTest_int/MLBridgeTest_int.grpc.pb.h" +#include "grpc/MLBridgeTest_int/MLBridgeTest_int.pb.h" +#include "grpc/MLBridgeTest_long/MLBridgeTest_long.grpc.pb.h" +#include "grpc/MLBridgeTest_long/MLBridgeTest_long.pb.h" +#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.pb.h" +#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.pb.h" +#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.pb.h" +#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.pb.h" diff --git a/test/mlbridge-test.py b/test/mlbridge-test.py index 4cb8466d..e1ecc6be 100644 --- a/test/mlbridge-test.py +++ b/test/mlbridge-test.py @@ -5,29 +5,32 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # # ------------------------------------------------------------------------------ - import argparse import numpy as np import ctypes import sys +import os import torch, torch.nn as nn - -sys.path.append("../CompilerInterface") +import torch.onnx +import subprocess +import time + +BUILD_DIR = "../build_release" +sys.path.extend( + [ + f"{BUILD_DIR}/MLModelRunner/CompilerInterface", + f"{BUILD_DIR}/MLModelRunner/gRPCModelRunner/Python-Utilities", + ] +) from PipeCompilerInterface import PipeCompilerInterface from GrpcCompilerInterface import GrpcCompilerInterface -sys.path.append("./Python-Utilities") -import helloMLBridge_pb2, helloMLBridge_pb2_grpc, grpc -from concurrent import futures - FAIL = 1 SUCCESS = 0 parser = argparse.ArgumentParser() -parser.add_argument( - "--use_pipe", type=bool, default=False, help="Use pipe or not", required=False -) +parser.add_argument("--use_pipe", default=False, help="Use pipe or not", required=False) parser.add_argument( "--data_format", type=str, @@ -53,19 +56,113 @@ "--server_port", type=int, help="Server Port", - default=5050, +) +parser.add_argument( + "--test_number", + type=int, + help="Datatype number for test", + default=0, +) +parser.add_argument( + "--export_onnx", + help="Export onnx test model", + required=False, + default=False, ) args = parser.parse_args() +if args.test_number <= 1: + import MLBridgeTest_int_pb2, MLBridgeTest_int_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_int_pb2, + MLBridgeTest_int_pb2_grpc, + ) +elif args.test_number == 2: + import MLBridgeTest_long_pb2, MLBridgeTest_long_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_long_pb2, + MLBridgeTest_long_pb2_grpc, + ) +elif args.test_number == 3: + import MLBridgeTest_float_pb2, MLBridgeTest_float_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_float_pb2, + MLBridgeTest_float_pb2_grpc, + ) +elif args.test_number == 4: + import MLBridgeTest_double_pb2, MLBridgeTest_double_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_double_pb2, + MLBridgeTest_double_pb2_grpc, + ) +elif args.test_number == 5: + import MLBridgeTest_char_pb2, MLBridgeTest_char_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_char_pb2, + MLBridgeTest_char_pb2_grpc, + ) +elif args.test_number == 6: + import MLBridgeTest_bool_pb2, MLBridgeTest_bool_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_bool_pb2, + MLBridgeTest_bool_pb2_grpc, + ) +elif args.test_number == 7: + import MLBridgeTest_vec_int_pb2, MLBridgeTest_vec_int_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_int_pb2, + MLBridgeTest_vec_int_pb2_grpc, + ) +elif args.test_number == 8: + import MLBridgeTest_vec_long_pb2, MLBridgeTest_vec_long_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_long_pb2, + MLBridgeTest_vec_long_pb2_grpc, + ) +elif args.test_number == 9: + import MLBridgeTest_vec_float_pb2, MLBridgeTest_vec_float_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_float_pb2, + MLBridgeTest_vec_float_pb2_grpc, + ) +elif args.test_number == 10: + import MLBridgeTest_vec_double_pb2, MLBridgeTest_vec_double_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_double_pb2, + MLBridgeTest_vec_double_pb2_grpc, + ) + class DummyModel(nn.Module): - def __init__(self, input_dim=10): + def __init__(self): nn.Module.__init__(self) - self.fc1 = nn.Linear(input_dim, 1) def forward(self, input): - x = self.fc1(input) - return x + return 2 - input + + +def export_onnx_model(input_dim=100): + onnx_filename = "./onnx/dummy_model.onnx" + dummy_value = torch.randn(1, input_dim) + torch.onnx.export( + DummyModel(), + dummy_value, + onnx_filename, + input_names=["obs"], + verbose=True, + export_params=True, + ) + print(f"Model exported to {onnx_filename}") expected_type = { @@ -96,34 +193,77 @@ def forward(self, input): returned_data = { 1: 12, - 2: ctypes.c_long(1234567891), + 2: 1234567891, 3: 4.14, - 4: ctypes.c_double(1.123456789123456789), + 4: 1.123456789123456789, 5: ord("b"), 6: False, 7: [12, 23, 34], - 8: [ctypes.c_long(123456780), ctypes.c_long(123456781), ctypes.c_long(123456782)], + 8: [123456780, 123456781, 123456782], 9: [1.11, 2.22, -3.33, 0], - 10: [ctypes.c_double(1.12345678912345670), ctypes.c_double(-1.12345678912345671)], + 10: [1.12345678912345670, -1.12345678912345671], } -# may not be configured for extended types -if args.data_format == "json": - returned_data[2] = ctypes.c_long(12345) - returned_data[8] = [ - ctypes.c_long(6780), - ctypes.c_long(6781), - ctypes.c_long(6782), - ] # [ctypes.c_long(6780),ctypes.c_long(6781),ctypes.c_long(6782)], +if args.use_pipe and args.data_format == "bytes": + returned_data.update( + { + 2: ctypes.c_long(1234567891), + 4: ctypes.c_double(1.123456789123456789), + 8: [ + ctypes.c_long(123456780), + ctypes.c_long(123456781), + ctypes.c_long(123456782), + ], + 10: [ + ctypes.c_double(1.12345678912345670), + ctypes.c_double(-1.12345678912345671), + ], + } + ) +if args.use_pipe and args.data_format == "json": + returned_data.update( + { + 2: ctypes.c_long(12345), + 4: ctypes.c_double(1.123456789123456789), + 8: [ctypes.c_long(6780), ctypes.c_long(6781), ctypes.c_long(6782)], + 10: [ + ctypes.c_double(1.12345678912345670), + ctypes.c_double(-1.12345678912345671), + ], + } + ) + +status = SUCCESS + +# test index vs received data +def checkData(index, data): + global status + if not args.silent: + print(" ", expected_type[index], "request:", data) + + if isinstance(expected_data[index], list): + for e, d in zip(expected_data[index], data): + if abs(e - d) > 10e-6: + print( + f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" + ) + status = FAIL + # raise Exception(f"Mismatch in {expected_type[i]}") + + elif abs(data - expected_data[index]) > 10e-6: + print( + f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" + ) + status = FAIL + # raise Exception(f"Mismatch in {expected_type[i]}") def run_pipe_communication(data_format, pipe_name): - compiler_interface = PipeCompilerInterface(data_format, "/tmp/" + pipe_name) + compiler_interface = PipeCompilerInterface(data_format, "./" + pipe_name) if not args.silent: print("PipeCompilerInterface init...") compiler_interface.reset_pipes() - status = SUCCESS i = 0 while True: i += 1 @@ -137,24 +277,7 @@ def run_pipe_communication(data_format, pipe_name): if len(data) == 1: data = data[0] - if not args.silent: - print(" ", expected_type[i], "request:", data) - - if isinstance(expected_data[i], list): - for e, d in zip(expected_data[i], data): - if abs(e - d) > 10e-6: - print( - f"Error: Expected {expected_type[i]} request: {expected_data[i]}, Received: {data}" - ) - status = FAIL - # raise Exception(f"Mismatch in {expected_type[i]}") - - elif abs(data - expected_data[i]) > 10e-6: - print( - f"Error: Expected {expected_type[i]} request: {expected_data[i]}, Received: {data}" - ) - status = FAIL - # raise Exception(f"Mismatch in {expected_type[i]}") + checkData(i, data) compiler_interface.populate_buffer(returned_data[i]) @@ -166,46 +289,58 @@ def run_pipe_communication(data_format, pipe_name): compiler_interface.reset_pipes() -class service_server(helloMLBridge_pb2_grpc.HelloMLBridgeService): - def __init__(self, data_format, pipe_name): - # self.serdes = SerDes.SerDes(data_format, pipe_name) - # self.serdes.init() +class service_server(MLBridgeTest_pb2_grpc.MLBridgeTestService): + def __init__(self): pass def getAdvice(self, request, context): try: - print(request) - print("Entered getAdvice") - print("Data: ", request.tensor) - reply = helloMLBridge_pb2.ActionRequest(action=1) + request_type = [var for var in dir(request) if "request" in var] + data = getattr(request, request_type[0]) + checkData(args.test_number, data) + if status == FAIL: + os.system("touch mlbridge-grpc-fail.txt") + reply = MLBridgeTest_pb2.Reply(action=returned_data[args.test_number]) return reply except: - reply = helloMLBridge_pb2.ActionRequest(action=-1) + reply = MLBridgeTest_pb2.Reply(action=-1) return reply -def test_func(): - data = 3.24 - import struct - - print(data, type(data)) - byte_data = struct.pack("f", data) - print(byte_data, len(byte_data)) - - print("decoding...") - decoded = float(byte_data) +def run_grpc_communication(): + # parent with test_number 0 spawns different servers + if args.test_number == 0: + process_list = [] + for i in range(1, len(expected_type) + 1): + p = subprocess.Popen( + f"python mlbridge-test.py --use_grpc --server_port={args.server_port} --silent={args.silent} --test_number={i}".split(), + ) + process_list.append(p) + + time.sleep(10) + global status + for p in process_list: + if os.path.isfile("mlbridge-grpc-fail.txt"): + status = FAIL + os.system("rm mlbridge-grpc-fail.txt") + p.terminate() + exit(status) - print(decoded, type(decoded)) + # servers serve different datatypes + else: + compiler_interface = GrpcCompilerInterface( + mode="server", + add_server_method=MLBridgeTest_pb2_grpc.add_MLBridgeTestServiceServicer_to_server, + grpc_service_obj=service_server(), + hostport=args.server_port + args.test_number, + ) + compiler_interface.start_server() if __name__ == "__main__": if args.use_pipe: run_pipe_communication(args.data_format, args.pipe_name) elif args.use_grpc: - compiler_interface = GrpcCompilerInterface( - mode="server", - add_server_method=helloMLBridge_pb2_grpc.add_HelloMLBridgeServiceServicer_to_server, - grpc_service_obj=service_server(), - hostport=args.server_port, - ) - compiler_interface.start_server() + run_grpc_communication() + elif args.export_onnx: + export_onnx_model() diff --git a/test/mlbridge-test.sh b/test/mlbridge-test.sh index 9ac1d49a..68d59a34 100644 --- a/test/mlbridge-test.sh +++ b/test/mlbridge-test.sh @@ -52,11 +52,12 @@ python $SERVER_FILE --use_pipe=True --data_format=json --pipe_name=mlbridgepipe2 SERVER_PID=$! run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=pipe-json --test-pipe-name=mlbridgepipe2 --silent -exit $STATUS +echo -e "${BLUE}${BOLD}Testing MLBridge [grpc]${NC}" +python $SERVER_FILE --use_grpc --server_port=50155 --silent=True & +SERVER_PID=$! +run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50155" --silent -# python $SERVER_FILE --use_grpc --server_port=50065 & -# echo "Test [grpc]:" -# run_test $BUILD_DIR/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50065" +echo -e "${BLUE}${BOLD}Testing MLBridge [onnx]${NC}" +run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=onnx --onnx-model-path=$REPO_DIR/test/onnx/dummy_model.onnx -# echo "Test [onnx]:" -# $BUILD_DIR/MLCompilerBridgeTest --test-config=onnx +exit $STATUS diff --git a/test/onnx/dummy_model.onnx b/test/onnx/dummy_model.onnx new file mode 100644 index 00000000..09f46370 Binary files /dev/null and b/test/onnx/dummy_model.onnx differ diff --git a/test/protos/MLBridgeTest_bool.proto b/test/protos/MLBridgeTest_bool.proto index 614294ac..3cea3b72 100644 --- a/test/protos/MLBridgeTest_bool.proto +++ b/test/protos/MLBridgeTest_bool.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { bool data = 1; } +message Request { bool request_bool = 1; } message Reply { bool action = 1; } diff --git a/test/protos/MLBridgeTest_char.proto b/test/protos/MLBridgeTest_char.proto index d164cbd0..49827602 100644 --- a/test/protos/MLBridgeTest_char.proto +++ b/test/protos/MLBridgeTest_char.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { string data = 1; } +message Request { string request_char = 1; } message Reply { string action = 1; } diff --git a/test/protos/MLBridgeTest_double.proto b/test/protos/MLBridgeTest_double.proto index d7528eef..8ad7c130 100644 --- a/test/protos/MLBridgeTest_double.proto +++ b/test/protos/MLBridgeTest_double.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { double data = 1; } +message Request { double request_double = 1; } message Reply { double action = 1; } diff --git a/test/protos/MLBridgeTest_float.proto b/test/protos/MLBridgeTest_float.proto index de88aed9..d933fdf6 100644 --- a/test/protos/MLBridgeTest_float.proto +++ b/test/protos/MLBridgeTest_float.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { float data = 1; } +message Request { float request_float = 1; } message Reply { float action = 1; } diff --git a/test/protos/MLBridgeTest_int.proto b/test/protos/MLBridgeTest_int.proto index 23148449..d21ff81a 100644 --- a/test/protos/MLBridgeTest_int.proto +++ b/test/protos/MLBridgeTest_int.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { int32 data = 1; } +message Request { int32 request_int = 1; } message Reply { int32 action = 1; } diff --git a/test/protos/MLBridgeTest_long.proto b/test/protos/MLBridgeTest_long.proto index 6bbfce86..41a14194 100644 --- a/test/protos/MLBridgeTest_long.proto +++ b/test/protos/MLBridgeTest_long.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { int64 data = 1; } +message Request { int64 request_long = 1; } message Reply { int64 action = 1; } diff --git a/test/protos/MLBridgeTest_vec_double.proto b/test/protos/MLBridgeTest_vec_double.proto index 669512d4..7911efdd 100644 --- a/test/protos/MLBridgeTest_vec_double.proto +++ b/test/protos/MLBridgeTest_vec_double.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated double data = 1; } +message Request { repeated double request_vec_double = 1; } message Reply { repeated double action = 1; } diff --git a/test/protos/MLBridgeTest_vec_float.proto b/test/protos/MLBridgeTest_vec_float.proto index c6da7604..9bcc6753 100644 --- a/test/protos/MLBridgeTest_vec_float.proto +++ b/test/protos/MLBridgeTest_vec_float.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated float data = 1; } +message Request { repeated float request_vec_float = 1; } message Reply { repeated float action = 1; } diff --git a/test/protos/MLBridgeTest_vec_int.proto b/test/protos/MLBridgeTest_vec_int.proto index 9bb741be..2fb22981 100644 --- a/test/protos/MLBridgeTest_vec_int.proto +++ b/test/protos/MLBridgeTest_vec_int.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated int32 data = 1; } +message Request { repeated int32 request_vec_int = 1; } message Reply { repeated int32 action = 1; } diff --git a/test/protos/MLBridgeTest_vec_long.proto b/test/protos/MLBridgeTest_vec_long.proto index d82bf92e..9757d1d9 100644 --- a/test/protos/MLBridgeTest_vec_long.proto +++ b/test/protos/MLBridgeTest_vec_long.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated int64 data = 1; } +message Request { repeated int64 request_vec_long = 1; } message Reply { repeated int64 action = 1; }