diff --git a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java index 40735ddeca7f..b3d9b47fd5b2 100644 --- a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java @@ -405,6 +405,8 @@ private void drawMask(Mask mask) { int imageHeight = image.getHeight(); int x = (int) (mask.getX() * imageWidth); int y = (int) (mask.getY() * imageHeight); + int w = (int) (mask.getWidth() * imageWidth); + int h = (int) (mask.getHeight() * imageHeight); float[][] probDist = mask.getProbDist(); // Correct some coordinates of box when going out of image if (x < 0) { @@ -416,15 +418,17 @@ private void drawMask(Mask mask) { BufferedImage maskImage = new BufferedImage( - probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB); - for (int xCor = 0; xCor < probDist.length; xCor++) { - for (int yCor = 0; yCor < probDist[xCor].length; yCor++) { - float opacity = probDist[xCor][yCor] * 0.8f; + probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB); + for (int yCor = 0; yCor < probDist.length; yCor++) { + for (int xCor = 0; xCor < probDist[0].length; xCor++) { + float opacity = probDist[yCor][xCor] * 0.8f; maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB()); } } + java.awt.Image img = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH); + Graphics2D gR = (Graphics2D) image.getGraphics(); - gR.drawImage(maskImage, x, y, null); + gR.drawImage(img, x, y, null); gR.dispose(); } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java index efd05d7be269..458122bbc0b2 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java @@ -12,7 +12,6 @@ */ package ai.djl.modality.cv.translator; -import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Mask; @@ -68,14 +67,6 @@ public void prepare(TranslatorContext ctx) throws IOException { } } - /** {@inheritDoc} */ - @Override - public NDList processInput(TranslatorContext ctx, Image image) { - ctx.setAttachment("originalHeight", image.getHeight()); - ctx.setAttachment("originalWidth", image.getWidth()); - return super.processInput(ctx, image); - } - /** {@inheritDoc} */ @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { @@ -102,18 +93,15 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { double w = box[2] / rescaledWidth - x; double h = box[3] / rescaledHeight - y; - int maskW = (int) (w * (int) ctx.getAttachment("originalWidth")); - int maskH = (int) (h * (int) ctx.getAttachment("originalHeight")); - // Reshape mask to actual image bounding box shape. NDArray array = masks.get(i); Shape maskShape = array.getShape(); - array = array.reshape(maskShape.addAll(new Shape(1))); - NDArray maskArray = NDImageUtils.resize(array, maskW, maskH).transpose(); - float[] flattened = maskArray.toFloatArray(); - float[][] maskFloat = new float[maskW][maskH]; - for (int j = 0; j < maskW; j++) { - System.arraycopy(flattened, j * maskH, maskFloat[j], 0, maskH); + int maskH = (int) maskShape.get(0); + int maskW = (int) maskShape.get(1); + float[] flattened = array.toFloatArray(); + float[][] maskFloat = new float[maskH][maskW]; + for (int j = 0; j < maskH; j++) { + System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW); } Mask mask = new Mask(x, y, w, h, maskFloat); diff --git a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java index b6a2292714e6..2c664380595a 100644 --- a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java +++ b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java @@ -358,6 +358,8 @@ private void drawMask(BufferedImage img, Mask mask) { int imageHeight = img.getHeight(); int x = (int) (mask.getX() * imageWidth); int y = (int) (mask.getY() * imageHeight); + int w = (int) (mask.getWidth() * imageWidth); + int h = (int) (mask.getHeight() * imageHeight); float[][] probDist = mask.getProbDist(); // Correct some coordinates of box when going out of image if (x < 0) { @@ -368,15 +370,17 @@ private void drawMask(BufferedImage img, Mask mask) { } BufferedImage maskImage = - new BufferedImage(probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB); - for (int xCor = 0; xCor < probDist.length; xCor++) { - for (int yCor = 0; yCor < probDist[xCor].length; yCor++) { - float opacity = probDist[xCor][yCor] * 0.8f; - maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).getRGB()); + new BufferedImage(probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB); + for (int yCor = 0; yCor < probDist.length; yCor++) { + for (int xCor = 0; xCor < probDist[0].length; xCor++) { + float opacity = probDist[yCor][xCor] * 0.8f; + maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB()); } } + java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH); + Graphics2D gR = (Graphics2D) img.getGraphics(); - gR.drawImage(maskImage, x, y, null); + gR.drawImage(scaled, x, y, null); gR.dispose(); }