This vignette visualizes classification results from discriminant analysis, using tools from the package.
As a first small example, we consider the Iris data. We first load the data and inspect it.
## [1] TRUE
## y
## setosa versicolor virginica
## 50 50 50
Now we carry out quadratic discriminant analysis and inspect the output. Note that we can also do linear discriminant analysis by choosing rule = “LDA”.
## [1] "yint" "y" "levels" "predint" "pred" "altint"
## [7] "altlab" "PAC" "figparams" "fig" "farness" "ofarness"
## [13] "classMS" "lCurrent" "lPred" "lAlt"
We now inspect the output in detail. First look at the prediction as integer, the prediction as label, the alternative label as integer and the alternative label:
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 2
## [75] 2 2 2 2 2 2 2 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3
## [112] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [149] 3 3
## [1] "setosa" "setosa" "setosa" "setosa" "setosa"
## [6] "setosa" "setosa" "setosa" "setosa" "setosa"
## [11] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [16] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [21] "virginica" "virginica" "virginica" "virginica" "virginica"
## [26] "virginica" "virginica" "virginica" "virginica" "virginica"
## [1] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [38] 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [75] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2
## [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [149] 2 2
## [1] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [6] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [11] "virginica" "virginica" "virginica" "virginica" "virginica"
## [16] "virginica" "virginica" "virginica" "virginica" "virginica"
## [21] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [26] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
The Probability of Alternative Class (PAC) of each object is found in the $PAC element of the output:
## [1] 4.918517e-26 7.655808e-19 1.552279e-21
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0000000 0.0000000 0.0000081 0.0237098 0.0010938 0.8456517
The $fig element of the output contains the distance from case i to class g. Let’s look at it for the first 5 objects:
## [,1] [,2] [,3]
## [1,] 0.02675535 1 1
## [2,] 0.33639794 1 1
## [3,] 0.16134074 1 1
## [4,] 0.25293196 1 1
## [5,] 0.06600114 1 1
From the fig, the farness of each object can be computed. The farness of an object i is the f(i, g) to its own class:
## [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0153 0.2396 0.5159 0.4996 0.7617 0.9862
The “overall farness” of an object is defined as the lowest f(i, g) it has to any class g (including its own):
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0153 0.2396 0.5145 0.4957 0.7543 0.9862
Objects with ofarness > cutoff are flagged as “outliers”. These can be included in a separate column in the confusion matrix. This confusion matrix can be computed using confmat.vcr, which also returns the accuracy.
To illustrate this we choose a rather low cutoff:
##
## Confusion matrix:
## predicted
## given setosa versicolor virginica outl
## setosa 48 0 0 2
## versicolor 0 48 2 0
## virginica 0 1 48 1
##
## The accuracy is 98%.
With the default cutoff = 0.99 no objects are flagged in this example:
##
## Confusion matrix:
## predicted
## given setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 48 2
## virginica 0 1 49
##
## The accuracy is 98%.
Note that the accuracy is computed before any objects are flagged, so it does not depend on the cutoff.
The confusion matrix can also be constructed showing class numbers instead of labels. This option can be useful for long level names.
##
## Confusion matrix:
## predicted
## given 1 2 3
## 1 50 0 0
## 2 0 48 2
## 3 0 1 49
##
## The accuracy is 98%.
A stacked mosaic plot made with the stackedplot() function can be used to visualize the confusion matrix. The outliers, if there are any, appear as grey areas on top.
cols <- c("red", "darkgreen", "blue")
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
minSize = 1, showLegend = TRUE)
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
minSize = 1, showLegend = TRUE, cutoff = 0.98)
The default stacked mosaic plot has no legend:
stplot <- stackedplot(vcr.train, classCols = cols,
separSize = 1.5, minSize = 1,
main = "QDA on iris data")
stplot
We also make the silhouette plot using the silplot() function:
# pdf("Iris_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols,
main = "Silhouette plot of QDA on iris data")
## classNumber classLabel classSize classAveSi
## 1 setosa 50 1.00
## 2 versicolor 50 0.91
## 3 virginica 50 0.95
We now make the class maps based on the vcr object. This can be done using the classmap() function. We make a separate class map for each of the three classes. We see that class 1 is a very tight class (low PAC, no high farness). Class 2 is not so tight, and has two points which are predicted as virginica. Class 3 has one point predicted as versicolor.
# Now one point is to the right of the vertical line.
# It also has a black border, meaning that it is flagged
# as an outlier, in the sense that its farness to _all_
# classes is above 0.98.
To illustrate the use of new data we create a fake dataset which is a subset of the training data, where not all classes occur, and ynew has NA’s.
Xnew <- X[c(1:50, 101:150), ]
ynew <- y[c(1:50, 101:150)]
ynew[c(1:10, 51:60)] <- NA
pairs(X, col = as.numeric(y) + 1, pch = 19) # 3 colors
Now we build the vcr object on the training data.
Inspect some of the output to confirm that it corresponds with what we would expect:
## [1] NA NA NA NA NA NA
## [7] NA NA NA NA 0.29421328 0.32178116
## [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
## [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
## [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
## [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
## [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
## [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
## [49] 0.15482827 0.03145084 NA NA NA NA
## [55] NA NA NA NA NA NA
## [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
## [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
## [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
## [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
## [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
## [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
## [97] 0.63534800 0.11247730 0.62691061 0.42737442
## [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114 0.63210603
## [7] 0.59041424 0.01732745 0.52024594 0.55494759 0.29421328 0.32178116
## [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
## [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
## [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
## [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
## [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
## [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
## [49] 0.15482827 0.03145084 0.93068594 0.26632690 0.06831897 0.31105631
## [55] 0.20258388 0.65479346 0.90867003 0.58975324 0.56295693 0.68728265
## [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
## [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
## [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
## [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
## [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
## [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
## [97] 0.63534800 0.11247730 0.62691061 0.42737442
The confusion matrix for the test data, as for the training data, can be constructed by the confmat.vcr() function. A cutoff of 0.98 flags three outliers in this example.
##
## Confusion matrix:
## predicted
## given setosa versicolor virginica
## setosa 40 0 0
## virginica 0 1 39
##
## The accuracy is 98.75%.
##
## Confusion matrix:
## predicted
## given setosa versicolor virginica outl
## setosa 38 0 0 2
## virginica 0 1 38 1
##
## The accuracy is 98.75%.
Also the stacked mosaic plot can be constructed on the test data:
##
## Not all classes occur in these data. The classes to plot are:
## [1] 1 3
We now make the silhouette plot on the test data:
#pdf("Iris_test_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.test, classCols = cols,
main = "Silhouette plot of QDA on iris subset")
## classNumber classLabel classSize classAveSi
## 1 setosa 40 1.00
## 3 virginica 40 0.94
Finally, we construct the class maps for the test data. We compare the class map of the training data with that of the test data for each class.
## Error in classmap(vcr.test, 2, classCols = cols): Class number 2 with label versicolor has no objects to visualize.
We now analyze the floral buds data, which was also used as an illustration in the paper. First load and inspect the data.
## [1] 550 6
## [1] 550
## y
## branch bud scales support
## 49 363 94 44
# branch bud scales support
# 49 363 94 44
# Pairs plot
cols <- c("saddlebrown", "orange", "olivedrab4", "royalblue3")
pairs(X, gap = 0, col = cols[as.numeric(y)]) # hard to separate visually
Now we perform quadratic discriminant analysis:
Construct the confusion matrix without and with outliers shown:
##
## Confusion matrix:
## predicted
## given branch bud scales support
## branch 45 1 1 2
## bud 0 358 1 4
## scales 2 0 90 2
## support 6 3 0 35
##
## The accuracy is 96%.
##
## Confusion matrix:
## predicted
## given branch bud scales support outl
## branch 45 1 1 2 0
## bud 0 353 1 4 5
## scales 2 0 86 2 4
## support 6 3 0 35 0
##
## The accuracy is 96%.
Construct the stacked mosaic plot:
stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
minSize = 1.5, main = "stacked plot of QDA on floral buds")
# Version in paper:
# pdf("Floralbuds_QDA_stackplot_without_outliers.pdf",
# width=5, height=4.3)
# stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
# minSize = 1.5, showOutliers = FALSE,
# htitle = "given class", vtitle = "predicted class")
# dev.off()
Now make the silhouette plot:
#pdf("Floralbuds_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.obj, classCols = cols,
main = "Silhouette plot of QDA on floral bud data")
## classNumber classLabel classSize classAveSi
## 1 branch 49 0.75
## 2 bud 363 0.96
## 3 scales 94 0.93
## 4 support 44 0.57
The quasi residual plot can be made with the qresplot() function. We illustate this below by making the quasi residual plot against the sum of the variables. A correlation test confirms that the images with higher sums are significantly easier to classify:
PAC <- vcr.obj$PAC
feat <- rowSums(X); xlab = "rowSums(X)"
# pdf("Floralbuds_QDA_quasi_residual_plot.pdf", width=5, height=4.8)
qresplot(PAC, feat, xlab = xlab, plotErrorBars = TRUE, fac = 2,
main = "Floral buds: quasi residual plot")
##
## Spearman's rank correlation rho
##
## data: feat and PAC
## S = 39255896, p-value < 2.2e-16
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
## rho
## -0.4156944
Construct the class maps, as shown in the paper:
labels <- c("branch", "bud", "scale", "support")
# classmap of class "bud"
#
# To identify the points that stand out:
# classmap(vcr.obj, 2, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf("Floralbuds_QDA_classmap_bud.pdf", width=7, height=7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.obj, 2, classCols = cols,
main = "predictions of buds",
cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
# For marking points:
indstomark <- c(294, 70, 69, 152, 204) # from identify = TRUE above
labs <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
c(0.04, 0.04, 0, -0.03, +0.04)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("topleft", fill = cols[1:4], legend = labels,
cex = 1, ncol = 1, bg = "white")
All class maps:
#
# pdf(file = "Floralbuds_all_class_maps.pdf", width = 7, height = 7)
par(mfrow = c(2, 2))
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 1, classCols = cols,
main = "predictions of branches")
legend("topright", fill = cols, legend = labels,
cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 2, classCols = cols,
main = "predictions of buds")
labs <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
c(0.04, 0.04, 0, -0.03, 0.04)
# xvals <- c( 1.75, 1.68, 1.25, 3.25, 4.00)
# yvals <- c(0.045, 0.92, 0.54, 0.97, 0.045)
text(x = xvals, y = yvals, labels = labs, cex = 1.0)
legend("topleft", fill = cols, legend = labels,
cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 3, classCols = cols,
main = "predictions of scales")
legend("left", fill = cols, legend = labels,
cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 4, classCols = cols,
main = "predictions of supports")
legend("topright", fill = cols, legend = labels,
cex = 1, ncol = 1, bg = "white")
We now analyze the MNIST data, originally from the website of Yann LeCun. As the link on his website is currently down, we use a different source. Note that downloading the data may take a minute or two, depending on the speed of the internet connection.
mnist_url <- "https://wis.kuleuven.be/statdatascience/robust/data/mnist-rdata"
url.exists <- suppressWarnings(try(open.connection(url(mnist_url), open = "rt", timeout = 2), silent = TRUE)[1], classes = "warning")
if (is.null(url.exists)) {load(url(mnist_url))} else {
print(paste("The data source ", mnist_url, "is not active at the moment. The example can nevertheless be reproduced by downloading the mnist data from another source, formatting the training data to dimensions 60000 x 28 x 28, and running the code below."))
}
close(url(mnist_url))
X_train <- mnist$train$x
y_train <- as.factor(mnist$train$y)
head(y_train)
## [1] 5 0 4 1 9 2
## Levels: 0 1 2 3 4 5 6 7 8 9
## [1] 60000 28 28
## [1] 60000
We now inspect the data by plotting a few images
plotImage = function(tempImage) {
tdm = reshape2::melt(apply((tempImage), 2, rev))
p = ggplot(tdm, aes(x = Var2, y = Var1, fill = (value))) +
geom_raster() +
guides(color = "none", size = "none", fill = "none") +
theme(axis.title.x = element_blank(),
axis.title.y = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.ticks.x = element_blank(),
axis.ticks.y = element_blank()) +
scale_fill_gradient(low = "white", high = "black")
p
}
plotImage(X_train[1, , ])
We now unfold the array containing the data to a matrix, and inspect some sample images as well as the average image per digit:
# Change the dimensions of X for the sequel:
dim(X_train) <- c(60000, 28 * 28)
dim(X_train) # 60000 784
## [1] 60000 784
# Sampled digit images:
set.seed(123)
sampledigits <- list()
for (i in 0:9) {
digit <- i
idx <- sample(which(y_train == digit), size = 1)
tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
sampledigits[[i + 1]] <- plotImage(tempImage)
}
psampledigits <- grid.arrange(grobs = sampledigits, ncol = 5)
# ggsave("MNIST_sampled_images.pdf", plot = psampledigits,
# width = 10, height = 1)
# Averaged digit images:
meanPlots <- list()
for (j in 0:9) {
m.out <- colMeans(X_train[which(y_train == j), ])
dim(m.out) <- c(28, 28)
meanPlots[[j + 1]] <- plotImage(m.out)
}
meanplot <- grid.arrange(grobs = meanPlots, ncol = 5)
Before performing discriminant analysis, we reduce the dimension of the data by PCA.
library(svd)
ptm <- proc.time()
svd.out <- svd::propack.svd(X_train, neig = 50)
(proc.time() - ptm)[3]
## elapsed
## 6.504
## [1] 60000 50
Now we perform discriminant analysis, which takes roughly 5 seconds.
We compute the confusion matrix and make the stacked mosaic plot:
##
## Confusion matrix:
## predicted
## given 0 1 2 3 4 5 6 7 8 9
## 0 5833 0 22 6 1 14 2 0 42 3
## 1 0 6436 104 14 32 0 2 13 138 3
## 2 14 1 5807 26 14 0 9 12 69 6
## 3 3 1 88 5821 4 52 0 18 120 24
## 4 6 1 21 3 5704 1 12 14 30 50
## 5 14 0 4 71 2 5222 17 0 78 13
## 6 27 2 6 2 8 114 5703 0 56 0
## 7 13 8 94 14 34 14 0 5936 54 98
## 8 10 24 40 72 8 40 2 4 5625 26
## 9 17 2 23 65 59 14 1 77 93 5598
##
## The accuracy is 96.14%.
cols <- c("red3", "darkorange", "gold2", "darkolivegreen3",
"darkolivegreen4", "cadetblue3", "deepskyblue4",
"darkslateblue", "darkorchid3", "deeppink4")
# stacked plot in paper:
# pdf("MNIST_stackplot_with_outliers.pdf", width=5, height=4.3)
stackedplot(vcr.train, classCols = cols, separSize = 0.6,
minSize = 1.5, htitle = "given class",
main = "Stacked plot of QDA on MNIST training data", vtitle = "predicted class")
The silhouette plot:
# pdf("MNIST_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols,
main = "Silhouette plot of QDA on MNIST training data")
## classNumber classLabel classSize classAveSi
## 1 0 5923 0.97
## 2 1 6742 0.91
## 3 2 5958 0.95
## 4 3 6131 0.90
## 5 4 5842 0.95
## 6 5 5421 0.92
## 7 6 5918 0.93
## 8 7 6265 0.89
## 9 8 5851 0.92
## 10 9 5949 0.88
Now we make the class maps.
wnq <- function(string, qwrite=TRUE) { # auxiliary function
# writes a line without quotes
if (qwrite) write(noquote(string), file = "", ncolumns = 100)
}
showdigit <- function(digit=digit, i, plotIt = TRUE) {
idx = which(y_train == digit)[i]
# wnq(paste("Estimated digit: ", as.numeric(vcr.train$pred[idx]), sep=""))
tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
if (plotIt) {plot(plotImage(tempImage))}
return(plotImage(tempImage))
}
Class map of digit 0, shown in paper:
digit <- 0
#
# To identify outliers:
# classmap(vcr.train, digit+1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
main = paste0("predictions of digit ",digit),
cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
indstomark <- c(4000, 3964, 5891, 2485, 822,
2280, 2504, 3906, 5869, 1034) # from identify = TRUE
labs <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
c(-0.04, -0.01, 0, -0.11, 0.06,
0.07, 0.06, 0.10, 0.06, 0.09)
yvals <- coords[indstomark, 2] +
c(-0.03, -0.03, -0.03, 0.022, -0.025,
-0.025, -0.035, -0.025, 0.03, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
legend = 0:9, cex = 1, ncol = 2, bg = "white")
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
idx <- indstomark[i]
tempplot <- showdigit(digit, idx, plotIt = FALSE)
tempplot <- arrangeGrob(tempplot,
bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
discussionPlots[[i]] = tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
# plot = discussionPlot, width = 5,
# height = (length(indstomark) %/% 5 +
# (length(indstomark) %% 5 > 0)))
Class map of digit 1, shown in paper:
digit <- 1
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
classmap(vcr.train, digit + 1, classCols = cols,
main = paste0("predictions of digit ", digit),
cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
legend("left", fill = cols,
legend = 0:9, cex = 1, ncol = 2, bg = "white")
# indices of the 1s predicted as 2 (takes a while):
#
indstomark <- which(vcr.train$predint[which(y_train == digit)] == 3)
length(indstomark) # 104
## [1] 104
labs <- letters[1:length(indstomark)]
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
idx <- indstomark[i]
tempplot <- showdigit(digit, idx, FALSE)
tempplot <- arrangeGrob(tempplot,
bottom = paste0("\"", tempPreds[i], "\""))
discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots,
ncol = 8)
# ggsave(paste0("MNIST_discussionplot_digit", digit, "predictedAs2b.pdf"),
# plot = discussionPlot, width = 10,
# height = (length(indstomark) %/% 10 +
# (length(indstomark) %% 10 > 0)))
# The digits 1 predicted as a 2 are mostly ones written with
# a horizontal line at the bottom.
Class map of digit 2:
digit <- 2
# To identify outliers:
# classmap(vcr.train, digit + 1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit,".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
main = paste0("predictions of digit", digit), cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
indstomark <- c(3164, 5434, 2319 , 4224, 3682,
2642, 4920, 1233, 3741, 3993) # from identify = TRUE
labs <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
c(0, 0.08, 0, 0, 0, 0, 0, 0, 0, 0)
yvals <- coords[indstomark, 2] +
c(-0.03, -0.03, -0.03, -0.03, -0.03,
-0.03, -0.03, -0.03, 0.03, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
legend = 0:9, cex = 1, ncol = 2, bg = "white")
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
idx <- indstomark[i]
tempplot <- showdigit(digit, idx, FALSE)
tempplot <- arrangeGrob(tempplot,
bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots,
ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
# plot = discussionPlot, width = 5,
# height = (length(indstomark) %/% 5 +
# (length(indstomark) %% 5 > 0)))
Now we analyze the MNIST test data. First load and inspect the data, and project it onto the PCA subspace extracted from the training data.
## [1] 7 2 1 0 4 1
## Levels: 0 1 2 3 4 5 6 7 8 9
## [1] 10000 28 28
## [1] 10000
## [1] 10000 784
Now prepare the VCR object:
Build the confusion matrix and plot a stacked mosaic plot of the classification performance on the test data:
##
## Confusion matrix:
## predicted
## given 1 2 3 4 5 6 7 8 9 10
## 1 970 0 1 0 0 2 1 1 5 0
## 2 0 1097 11 3 2 1 1 0 20 0
## 3 2 0 1002 3 3 0 2 1 19 0
## 4 1 0 9 972 0 5 0 2 17 4
## 5 0 0 4 0 965 0 3 2 2 6
## 6 2 0 1 18 0 859 1 1 10 0
## 7 8 1 2 0 4 12 924 0 7 0
## 8 1 2 28 1 3 2 0 958 14 19
## 9 3 0 9 12 1 5 1 2 935 6
## 10 5 1 11 6 10 2 0 6 18 950
##
## The accuracy is 96.32%.
# In supplementary material:
# pdf("MNISTtest_stackplot_with_outliers.pdf", width = 5, height = 4.3)
stackedplot(vcr.test, classCols = cols, separSize = 0.6,
main = "Stacked plot of QDA on MNIST test data",
minSize = 1.5)
Silhouette plot:
#pdf("MNIST_test_QDA_silhouettes.pdf", width = 5.0, height = 4.6)
silplot(vcr.test, classCols = cols,
main = "Silhouette plot of QDA on MNIST test data")
## classNumber classLabel classSize classAveSi
## 1 0 980 0.98
## 2 1 1135 0.93
## 3 2 1032 0.94
## 4 3 1010 0.92
## 5 4 982 0.96
## 6 5 892 0.92
## 7 6 958 0.93
## 8 7 1028 0.86
## 9 8 974 0.92
## 10 9 1009 0.88
Now we can construct the class maps on the test data. First for digit 0:
showdigit_test <- function(digit = digit, i, plotIt = TRUE) {
idx = which(y_test == digit)[i]
# wnq(paste("Estimated digit: ", as.numeric(vcr.test$pred[idx]), sep = ""))
tempImage <- matrix(unlist(X_test[idx, ]), 28, 28)
if (plotIt) {plot(plotImage(tempImage))}
return(plotImage(tempImage))
}
digit <- 0
# classmap(vcr.test, digit+1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit,".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
main = paste0("predictions of digit ", digit),
cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
indstomark <- c(140, 630, 241, 967, 189,
377, 78, 943, 64, 354)
labs <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
c(0.08, 0.07, -0.07, 0.06, 0,
0.04, 0.05, 0.09, -0.04, 0.09)
yvals <- coords[indstomark, 2] +
c(-0.025, -0.03, -0.024, -0.025, -0.03,
-0.03, -0.03, 0.022, 0.035, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
legend = 0:9, cex = 1, ncol = 2, bg = "white")
pred <- vcr.test$pred # needed for discussion plots
tempPreds <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
idx <- indstomark[i]
tempplot <- showdigit_test(digit, idx, FALSE)
tempplot <- arrangeGrob(tempplot,
bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots,
ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
# plot = discussionPlot, width = 5,
# height = (length(indstomark) %/% 5 +
# (length(indstomark) %% 5 > 0)))
Now for digit 3:
digit <- 3
# classmap(vcr.test, digit + 1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit, ".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
main = paste0("predictions of digit ", digit),
cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
cex.main = 1.5)
indstomark <- c(883, 659, 262, 60, 310,
832, 223, 784, 835, 289)
labs <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
c(-0.01, 0.08, -0.10, 0.06, 0.07,
0.06, 0.03, 0.11, 0.02, 0.06)
yvals <- coords[indstomark, 2] +
c(0.035, 0.033, -0.017, -0.022, -0.025,
-0.025, -0.033, -0.022, 0.035, 0.038)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
legend = 0:9, cex = 1, ncol = 2, bg = "white")
pred <- vcr.test$pred # needed for discussion plots
tempPreds <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
idx <- indstomark[i]
tempplot <- showdigit_test(digit, idx, FALSE)
tempplot <- arrangeGrob(tempplot,
bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots,
ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
# plot = discussionPlot, width = 5,
# height = (length(indstomark) %/% 5 +
# (length(indstomark) %% 5 > 0)))