diff --git a/test/MLBridgeTest.cpp b/test/MLBridgeTest.cpp index 07062008..19153833 100644 --- a/test/MLBridgeTest.cpp +++ b/test/MLBridgeTest.cpp @@ -51,6 +51,10 @@ static llvm::cl::opt cl_pipe_name("test-pipe-name", llvm::cl::Hidden, llvm::cl::init(""), llvm::cl::desc("Name for pipe file")); +static llvm::cl::opt + cl_onnx_path("onnx-model-path", llvm::cl::Hidden, llvm::cl::init(""), + llvm::cl::desc("Path to onnx model")); + static llvm::cl::opt cl_test_config( "test-config", llvm::cl::Hidden, llvm::cl::desc("Method for communication with python model")); @@ -67,6 +71,7 @@ BaseSerDes::Kind SerDesType; std::string test_config; std::string pipe_name; std::string server_address; +std::string onnx_path; // send value of type T1. Test received value of type T2 against expected value template @@ -219,7 +224,35 @@ int testGRPC() { return 0; } -int testONNX() { return 0; } +class ONNXTest : public MLBridgeTestEnv { +public: + int run(int expectedAction) { + onnx_path = cl_onnx_path.getValue(); + if (onnx_path == "") { + std::cerr << "ONNX model path must be specified via " + "--onnx-model-path=\n"; + exit(1); + } + FeatureVector.clear(); + int n = 100; + for (int i = 0; i < n; i++) { + float delta = (float)(i - expectedAction) / n; + FeatureVector.push_back(delta * delta); + } + + Agent *agent = new Agent(onnx_path); + std::map agents; + agents["agent"] = agent; + MLRunner = std::make_unique(this, agents, nullptr); + MLRunner->evaluate(); + if (lastAction != expectedAction) { + std::cerr << "Error: Expected action: " << expectedAction + << ", Computed action: " << lastAction << "\n"; + exit(1); + } + return 0; + } +}; } // namespace @@ -236,9 +269,10 @@ int main(int argc, char **argv) { } else if (test_config == "grpc") { server_address = cl_server_address.getValue(); testGRPC(); - } else if (test_config == "onnx") - testONNX(); - else + } else if (test_config == "onnx") { + ONNXTest t; + t.run(20); + } else std::cerr << "--test-config must be provided from [pipe-bytes, pipe-json, " "grpc, onnx]\n"; return 0; diff --git a/test/include/HelloMLBridge_Env.h b/test/include/HelloMLBridge_Env.h index 36f4faa4..243013ff 100644 --- a/test/include/HelloMLBridge_Env.h +++ b/test/include/HelloMLBridge_Env.h @@ -19,6 +19,7 @@ class MLBridgeTestEnv : public Environment { MLBridgeTestEnv() { setNextAgent("agent"); }; Observation &reset() override; Observation &step(Action) override; + Action lastAction; protected: std::vector FeatureVector; @@ -28,7 +29,7 @@ Observation &MLBridgeTestEnv::step(Action Action) { CurrObs.clear(); std::copy(FeatureVector.begin(), FeatureVector.end(), std::back_inserter(CurrObs)); - llvm::outs() << "Action: " << Action << "\n"; + lastAction = Action; setDone(); return CurrObs; } diff --git a/test/inference/HelloMLBridge_Env.h b/test/inference/HelloMLBridge_Env.h deleted file mode 100644 index 0faa328f..00000000 --- a/test/inference/HelloMLBridge_Env.h +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM -// Exceptions. See the LICENSE file for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "MLModelRunner/ONNXModelRunner/environment.h" -#include "MLModelRunner/ONNXModelRunner/utils.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/raw_ostream.h" - -using namespace MLBridge; -class HelloMLBridgeEnv : public Environment { - Observation CurrObs; - -public: - HelloMLBridgeEnv() { setNextAgent("agent"); }; - Observation &reset() override; - Observation &step(Action) override; - -protected: - std::vector FeatureVector; -}; - -Observation &HelloMLBridgeEnv::step(Action Action) { - CurrObs.clear(); - std::copy(FeatureVector.begin(), FeatureVector.end(), - std::back_inserter(CurrObs)); - llvm::outs() << "Action: " << Action << "\n"; - setDone(); - return CurrObs; -} - -Observation &HelloMLBridgeEnv::reset() { - std::copy(FeatureVector.begin(), FeatureVector.end(), - std::back_inserter(CurrObs)); - return CurrObs; -} diff --git a/test/mlbridge-test.py b/test/mlbridge-test.py index 91664604..e1ecc6be 100644 --- a/test/mlbridge-test.py +++ b/test/mlbridge-test.py @@ -12,6 +12,7 @@ import sys import os import torch, torch.nn as nn +import torch.onnx import subprocess import time @@ -29,9 +30,7 @@ SUCCESS = 0 parser = argparse.ArgumentParser() -parser.add_argument( - "--use_pipe", type=bool, default=False, help="Use pipe or not", required=False -) +parser.add_argument("--use_pipe", default=False, help="Use pipe or not", required=False) parser.add_argument( "--data_format", type=str, @@ -64,6 +63,12 @@ help="Datatype number for test", default=0, ) +parser.add_argument( + "--export_onnx", + help="Export onnx test model", + required=False, + default=False, +) args = parser.parse_args() if args.test_number <= 1: @@ -139,13 +144,25 @@ class DummyModel(nn.Module): - def __init__(self, input_dim=10): + def __init__(self): nn.Module.__init__(self) - self.fc1 = nn.Linear(input_dim, 1) def forward(self, input): - x = self.fc1(input) - return x + return 2 - input + + +def export_onnx_model(input_dim=100): + onnx_filename = "./onnx/dummy_model.onnx" + dummy_value = torch.randn(1, input_dim) + torch.onnx.export( + DummyModel(), + dummy_value, + onnx_filename, + input_names=["obs"], + verbose=True, + export_params=True, + ) + print(f"Model exported to {onnx_filename}") expected_type = { @@ -325,3 +342,5 @@ def run_grpc_communication(): run_pipe_communication(args.data_format, args.pipe_name) elif args.use_grpc: run_grpc_communication() + elif args.export_onnx: + export_onnx_model() diff --git a/test/mlbridge-test.sh b/test/mlbridge-test.sh index a491f515..68d59a34 100644 --- a/test/mlbridge-test.sh +++ b/test/mlbridge-test.sh @@ -57,6 +57,7 @@ python $SERVER_FILE --use_grpc --server_port=50155 --silent=True & SERVER_PID=$! run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50155" --silent +echo -e "${BLUE}${BOLD}Testing MLBridge [onnx]${NC}" +run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=onnx --onnx-model-path=$REPO_DIR/test/onnx/dummy_model.onnx + exit $STATUS -# echo "Test [onnx]:" -# $BUILD_DIR/MLCompilerBridgeTest --test-config=onnx diff --git a/test/onnx/dummy_model.onnx b/test/onnx/dummy_model.onnx new file mode 100644 index 00000000..09f46370 Binary files /dev/null and b/test/onnx/dummy_model.onnx differ