Skip to content

Commit

Permalink
[api] Refactor drawMask() for instance segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 6, 2024
1 parent 90d6019 commit a21cd86
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
14 changes: 9 additions & 5 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ private void drawMask(Mask mask) {
int imageHeight = image.getHeight();
int x = (int) (mask.getX() * imageWidth);
int y = (int) (mask.getY() * imageHeight);
int w = (int) (mask.getWidth() * imageWidth);
int h = (int) (mask.getHeight() * imageHeight);
float[][] probDist = mask.getProbDist();
// Correct some coordinates of box when going out of image
if (x < 0) {
Expand All @@ -416,15 +418,17 @@ private void drawMask(Mask mask) {

BufferedImage maskImage =
new BufferedImage(
probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB);
for (int xCor = 0; xCor < probDist.length; xCor++) {
for (int yCor = 0; yCor < probDist[xCor].length; yCor++) {
float opacity = probDist[xCor][yCor] * 0.8f;
probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB);
for (int yCor = 0; yCor < probDist.length; yCor++) {
for (int xCor = 0; xCor < probDist[0].length; xCor++) {
float opacity = probDist[yCor][xCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB());
}
}
java.awt.Image img = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH);

Graphics2D gR = (Graphics2D) image.getGraphics();
gR.drawImage(maskImage, x, y, null);
gR.drawImage(img, x, y, null);
gR.dispose();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
Expand Down Expand Up @@ -68,14 +67,6 @@ public void prepare(TranslatorContext ctx) throws IOException {
}
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image image) {
ctx.setAttachment("originalHeight", image.getHeight());
ctx.setAttachment("originalWidth", image.getWidth());
return super.processInput(ctx, image);
}

/** {@inheritDoc} */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
Expand All @@ -102,18 +93,15 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
double w = box[2] / rescaledWidth - x;
double h = box[3] / rescaledHeight - y;

int maskW = (int) (w * (int) ctx.getAttachment("originalWidth"));
int maskH = (int) (h * (int) ctx.getAttachment("originalHeight"));

// Reshape mask to actual image bounding box shape.
NDArray array = masks.get(i);
Shape maskShape = array.getShape();
array = array.reshape(maskShape.addAll(new Shape(1)));
NDArray maskArray = NDImageUtils.resize(array, maskW, maskH).transpose();
float[] flattened = maskArray.toFloatArray();
float[][] maskFloat = new float[maskW][maskH];
for (int j = 0; j < maskW; j++) {
System.arraycopy(flattened, j * maskH, maskFloat[j], 0, maskH);
int maskH = (int) maskShape.get(0);
int maskW = (int) maskShape.get(1);
float[] flattened = array.toFloatArray();
float[][] maskFloat = new float[maskH][maskW];
for (int j = 0; j < maskH; j++) {
System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW);
}
Mask mask = new Mask(x, y, w, h, maskFloat);

Expand Down
16 changes: 10 additions & 6 deletions extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ private void drawMask(BufferedImage img, Mask mask) {
int imageHeight = img.getHeight();
int x = (int) (mask.getX() * imageWidth);
int y = (int) (mask.getY() * imageHeight);
int w = (int) (mask.getWidth() * imageWidth);
int h = (int) (mask.getHeight() * imageHeight);
float[][] probDist = mask.getProbDist();
// Correct some coordinates of box when going out of image
if (x < 0) {
Expand All @@ -368,15 +370,17 @@ private void drawMask(BufferedImage img, Mask mask) {
}

BufferedImage maskImage =
new BufferedImage(probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB);
for (int xCor = 0; xCor < probDist.length; xCor++) {
for (int yCor = 0; yCor < probDist[xCor].length; yCor++) {
float opacity = probDist[xCor][yCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).getRGB());
new BufferedImage(probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB);
for (int yCor = 0; yCor < probDist.length; yCor++) {
for (int xCor = 0; xCor < probDist[0].length; xCor++) {
float opacity = probDist[yCor][xCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB());
}
}
java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH);

Graphics2D gR = (Graphics2D) img.getGraphics();
gR.drawImage(maskImage, x, y, null);
gR.drawImage(scaled, x, y, null);
gR.dispose();
}

Expand Down

0 comments on commit a21cd86

Please sign in to comment.