diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html
index 73afdccff83..24844a97b8e 100644
--- a/tfjs-core/benchmarks/index.html
+++ b/tfjs-core/benchmarks/index.html
@@ -73,7 +73,7 @@
TensorFlow.js Model Benchmark
const state = {
numRuns: 50,
- benchmark: 'mobilenet_v2',
+ benchmark: 'custom_forward',
run: (v) => {
runBenchmark();
},
diff --git a/tfjs-core/benchmarks/modelConfig.js b/tfjs-core/benchmarks/modelConfig.js
index d75caec96d4..4e6e26e2eeb 100644
--- a/tfjs-core/benchmarks/modelConfig.js
+++ b/tfjs-core/benchmarks/modelConfig.js
@@ -72,6 +72,21 @@ const sentences = [
];
const benchmarks = {
+ 'custom_forward': {
+ load: async () => {
+ return {};
+ },
+ predictFunc: () => {
+ // Setup code for the forward pass. Only gets called once.
+ const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
+ const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
+
+ return () => {
+ // Forward pass.
+ return tf.matMul(a, b);
+ }
+ }
+ },
'mobilenet_v2': {
load: async () => {
const url =