Skip to content

Commit

Permalink
add disable static option in DJL to allow some model running (#1735)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 authored Jun 21, 2022
1 parent 0208ce9 commit fb61b7c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
23 changes: 21 additions & 2 deletions docs/mxnet/how_to_convert_your_model_to_symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ net.add(nn.Dense(10))
# initialize and hybridize the block
net.initialize()
net.hybridize()
net.hybridize(static_alloc=True, static_shape=True)
# create sample input and run forward once
x = nd.random.uniform(shape=(2, 20))
Expand All @@ -36,7 +36,7 @@ from gluoncv import model_zoo
# get the pretrained model from the gluon model zoo
net = model_zoo.get_model('resnet18_v1', pretrained=True)
net.hybridize()
net.hybridize(static_alloc=True, static_shape=True)
# create a sample input and run forward once (required for tracing)
x = nd.random.uniform(shape=(1, 3, 224, 224))
Expand All @@ -45,3 +45,22 @@ net(x)
# export your model
net.export("sample_model")
```

### hybridize without `static_alloc=True, static_shape=True`

It is always recommended enabling the static settings when exporting Apache MXNet model. This will ensure DJL to have the best performance for inference.

If you run hybridize without `static_alloc=True, static_shape=True`:

```
net.hybridize()
```

you can enable this Java property with DJL:

```
-Dai.djl.mxnet.static_alloc=False -Dai.djl.mxnet.static_shape=False
```

This will ensure we skip the static settings in the inference model and make DJL produce consistent result with Python.

Original file line number Diff line number Diff line change
Expand Up @@ -1887,8 +1887,18 @@ public static CachedOp createCachedOp(
PointerByReference ref = REFS.acquire();

// static_alloc and static_shape are enabled by default
String staticAlloc = "1";
String staticShape = "1";
if (!Boolean.parseBoolean(System.getProperty("ai.djl.mxnet.static_alloc", "true"))) {
staticAlloc = "0";
}
if (!Boolean.parseBoolean(System.getProperty("ai.djl.mxnet.static_shape", "true"))) {
staticShape = "0";
}
String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"};
String[] values = {dataIndices.values().toString(), paramIndices.toString(), "1", "1"};
String[] values = {
dataIndices.values().toString(), paramIndices.toString(), staticAlloc, staticShape
};

checkCall(LIB.MXCreateCachedOpEx(symbolHandle, keys.length, keys, values, ref));
Pointer pointer = ref.getValue();
Expand Down

0 comments on commit fb61b7c

Please sign in to comment.