Skip to content

Commit

Permalink
Supports adapters in download dir (#1690)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Apr 1, 2024
1 parent 91a8736 commit 0711820
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 29 deletions.
1 change: 1 addition & 0 deletions serving/docs/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ There are several options to choose between for managing your set of adapters.

The easiest option is to use an adapters local directory.
This is as easy as adding a directory of adapters alongside your model files.
It can be inside either the same directory as the serving.properties or the directory of an s3 `model_id`.
It should contain an overarching adapters directory with an artifact directory for each adapter to add.
This works best for having a manageable set of adapters as they are all loaded on startup.
It can be used in conjunction with services like [Amazon SageMaker Single Model Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-single-model.html).
Expand Down
36 changes: 33 additions & 3 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ public void beforeSuite() throws IOException {
ZipUtils.zip(Paths.get("build/classes/java/test/"), dest, true);
String engineCacheDir = Utils.getEngineCacheDir().toString();
System.setProperty("DJL_CACHE_DIR", "build/cache");
System.setProperty("DJL_TEST_S3_NO_CREDENTIALS", "true");
System.setProperty("ENGINE_CACHE_DIR", engineCacheDir);
}

@AfterSuite
public void afterSuite() {
System.clearProperty("DJL_CACHE_DIR");
System.clearProperty("ENGINE_CACHE_DIR");
System.clearProperty("DJL_TEST_S3_NO_CREDENTIALS");
}

@AfterMethod
Expand Down Expand Up @@ -331,6 +333,27 @@ public void testAdapterWorkflows()
}
}

// Test disabled as unsigned s3 downloading fails on Github actions
@Test(enabled = false)
public void testAdaptersInModelDir()
throws ServerStartupException, GeneralSecurityException, ParseException, IOException,
InterruptedException, ReflectiveOperationException {
ModelServer server =
initTestServer("src/test/resources/adaptersInModelDir/config.properties");
try {
assertTrue(server.isRunning());
Channel channel = initTestChannel();

testAdapterPredict(channel, "modelProp", "a", "");

channel.close().sync();

ConfigManagerTest.testSsl();
} finally {
server.stop();
}
}

@Test
public void testAdapterSME()
throws ServerStartupException, GeneralSecurityException, ParseException, IOException,
Expand Down Expand Up @@ -917,16 +940,23 @@ private void testAdapterNoPredictRegister() throws InterruptedException {
}

private void testAdapterPredict(Channel channel) throws InterruptedException {
testAdapterPredict(channel, "adaptecho", "adaptable", "opt");
}

private void testAdapterPredict(
Channel channel, String model, String adapter, String adaptOption)
throws InterruptedException {
logTestFunction();
String url = "/predictions/adaptecho?adapter=adaptable";
String url = "/predictions/" + model + "?adapter=" + adapter;
String payload = "testPredictAdapter";
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, url);
req.content().writeBytes("testPredictAdapter".getBytes(StandardCharsets.UTF_8));
req.content().writeBytes(payload.getBytes(StandardCharsets.UTF_8));
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
request(channel, req);
assertHttpOk();
assertEquals(result, "adaptableopttestPredictAdapter");
assertEquals(result, adapter + adaptOption + payload);
}

private void testAdapterDirPredict(Channel channel) throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8443
model_store=build/models
load_models=src/test/resources/adaptersInModelDir/modelProp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
option.pythonExecutable=python3
option.model_id=s3://djl-ai/resources/test-models/s3AdaptersInModelDir
option.entryPoint=model.py
option.handler=handle
retry_threshold=0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Adapter A placeholder
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Adapter B placeholder
66 changes: 40 additions & 26 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -272,30 +272,38 @@ public void load(Device device) throws ModelException, IOException {

if (models.isEmpty()) {
// Check for adapters on first load
if (Files.isDirectory(modelDir.resolve("adapters"))) {
Files.list(modelDir.resolve("adapters"))
.forEach(
adapterDir -> {
eventManager.onAdapterLoading(this, adapterDir);
long start = System.nanoTime();
String adapterName = adapterDir.getFileName().toString();
Adapter adapter =
Adapter.newInstance(
this,
adapterName,
adapterDir.toAbsolutePath().toString(),
Collections.emptyMap());
registerAdapter(adapter);
long d = (System.nanoTime() - start) / 1000;
Metric me =
new Metric(
"LoadAdapter",
d,
Unit.MICROSECONDS,
dimension);
MODEL_METRIC.info("{}", me);
eventManager.onAdapterLoaded(this, adapter);
});
List<Path> possibleAdapterDirs = new ArrayList<>(2);
possibleAdapterDirs.add(modelDir);
if (downloadDir != null && !modelDir.equals(downloadDir)) {
possibleAdapterDirs.add(downloadDir);
}
for (Path parentDir : possibleAdapterDirs) {
if (Files.isDirectory(parentDir.resolve("adapters"))) {
Files.list(parentDir.resolve("adapters"))
.forEach(
adapterDir -> {
eventManager.onAdapterLoading(this, adapterDir);
long start = System.nanoTime();
String adapterName =
adapterDir.getFileName().toString();
Adapter adapter =
Adapter.newInstance(
this,
adapterName,
adapterDir.toAbsolutePath().toString(),
Collections.emptyMap());
registerAdapter(adapter);
long d = (System.nanoTime() - start) / 1000;
Metric me =
new Metric(
"LoadAdapter",
d,
Unit.MICROSECONDS,
dimension);
MODEL_METRIC.info("{}", me);
eventManager.onAdapterLoaded(this, adapter);
});
}
}
}

Expand Down Expand Up @@ -1083,7 +1091,13 @@ private void runS3cmd(String src, String dest) throws ModelException {
};
} else {
logger.info("s5cmd is not installed, using aws cli");
commands = new String[] {"aws", "s3", "sync", src, dest};
if (Boolean.parseBoolean(
Utils.getEnvOrSystemProperty("DJL_TEST_S3_NO_CREDENTIALS"))) {
logger.info("Skipping s3 credentials");
commands = new String[] {"aws", "s3", "sync", "--no-sign-request", src, dest};
} else {
commands = new String[] {"aws", "s3", "sync", src, dest};
}
}
Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start();
String logOutput;
Expand All @@ -1092,7 +1106,7 @@ private void runS3cmd(String src, String dest) throws ModelException {
}
int exitCode = exec.waitFor();
if (0 != exitCode || logOutput.startsWith("ERROR ")) {
logger.error(logOutput);
logger.error("Download error: {}", logOutput);
throw new EngineException("Download model failed.");
} else {
logger.debug(logOutput);
Expand Down

0 comments on commit 0711820

Please sign in to comment.