Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] Download jni lib files for cuda #3326

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/** Utilities for finding the Huggingface tokenizer native binary on the System. */
@SuppressWarnings("MissingJavadocMethod")
Expand All @@ -33,6 +35,9 @@ public final class LibUtils {
private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);

private static final String LIB_NAME = System.mapLibraryName("tokenizers");
private static final Pattern VERSION_PATTERN =
Pattern.compile(
"(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?");

private static EngineException exception;

Expand Down Expand Up @@ -63,7 +68,7 @@ private static void loadLibrary() {
libs = new String[] {LIB_NAME};
}

Path dir = copyJniLibraryFromClasspath(libs);
Path dir = copyJniLibrary(libs);
logger.debug("Loading huggingface library from: {}", dir);

for (String libName : libs) {
Expand All @@ -78,9 +83,10 @@ private static void loadLibrary() {
}
}

private static Path copyJniLibraryFromClasspath(String[] libs) {
private static Path copyJniLibrary(String[] libs) {
Path cacheDir = Utils.getEngineCacheDir("tokenizers");
Platform platform = Platform.detectPlatform("tokenizers");
String os = platform.getOsPrefix();
String classifier = platform.getClassifier();
String flavor = platform.getFlavor();
String version = platform.getVersion();
Expand All @@ -91,6 +97,20 @@ private static Path copyJniLibraryFromClasspath(String[] libs) {
return dir.toAbsolutePath();
}

// For Linux cuda 12.x, download JNI library
if (flavor.startsWith("cu12") && !"win".equals(os)) {
Matcher matcher = VERSION_PATTERN.matcher(version);
if (!matcher.matches()) {
throw new EngineException("Unexpected version: " + version);
}
String jniVersion = matcher.group(1);
String djlVersion = matcher.group(3);

downloadJniLib(dir, path, djlVersion, jniVersion, classifier, flavor);
return dir.toAbsolutePath();
}

// Extract JNI library from classpath
Path tmp = null;
try {
Files.createDirectories(cacheDir);
Expand All @@ -114,4 +134,38 @@ private static Path copyJniLibraryFromClasspath(String[] libs) {
}
}
}

private static void downloadJniLib(
Path cacheDir,
Path path,
String djlVersion,
String version,
String classifier,
String flavor) {
String url =
"https://publish.djl.ai/tokenizers/"
+ version
+ "/jnilib/"
+ djlVersion
+ '/'
+ classifier
+ '/'
+ flavor
+ '/'
+ LIB_NAME;
logger.info("Downloading jni {} to cache ...", url);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We don't have cuda for windows
  2. We may only want to support cuda12

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

Path tmp = null;
try (InputStream is = Utils.openUrl(url)) {
Files.createDirectories(cacheDir);
tmp = Files.createTempFile(cacheDir, "jni", "tmp");
Files.copy(is, tmp, StandardCopyOption.REPLACE_EXISTING);
Utils.moveQuietly(tmp, path);
} catch (IOException e) {
throw new EngineException("Cannot download jni files: " + url, e);
} finally {
if (tmp != null) {
Utils.deleteQuietly(tmp);
}
}
}
}
Loading