<!-- Load TensorFlow.js. This is required to use MobileNet. -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"></script>
<!-- Load the MobileNet model. -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.1"></script>
<!-- Load D3 for the visualization of predicted image classifications -->
<script src="https://d3js.org/d3.v4.min.js"></script>
Webcam Image Classification
As I mentioned in my “Hello World” post, I’m fascinated with computer vision. I’m in the process of writing up a post about the image classification mobile app that I made to detect whether my iPhone camera is pointing at my dog or not. While trying to figure out how to host a live ML model within a browser, I learned how to call pre-trained Tensorflow models from the HTML <script>
tag. This option is pretty straightforward, but required learning a bit of JavaScript. If I tried to do this strictly in R, I anticipated embedding a Shiny application with an upload photo UI, pre-loaded model, image display, etc.
In this post, I create a webcam image classifier using a pre-trained TensorFlow.js MobileNet model hosted on jsDelivr. MobileNets are relatively small neural networks that are optimized to work well on less powerful computers (e.g. cell phones, tablets). This increased efficiency comes with a tradeoff in accuracy. Depending on the exact architecture chosen, MobileNets tend to top out around 60-70% accuracy. This particular model can predict 1,000 different classes, mostly commonplace objects and animals. I’ve found that it gets confused when the background is non-monochromatic, so for best results try to position the object in front of something simple. Before diving into the code, I am by no means a front end developer, so forgive the weak UI and delay in classification and display of the results. Below is the final product, followed by the instructions to recreate it. Have fun!
The final product!
How it’s made:
The first step is to load the necessary libraries:
- TensorFlow.js to run the model
- The MobileNet model to classify the images
- D3.js for visualization
I add some (bare minimum) CSS so that the UI is somewhat presentable.
>
<style.bar {
fill: #27A822;
}.axis {
font-size: 10px;
}.axis path,
.axis line {
fill: none;
display: none;
}.label {
font-size: 13px;
}.column {
float: left;
width: 33.33%;
}.row:after {
content: "";
display: table;
clear: both;
}> </style
Next, I add a basic UI with a “Capture video” and “Take screenshot” button that will start the webcam and screenshot a single frame to initiate object classification, respectively.
<p align="center">
<button id="capture-button" class="capture-button">Capture video</button>
<button id="screenshot-button" disabled>Take screenshot</button>
</p>
<br>
<div class="row">
<div class="column">
<video class="videostream" width="240" height="180" autoplay></video>
</div>
<div class="column">
<img id="screenshot-img" width="240" height="180">
</div>
<div class="column">
<div id="graphic" align="center" width="240" height="180"></div>
</div>
</div>
Now that the UI is all set up, you need to make it react to input. To start, I initialize several variables that will be used later by the webcam. I also set up the size and axes of the graph that will ultimately display the top 3 classes and corresponding probabilities that the MobileNet predicts for the object in the still screenshot.
// Initialize webcam and image settings
const constraints = {
video: true
;
}const captureVideoButton = document.getElementById('capture-button');
const screenshotButton = document.getElementById('screenshot-button');
const img = document.getElementById('screenshot-img');
const video = document.querySelector('video');
const canvas = document.createElement('canvas');
// Set the dimensions and margins of the graph
var margin = {top: 0, right: 0, bottom: 25, left: 110},
= 240 - margin.left - margin.right,
width = 180 - margin.top - margin.bottom;
height
// Create the graph with the dimensions set above
var svg = d3.select("#graphic").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform",
"translate(" + margin.left + "," + margin.top + ")");
// Set the range of y and x
var y = d3.scaleBand()
.range([height, 0])
.padding(0.1);
var x = d3.scaleLinear()
.range([0, width]);
// Initialize the axes with the x-axis formatted as a percentage
var xAxis = d3.axisBottom(x)
.ticks(4)
.tickFormat(d => Math.round(d*100) + "%");
var yAxis = d3.axisLeft(y);
// Attach the settings for the axes
.append("g")
svg.attr("class", "y axis")
.call(yAxis);
.append("g")
svg.attr("class", "x axis");
Because I decided to put the videostream, screenshot, and graph all on the same row of the site, the y-axis labels tended to get cut off, especially when the predicted class is something like “cellular telephone, cellular phone, cellphone, cell, mobile phone”. This next function was taken almost directly from here with minor tweaks to work on a horizontal bar chart rather than a vertical one. It takes text and a fixed width and splits it over several lines.
function wrap(text, width) {
.each(function() {
textvar text = d3.select(this),
= text.text().split(/\s+/).reverse(),
words ,
word= [],
line = 0,
lineNumber = 0.5, // ems
lineHeight = text.attr("y"),
y = parseFloat(text.attr("dy")),
dy = text.text(null).append("tspan").attr("x", -10).attr("y", y).attr("dy", dy + "em")
tspan while (word = words.pop()) {
.push(word)
line.text(line.join(" "))
tspanif (tspan.node().getComputedTextLength() > width) {
.pop()
line.text(line.join(" "))
tspan= [word]
line = text.append("tspan")
tspan .attr("x", -10)
.attr("y", y)
.attr("dy", `${++lineNumber * lineHeight + dy}em`)
.text(word)
}
}
}) }
The last function related to the D3.js graph is below. It updates the existing graph with the data it receives anytime it is called. It’s all pretty self-explanatory, so see the comments within the function.
function update(data) {
// Sort bars based on value
= data.sort(function (a, b) {
data return d3.ascending(a.probability, b.probability);
;
})
// Format the data
.forEach(function(d) {
data.probability = +d.probability;
d;
})
// Scale the range of the data in the domains
.domain([0, d3.max(data, function(d){ return d.probability; })])
x.domain(data.map(function(d) { return d.className; }));
y
// Remove any existing bars on the graph
var bars = svg.selectAll(".bar")
.remove()
.exit()
.data(data)
// Add new bars using the new data
.enter()
bars.append("rect")
.attr("class", "bar")
.attr("width", function(d) {return x(d.probability); } )
.attr("y", function(d) { return y(d.className); })
.attr("height", y.bandwidth());
// Add the x Axis
.select(".x")
svg.attr("transform", "translate(0," + height + ")")
.call(d3.axisBottom(x)
.ticks(4)
.tickFormat(d => Math.round(d*100) + "%"));
// Add the y Axis
.select(".y")
svg.call(d3.axisLeft(y))
.selectAll(".tick text")
.call(wrap, 100);
}
So far, I’ve set up the UI and the code to create the graph, but have not set up either the webcam or screenshot mechanism. In the next code chunk, I capture the video stream on button click using the getUserMedia()
API. HTML5rocks.com has a very helpful article on video and audio capturing, if you’re interested in learning more. Next, I update the canvas element acting as the placeholder for the screenshot. That screenshot image then gets fed into the MobileNet model and the resulting predictions are passed in as the data to the update()
function I wrote above.
// When the video button is clicked, open the video webcam and begin streaming
.onclick = function() {
captureVideoButtonnavigator.mediaDevices.getUserMedia(constraints).
then(handleSuccess).catch(handleError);
;
}
function handleSuccess(stream) {
.disabled = false;
screenshotButton.srcObject = stream;
video;
}
// When the screenshot button or the video itself is clicked,
// grab a still image of the stream and replace the blank canvas
.onclick = video.onclick = function() {
screenshotButton
.width = video.videoWidth;
canvas.height = video.videoHeight;
canvas.getContext('2d').drawImage(video, 0, 0);
canvas// Other browsers will fall back to image/png
.src = canvas.toDataURL('image/webp');
img
const img_tmp = document.getElementById('screenshot-img');
// Load the model
.load().then(model => {
mobilenet
// Classify the image
.classify(img_tmp).then(predictions => {
model
// Update the bar chart
update(predictions);
;
});
}); }