Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports adapters in download dir #1690

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions serving/docs/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ There are several options to choose between for managing your set of adapters.

The easiest option is to use an adapters local directory.
This is as easy as adding a directory of adapters alongside your model files.
It can be inside either the same directory as the serving.properties or the directory of an s3 `model_id`.
It should contain an overarching adapters directory with an artifact directory for each adapter to add.
This works best for having a manageable set of adapters as they are all loaded on startup.
It can be used in conjunction with services like [Amazon SageMaker Single Model Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-single-model.html).
Expand Down
35 changes: 32 additions & 3 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ public void beforeSuite() throws IOException {
ZipUtils.zip(Paths.get("build/classes/java/test/"), dest, true);
String engineCacheDir = Utils.getEngineCacheDir().toString();
System.setProperty("DJL_CACHE_DIR", "build/cache");
System.setProperty("DJL_TEST_S3_NO_CREDENTIALS", "true");
System.setProperty("ENGINE_CACHE_DIR", engineCacheDir);
}

@AfterSuite
public void afterSuite() {
System.clearProperty("DJL_CACHE_DIR");
System.clearProperty("ENGINE_CACHE_DIR");
System.clearProperty("DJL_TEST_S3_NO_CREDENTIALS");
}

@AfterMethod
Expand Down Expand Up @@ -331,6 +333,26 @@ public void testAdapterWorkflows()
}
}

@Test
public void testAdaptersInModelDir()
throws ServerStartupException, GeneralSecurityException, ParseException, IOException,
InterruptedException, ReflectiveOperationException {
ModelServer server =
initTestServer("src/test/resources/adaptersInModelDir/config.properties");
try {
assertTrue(server.isRunning());
Channel channel = initTestChannel();

testAdapterPredict(channel, "modelProp", "a", "");

channel.close().sync();

ConfigManagerTest.testSsl();
} finally {
server.stop();
}
}

@Test
public void testAdapterSME()
throws ServerStartupException, GeneralSecurityException, ParseException, IOException,
Expand Down Expand Up @@ -917,16 +939,23 @@ private void testAdapterNoPredictRegister() throws InterruptedException {
}

private void testAdapterPredict(Channel channel) throws InterruptedException {
testAdapterPredict(channel, "adaptecho", "adaptable", "opt");
}

private void testAdapterPredict(
Channel channel, String model, String adapter, String adaptOption)
throws InterruptedException {
logTestFunction();
String url = "/predictions/adaptecho?adapter=adaptable";
String url = "/predictions/" + model + "?adapter=" + adapter;
String payload = "testPredictAdapter";
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, url);
req.content().writeBytes("testPredictAdapter".getBytes(StandardCharsets.UTF_8));
req.content().writeBytes(payload.getBytes(StandardCharsets.UTF_8));
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
request(channel, req);
assertHttpOk();
assertEquals(result, "adaptableopttestPredictAdapter");
assertEquals(result, adapter + adaptOption + payload);
}

private void testAdapterDirPredict(Channel channel) throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8443
model_store=build/models
load_models=src/test/resources/adaptersInModelDir/modelProp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
option.pythonExecutable=python3
option.model_id=s3://djl-ai/resources/test-models/s3AdaptersInModelDir
option.entryPoint=model.py
option.handler=handle
retry_threshold=0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Adapter A placeholder
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Adapter B placeholder
66 changes: 40 additions & 26 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -272,30 +272,38 @@

if (models.isEmpty()) {
// Check for adapters on first load
if (Files.isDirectory(modelDir.resolve("adapters"))) {
Files.list(modelDir.resolve("adapters"))
.forEach(
adapterDir -> {
eventManager.onAdapterLoading(this, adapterDir);
long start = System.nanoTime();
String adapterName = adapterDir.getFileName().toString();
Adapter adapter =
Adapter.newInstance(
this,
adapterName,
adapterDir.toAbsolutePath().toString(),
Collections.emptyMap());
registerAdapter(adapter);
long d = (System.nanoTime() - start) / 1000;
Metric me =
new Metric(
"LoadAdapter",
d,
Unit.MICROSECONDS,
dimension);
MODEL_METRIC.info("{}", me);
eventManager.onAdapterLoaded(this, adapter);
});
List<Path> possibleAdapterDirs = new ArrayList<>(2);
possibleAdapterDirs.add(modelDir);
if (downloadDir != null && !modelDir.equals(downloadDir)) {
possibleAdapterDirs.add(downloadDir);
}
for (Path parentDir : possibleAdapterDirs) {
if (Files.isDirectory(parentDir.resolve("adapters"))) {
Files.list(parentDir.resolve("adapters"))
.forEach(
adapterDir -> {
eventManager.onAdapterLoading(this, adapterDir);
long start = System.nanoTime();
String adapterName =
adapterDir.getFileName().toString();
Adapter adapter =
Adapter.newInstance(
this,
adapterName,
adapterDir.toAbsolutePath().toString(),
Collections.emptyMap());
registerAdapter(adapter);
long d = (System.nanoTime() - start) / 1000;
Metric me =
new Metric(
"LoadAdapter",
d,
Unit.MICROSECONDS,
dimension);
MODEL_METRIC.info("{}", me);
eventManager.onAdapterLoaded(this, adapter);
});
}
}
}

Expand Down Expand Up @@ -1088,7 +1096,13 @@
};
} else {
logger.info("s5cmd is not installed, using aws cli");
commands = new String[] {"aws", "s3", "sync", src, dest};
if (Boolean.parseBoolean(
Utils.getEnvOrSystemProperty("DJL_TEST_S3_NO_CREDENTIALS"))) {
logger.info("Skipping s3 credentials");
commands = new String[] {"aws", "s3", "sync", "--no-sign-request", src, dest};
github-advanced-security[bot] marked this conversation as resolved.
Dismissed
Show resolved Hide resolved
} else {
commands = new String[] {"aws", "s3", "sync", src, dest};
}
}
Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start();
String logOutput;
Expand All @@ -1097,7 +1111,7 @@
}
int exitCode = exec.waitFor();
if (0 != exitCode || logOutput.startsWith("ERROR ")) {
logger.error(logOutput);
logger.error("Download error: {}", logOutput);
throw new EngineException("Download model failed.");
} else {
logger.debug(logOutput);
Expand Down
Loading