The animint2 Manual by Toby Dylan Hocking


Chapter 11, Lasso

This goal of this chapter is to create an interactive data visualization that explains the Lasso, a machine learning model for regularized linear regression.

Chapter outline:

  • We begin with several static data visualizations of the lasso path.
  • We then create an interactive version with a facet and plot showing train/validation error and residuals.
  • Finally we re-design the interactive data visualization with simplified legends and moving tallrects.

Static plots of the coefficient regularization path

We begin by loading the prostate cancer data set.

if(!file.exists("prostate.data")){
  download.file(
    "https://web.stanford.edu/~hastie/ElemStatLearn/datasets/prostate.data",
    "prostate.data")
}
prostate <- data.table::fread("prostate.data")
head(prostate)
##       V1     lcavol  lweight   age      lbph   svi       lcp gleason pgg45
##    <int>      <num>    <num> <int>     <num> <int>     <num>   <int> <int>
## 1:     1 -0.5798185 2.769459    50 -1.386294     0 -1.386294       6     0
## 2:     2 -0.9942523 3.319626    58 -1.386294     0 -1.386294       6     0
## 3:     3 -0.5108256 2.691243    74 -1.386294     0 -1.386294       7    20
## 4:     4 -1.2039728 3.282789    58 -1.386294     0 -1.386294       6     0
## 5:     5  0.7514161 3.432373    62 -1.386294     0 -1.386294       6     0
## 6:     6 -1.0498221 3.228826    50 -1.386294     0 -1.386294       6     0
##          lpsa  train
##         <num> <char>
## 1: -0.4307829      T
## 2: -0.1625189      T
## 3: -0.1625189      T
## 4: -0.1625189      T
## 5:  0.3715636      T
## 6:  0.7654678      T

We construct a train inputs x and outputs y using the code below.

input.cols <- c(
  "lcavol", "lweight", "age", "lbph", "svi", "lcp", "gleason", 
  "pgg45")
prostate.inputs <- prostate[, ..input.cols]
is.train <- prostate$train == "T"
x <- as.matrix(prostate.inputs[is.train])
head(x)
##          lcavol  lweight age      lbph svi       lcp gleason pgg45
## [1,] -0.5798185 2.769459  50 -1.386294   0 -1.386294       6     0
## [2,] -0.9942523 3.319626  58 -1.386294   0 -1.386294       6     0
## [3,] -0.5108256 2.691243  74 -1.386294   0 -1.386294       7    20
## [4,] -1.2039728 3.282789  58 -1.386294   0 -1.386294       6     0
## [5,]  0.7514161 3.432373  62 -1.386294   0 -1.386294       6     0
## [6,] -1.0498221 3.228826  50 -1.386294   0 -1.386294       6     0
y <- prostate[is.train, lpsa]
head(y)
## [1] -0.4307829 -0.1625189 -0.1625189 -0.1625189  0.3715636  0.7654678

Below we fit the full path of lasso solutions using the lars package.

library(lars)
fit <- lars(x,y,type="lasso")
fit$lambda
## [1] 7.1939462 3.7172742 2.9403866 1.7305064 1.7002813 0.4933166 0.3711651
## [8] 0.0403451

The path of lambda values are not evenly spaced.

pred.nox <- predict(fit, type="coef")
beta <- scale(pred.nox$coefficients, FALSE, 1/fit$normx)
arclength <- rowSums(abs(beta))
path.list <- list()
for(variable in colnames(beta)){
  standardized.coef <- beta[, variable]
  path.list[[variable]] <- data.table::data.table(
    step=seq_along(standardized.coef),
    lambda=c(fit$lambda, 0),
    variable,
    standardized.coef,
    fraction=pred.nox$fraction,
    arclength)
}
path <- do.call(rbind, path.list)
variable.colors <- c(
  "#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", 
  "#A65628", "#F781BF", "#999999")
library(animint2)
gg.lambda <- ggplot()+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    lambda, standardized.coef, color=variable, group=variable),
    data=path)+
  ggtitle("LASSO path for prostate cancer data calculated using the LARS")
gg.lambda

The plot above shows the entire lasso path, the optimal weights in the L1-regularized least squares regression problem, for every regularization parameter lambda. The path begins at the least squares solution, lambda=0 on the left. It ends at the completely regularized intercept-only model on the right. To see the equivalence with the ordinary least squares solution, we add dots in the plot below.

x.scaled <- with(fit, scale(x, meanx, normx))
lfit <- lm.fit(x.scaled, y)
lpoints <- data.table::data.table(
  variable=colnames(x),
  standardized.coef=lfit$coefficients,
  arclength=sum(abs(lfit$coefficients)))
gg.lambda+
  geom_point(aes(
    0, standardized.coef, color=variable),
    data=lpoints)

In the next plot below, we show the path as a function of L1 norm (arclength), with some more points on an evenly spaced grid that we will use later for animation.

fraction <- sort(unique(c(
  seq(0, 1, l=21))))
pred.fraction <- predict(
  fit, prostate.inputs,
  type="coef", mode="fraction", s=fraction)
coef.grid.list <- list()
coef.grid.mat <- scale(pred.fraction$coefficients, FALSE, 1/fit$normx)
for(fraction.i in seq_along(fraction)){
  standardized.coef <- coef.grid.mat[fraction.i,]
  coef.grid.list[[fraction.i]] <- data.table::data.table(
    fraction=fraction[[fraction.i]],
    variable=colnames(x),
    standardized.coef,
    arclength=sum(abs(standardized.coef)))
}
coef.grid <- do.call(rbind, coef.grid.list)
ggplot()+
  ggtitle("LASSO path for prostate cancer data calculated using the LARS")+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    arclength, standardized.coef, color=variable, group=variable),
    data=path)+
  geom_point(aes(
    arclength, standardized.coef, color=variable),
    data=lpoints)+
  geom_point(aes(
    arclength, standardized.coef, color=variable),
    shape=21,
    fill=NA,
    size=3,
    data=coef.grid)

The plot above shows that the weights at the grid points are consistent with the lines that represent the entire path of solutions. The LARS algorithm quickly provides Lasso solutions for as many grid points as you like. More precisely, since the LARS only computes the change-points in the piecewise linear path, its time complexity only depends on the number of change-points (not the number of grid points).

Interactive visualization of the regularization path

The plot below combines the lasso weight path with the train/test error plot.

pred.list <- predict(
  fit, prostate.inputs,
  mode="fraction", s=fraction)
residual.mat <- pred.list$fit - prostate$lpsa
squares.mat <- residual.mat * residual.mat
mean.error.list <- list()
for(set in c("train", "validation")){
  val <- if(set=="train")TRUE else FALSE
  is.set <- is.train == val
  mse <- colMeans(squares.mat[is.set, ])
  mean.error.list[[paste(set)]] <- data.table::data.table(
    set, mse, fraction,
    arclength=rowSums(abs(coef.grid.mat)))
}
mean.error <- do.call(rbind, mean.error.list)
rect.width <- diff(mean.error$arclength[1:2])/2
addY <- function(dt, y){
  data.table::data.table(dt, y.var=factor(y, c("error", "weights")))
}
tallrect.dt <- coef.grid[variable==variable[1],]
gg.path <- ggplot()+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  facet_grid(y.var ~ ., scales="free")+
  ylab("")+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    arclength, standardized.coef, color=variable, group=variable),
    data=addY(path, "weights"))+
  geom_line(aes(
    arclength, mse, linetype=set, group=set),
    data=addY(mean.error, "error"))+
  geom_tallrect(aes(
    xmin=arclength-rect.width,
    xmax=arclength+rect.width),
    clickSelects="arclength",
    alpha=0.5,
    data=tallrect.dt)
print(gg.path)

Finally, we add a plot of residuals versus actual values.

lasso.res.list <- list()
for(fraction.i in seq_along(fraction)){
  lasso.res.list[[fraction.i]] <- data.table::data.table(
    observation.i=1:nrow(prostate),
    fraction=fraction[[fraction.i]],
    residual=residual.mat[, fraction.i],
    response=prostate$lpsa,
    arclength=sum(abs(coef.grid.mat[fraction.i,])),
    set=ifelse(prostate$train, "train","validation"))
}
lasso.res <- do.call(rbind, lasso.res.list)
hline.dt <- data.table::data.table(residual=0)
gg.res <- ggplot()+
  theme_bw()+
  geom_hline(aes(
    yintercept=residual),
    data=hline.dt,
    color="grey")+
  geom_point(aes(
    response, residual, fill=set, 
    key=observation.i),
    showSelected="arclength",
    shape=21,
    data=lasso.res)
print(gg.res)

Below, we combine the ggplots above in a single animint below. Clicking the first plot changes the regularization parameter, and the residuals that are shown in the second plot.

animint(
  gg.path,
  gg.res,
  duration=list(arclength=2000),
  time=list(variable="arclength", ms=2000))

Re-design with moving tallrects

The re-design below has two changes. First, you may have noticed that there are two different set legends in the previous animint (linetype=set in the first path plot, and color=set in the second residual plot). It would be easier for the reader to decode if the set variable had just one mapping. So in the re-design below we replace the geom_point in the second plot with a geom_segment with linetype=set.

Second, we have replaced the single tallrect in the first plot with two tallrects. The first tallrect has showSelected=arclength and is used to show the selected arclength using a grey rectangle. Since we specify a duration for the arclength variable, and the same key=1 value, we will observe a smooth transition of the selected grey tallrect. The second tallrect has clickSelects=arclength and so clicking it has the effect of changing the selected value of arclength. We specify a another data set with more rows, and use the named clickSelects/showSelected variables to indicate that arclength should also be used as a showSelected variable.

tallrect.show.list <- list()
for(a in tallrect.dt$arclength){
  is.selected <- tallrect.dt$arclength == a
  not.selected <- tallrect.dt[!is.selected]
  tallrect.show.list[[paste(a)]] <- data.table::data.table(
    not.selected, show.val=a, show.var="arclength")
}
tallrect.show <- do.call(rbind, tallrect.show.list)
animint(
  path=ggplot()+
    theme_bw()+
    theme(panel.margin=grid::unit(0, "lines"))+
    facet_grid(y.var ~ ., scales="free")+
    ylab("")+
    scale_color_manual(values=variable.colors)+
    geom_line(aes(
      arclength, standardized.coef, color=variable, group=variable),
      data=addY(path, "weights"))+
    geom_line(aes(
      arclength, mse, linetype=set, group=set),
      data=addY(mean.error, "error"))+
    geom_tallrect(aes(
      xmin=arclength-rect.width,
      xmax=arclength+rect.width,
      key=1),
      showSelected="arclength",
      alpha=0.5,
      data=tallrect.dt)+
    geom_tallrect(aes(
      xmin=arclength-rect.width,
      xmax=arclength+rect.width,
      key=paste(arclength, show.val)),
      clickSelects="arclength",
      showSelected=c("show.var"="show.val"),
      alpha=0.5,
      data=tallrect.show),
  res=ggplot()+
    theme_bw()+
    geom_hline(aes(
      yintercept=residual),
      data=hline.dt,
      color="grey")+
    guides(linetype="none")+
    geom_point(aes(
      response, residual, 
      key=observation.i),
      showSelected=c("set", "arclength"),
      shape=21,
      fill=NA,
      color="black",
      data=lasso.res)+
    geom_text(aes(
      3, 2.5, label=sprintf("L1 arclength = %.1f", arclength),
      key=1),
      showSelected="arclength",
      data=tallrect.dt)+
    geom_text(aes(
      0, -2, label=sprintf("train error = %.3f", mse),
      key=1),
      showSelected=c("set", "arclength"),
      hjust=0,
      data=mean.error[set=="train"])+
    geom_text(aes(
      0, -2.5, label=sprintf("validation error = %.3f", mse),
      key=1),
      showSelected=c("set", "arclength"),
      hjust=0,
      data=mean.error[set=="validation"])+
    geom_segment(aes(
      response, residual,
      xend=response, yend=0,
      linetype=set,
      key=observation.i),
      showSelected=c("set", "arclength"),
      size=1,
      data=lasso.res),
  duration=list(arclength=2000),
  time=list(variable="arclength", ms=2000))

Chapter summary and exercises

We created a visualization of the Lasso machine learning model, which simulataneously shows the regularization path and error curves. Interactivity was used to show details for different values of the regularization parameter.

Exercises:

  • Re-make this data viz, including the same visual effect for the tallrects, using only one geom_tallrect. Hint: create another data set with expand.grid(arclength.click=arclength, arclength.show=arclength), as in the definition of the make_tallrect_or_widerect function.
  • Add another scatterplot that shows predicted values versus response, with a geom_abline in the background to indicate perfect prediction.
  • How would the error curves look if other train/validation splits were chosen? Perform 4-fold cross-validation and add a plot that can be used to select test fold.

Next, Chapter 12 explains how to visualize the Support Vector Machine.