Contents
function [z, history] = group_lasso(A, b, lambda, p, rho, alpha)
t_start = tic;
Global constants and defaults
QUIET = 0;
MAX_ITER = 1000;
ABSTOL = 1e-4;
RELTOL = 1e-2;
Data preprocessing
[m, n] = size(A);
Atb = A'*b;
if (sum(p) ~= n)
error('invalid partition');
end
cum_part = cumsum(p);
ADMM solver
x = zeros(n,1);
z = zeros(n,1);
u = zeros(n,1);
[L U] = factor(A, rho);
if ~QUIET
fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
'r norm', 'eps pri', 's norm', 'eps dual', 'objective');
end
for k = 1:MAX_ITER
q = Atb + rho*(z - u);
if( m >= n )
x = U \ (L \ q);
else
x = q/rho - (A'*(U \ ( L \ (A*q) )))/rho^2;
end
zold = z;
start_ind = 1;
x_hat = alpha*x + (1-alpha)*zold;
for i = 1:length(p),
sel = start_ind:cum_part(i);
z(sel) = shrinkage(x_hat(sel) + u(sel), lambda/rho);
start_ind = cum_part(i) + 1;
end
u = u + (x_hat - z);
history.objval(k) = objective(A, b, lambda, cum_part, x, z);
history.r_norm(k) = norm(x - z);
history.s_norm(k) = norm(-rho*(z - zold));
history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z));
history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*norm(rho*u);
if ~QUIET
fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ...
history.r_norm(k), history.eps_pri(k), ...
history.s_norm(k), history.eps_dual(k), history.objval(k));
end
if (history.r_norm(k) < history.eps_pri(k) && ...
history.s_norm(k) < history.eps_dual(k))
break;
end
end
if ~QUIET
toc(t_start);
end
end
function p = objective(A, b, lambda, cum_part, x, z)
obj = 0;
start_ind = 1;
for i = 1:length(cum_part),
sel = start_ind:cum_part(i);
obj = obj + norm(z(sel));
start_ind = cum_part(i) + 1;
end
p = ( 1/2*sum((A*x - b).^2) + lambda*obj );
end
function z = shrinkage(x, kappa)
z = pos(1 - kappa/norm(x))*x;
end
function [L U] = factor(A, rho)
[m, n] = size(A);
if ( m >= n )
L = chol( A'*A + rho*speye(n), 'lower' );
else
L = chol( speye(m) + 1/rho*(A*A'), 'lower' );
end
L = sparse(L);
U = sparse(L');
end