Tuesday, August 03, 2010

Implementing Smart Blur in Java


Original image. Click to enlarge.


Image with Smart Blur applied. Notice that outlines are
preserved, even where the oranges overlap.


One of my favorite Photoshop effects is Smart Blur, which provides a seemingly effortless way to smooth out JPEG artifacts, remove blemishes from skin in photographs of people, etc. Its utility lies in the fact that despite the considerable blurriness it imparts to many regions of an image, it preserves outlines and fine details (the more important parts of an image, usually). Thus it gives the effect of magically blurring only those parts of the image that you want to be blurred.

The key to how Smart Blur works is that it preferentially blurs parts of an image that are sparse in detail (rich in low-frequency information) while leaving untouched the parts of the image that are comparatively rich in detail (rich in high-frequency information). Abrupt transitions in tone are ignored; areas of subtle change are smoothed (and thus made even more subtle).

The algorithm is quite straightforward:

1. March through the image pixel by pixel.
2. For each pixel, analyze an adjacent region (say, the adjoining 5 pixel by 5 pixel square).
3. Calculate some metric of pixel variance for that region.
4. Compare the variance to some predetermined threshold value.
5. If the variance exceeds the threshold, do nothing.
6. If the variance is less than the threshold, apply blurring to the source pixel. But vary the amount of blurring according to the variance: low variance, more blurring (high variance, less blurring).

In the implementation presented below, I start by cloning the current image and massively blurring the entire (cloned) image. Then I march through the pixels of the original image and begin doing the region-by-region analysis. When I need to apply blurring, I derive the new pixel by linear interpolation between original and cloned-image pixels.

So the first thing we need is a routine for linear interpolation between two values; and a corresponding routine for linear interpolation between two pixel values.

Linear interpolation is easy:

public double lerp( double a, double b, double amt) {
return a + amt * ( b - a );
}

Linear interpolation between pixels is tedious-looking but straightforward:

int lerpPixel( int oldpixel, int newpixel, double amt ) {

int oldRed = ( oldpixel >> 16 ) & 255;
int newRed = ( newpixel >> 16 ) & 255;
int red = (int) lerp( (double)oldRed, (double)newRed, amt ) & 255;

int oldGreen = ( oldpixel >> 8 ) & 255;
int newGreen = ( newpixel >> 8 ) & 255;
int green = (int) lerp( (double)oldGreen, (double)newGreen, amt ) & 255;

int oldBlue = oldpixel & 255;
int newBlue = newpixel & 255;
int blue = (int) lerp( (double)oldBlue, (double)newBlue, amt ) & 255;

return ( red << 16 ) | ( green << 8 ) | blue;
}
Another essential routine that we need is a routine for analyzing the pixel variance in a region. For this, I use a root-mean-square error:

public double rmsError( int [] pixels ) {

double ave = 0;

for ( int i = 0; i < pixels.length; i++ )
ave += ( pixels[ i ] >> 8 ) & 255;

ave /= pixels.length;

double diff = 0;
double accumulator = 0;

for ( int i = 0; i < pixels.length; i++ ) {
diff = ( ( pixels[ i ] >> 8 ) & 255 ) - ave;
diff *= diff;
accumulator += diff;
}

double rms = accumulator / pixels.length;

rms = Math.sqrt( rms );

return rms;
}
Before we transform the image, we should have code that opens an image and displays it in a JFrame. The following code does that. It takes the image whose path is supplied in a command-line argument, opens it, and displays it in a JComponent inside a JFrame:

import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.io.File;
import javax.imageio.ImageIO;
import javax.swing.JComponent;
import javax.swing.JFrame;

public class ImageWindow {

// This inner class is our canvas.
// We draw the image on it.
class ImagePanel extends JComponent {

BufferedImage theImage = null;

ImagePanel( BufferedImage image ) {
super();
theImage = image;
}

public BufferedImage getImage( ) {
return theImage;
}

public void setImage( BufferedImage image ) {
theImage = image;
this.updatePanel();
}

public void updatePanel() {

invalidate();
getParent().doLayout();
repaint();
}

public void paintComponent( Graphics g ) {

int w = theImage.getWidth( );
int h = theImage.getHeight( );

g.drawImage( theImage, 0,0, w,h, this );
}
} // end ImagePanel inner class

// Constructor
public ImageWindow( String [] args ) {

// open image
BufferedImage image = openImageFile( args[0] );

// create a panel for it
ImagePanel theImagePanel = new ImagePanel( image );

// display the panel in a JFrame
createWindowForPanel( theImagePanel, args[0] );

// filter the image
filterImage( theImagePanel );
}

public void filterImage( ImagePanel panel ) {

SmartBlurFilter filter = new SmartBlurFilter( );

BufferedImage newImage = filter.filter( panel.getImage( ) );

panel.setImage( newImage );
}

public void createWindowForPanel( ImagePanel theImagePanel, String name ) {

BufferedImage image = theImagePanel.getImage();
JFrame mainFrame = new JFrame();
mainFrame.setTitle( name );
mainFrame.setBounds(50,80,image.getWidth( )+10, image.getHeight( )+10);
mainFrame.setDefaultCloseOperation(3);
mainFrame.getContentPane().add( theImagePanel );
mainFrame.setVisible(true);
}

BufferedImage openImageFile( String fname ) {

BufferedImage img = null;

try {
File f = new File( fname );
if ( f.exists( ) )
img = ImageIO.read(f);
}
catch (Exception e) {
e.printStackTrace();
}

return img;
}

public static void main( String[] args ) {

new ImageWindow( args );
}
}


Note the method filterImage(), where we instantiate a SmartBlurFilter. Without further ado, here's the full code for SmartBlurFilter:
import java.awt.image.Kernel;
import java.awt.image.BufferedImage;
import java.awt.image.ConvolveOp;
import java.awt.Graphics;

public class SmartBlurFilter {

double SENSITIVITY = 10;
int REGION_SIZE = 5;

float [] kernelArray = {
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1,
1,1,1,1,1,1,1,1,1
};

Kernel kernel = new Kernel( 9,9, normalizeKernel( kernelArray ) );

float [] normalizeKernel( float [] ar ) {
int n = 0;
for (int i = 0; i < ar.length; i++)
n += ar[i];
for (int i = 0; i < ar.length; i++)
ar[i] /= n;

return ar;
}

public double lerp( double a,double b, double amt) {
return a + amt * ( b - a );
}

public double getLerpAmount( double a, double cutoff ) {

if ( a > cutoff )
return 1.0;

return a / cutoff;
}

public double rmsError( int [] pixels ) {

double ave = 0;

for ( int i = 0; i < pixels.length; i++ )
ave += ( pixels[ i ] >> 8 ) & 255;

ave /= pixels.length;

double diff = 0;
double accumulator = 0;

for ( int i = 0; i < pixels.length; i++ ) {
diff = ( ( pixels[ i ] >> 8 ) & 255 ) - ave;
diff *= diff;
accumulator += diff;
}

double rms = accumulator / pixels.length;

rms = Math.sqrt( rms );

return rms;
}

int [] getSample( BufferedImage image, int x, int y, int size ) {

int [] pixels = {};

try {
BufferedImage subimage = image.getSubimage( x,y, size, size );
pixels = subimage.getRGB( 0,0,size,size,null,0,size );
}
catch( Exception e ) {
// will arrive here if we requested
// pixels outside the image bounds
}
return pixels;
}

int lerpPixel( int oldpixel, int newpixel, double amt ) {

int oldRed = ( oldpixel >> 16 ) & 255;
int newRed = ( newpixel >> 16 ) & 255;
int red = (int) lerp( (double)oldRed, (double)newRed, amt ) & 255;

int oldGreen = ( oldpixel >> 8 ) & 255;
int newGreen = ( newpixel >> 8 ) & 255;
int green = (int) lerp( (double)oldGreen, (double)newGreen, amt ) & 255;

int oldBlue = oldpixel & 255;
int newBlue = newpixel & 255;
int blue = (int) lerp( (double)oldBlue, (double)newBlue, amt ) & 255;

return ( red << 16 ) | ( green << 8 ) | blue;
}

int [] blurImage( BufferedImage image,
int [] orig, int [] blur, double sensitivity ) {

int newPixel = 0;
double amt = 0;
int size = REGION_SIZE;

for ( int i = 0; i < orig.length; i++ ) {
int w = image.getWidth();
int [] pix = getSample( image, i % w, i / w, size );
if ( pix.length == 0 )
continue;

amt = getLerpAmount ( rmsError( pix ), sensitivity );
newPixel = lerpPixel( blur[ i ], orig[ i ], amt );
orig[ i ] = newPixel;
}

return orig;
}


public BufferedImage filter( BufferedImage image ) {

ConvolveOp convolver = new ConvolveOp(kernel, ConvolveOp.EDGE_NO_OP,
null);

// clone image into target
BufferedImage target = new BufferedImage(image.getWidth(), image
.getHeight(), image.getType());
Graphics g = target.createGraphics();
g.drawImage(image, 0, 0, null);
g.dispose();

int w = target.getWidth();
int h = target.getHeight();

// get source pixels
int [] pixels = image.getRGB(0, 0, w, h, null, 0, w);

// blur the cloned image
target = convolver.filter(target, image);

// get the blurred pixels
int [] blurryPixels = target.getRGB(0, 0, w, h, null, 0, w);

// go thru the image and interpolate values
pixels = blurImage(image, pixels, blurryPixels, SENSITIVITY);

// replace original pixels with new ones
image.setRGB(0, 0, w, h, pixels, 0, w);
return image;
}
}
Despite all the intensive image analysis, the routine is fairly fast: On my machine, it takes about one second to process a 640x480 image. That's slower than Photoshop by a factor of five, or more, but still not bad (given that it's "only Java").

Ideas for further development:
  • Substitute a directional blur for the non-directional blur.
  • Substitute a Sobel kernel for the blur kernel.
  • Try other sorts of kernels as well.