Contents
function [x, history] = group_lasso_feat_split(A, b, lambda, ni, RHO, ALPHA)
t_start = tic;
Global constants and defaults
QUIET = 0;
MAX_ITER = 100;
RELTOL = 1e-2;
ABSTOL = 1e-4;
Data preprocessing
[m, n] = size(A);
if (rem(n,ni) ~= 0)
error('invalid block size');
end
N = n/ni;
ADMM solver
rho = RHO;
alpha = ALPHA;
x = zeros(ni,N);
z = zeros(m,1);
u = zeros(m,1);
Axbar = zeros(m,1);
zs = zeros(m,N);
Aixi = zeros(m,N);
if ~QUIET
fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
'r norm', 'eps pri', 's norm', 'eps dual', 'objective');
end
for i = 1:N,
Ai = A(:,(i-1)*ni + 1:i*ni);
[Vi,Di] = eig(Ai'*Ai);
V{i} = Vi;
D{i} = diag(Di);
At{i} = Ai';
end
for k = 1:MAX_ITER
for i = 1:N,
Ai = A(:,(i-1)*ni + 1:i*ni);
xx = x_update(Ai, Aixi(:,i) + z - Axbar - u, lambda/rho, V{i}, D{i});
x(:,i) = xx;
Aixi(:,i) = Ai*x(:,i);
end
zold = z;
Axbar = 1/N*A*vec(x);
Axbar_hat = alpha*Axbar + (1-alpha)*zold;
z = (b + rho*(Axbar_hat + u))/(N+rho);
u = u + Axbar_hat - z;
s = 0; q = 0;
zsold = zs;
zs = z*ones(1,N) + Aixi - Axbar*ones(1,N);
for i = 1:N,
s = s + norm(-rho*At{i}*(zs(:,i) - zsold(:,i)))^2;
q = q + norm(rho*At{i}*u)^2;
end
history.objval(k) = objective(A, b, lambda, N, x, z);
history.r_norm(k) = sqrt(N)*norm(z - Axbar);
history.s_norm(k) = sqrt(s);
history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(Aixi,'fro'), norm(-zs, 'fro'));
history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*sqrt(q);
if ~QUIET
fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ...
history.r_norm(k), history.eps_pri(k), ...
history.s_norm(k), history.eps_dual(k), history.objval(k));
end
if history.r_norm(k) < history.eps_pri(k) && ...
history.s_norm(k) < history.eps_dual(k);
break
end
end
if ~QUIET
toc(t_start);
end
end
function p = objective(A, b, lambda, N, x, z)
p = ( 1/2*sum_square(N*z - b) + lambda*sum(norms(x)) );
end
function x = x_update(A, b, kappa, V, D)
[m,n] = size(A);
q = A'*b;
if (norm(q) <= kappa)
x = zeros(n,1);
else
lower = 0; upper = 1e10;
for i = 1:100,
t = (upper + lower)/2;
x = V*((V'*q)./(D + t));
if t > kappa/norm(x),
upper = t;
else
lower = t;
end
if (upper - lower <= 1e-6)
break;
end
end
end
end