TransWikia.com

Cross-Entropy Minimization - Extreme Code Performance

Code Review Asked by Tommaso Belluzzo on October 27, 2021

I’m working on a multivariate cross-entropy minimization model (for more details about it, see this paper, pp. 32-33). It’s purpose is to adjust a prior multivariate distribution (in this case, a gaussian normal) with information on marginals coming from real observations.

The code at the end of the post represents my current implementation. The maths should have been correctly reproduced, unless I missed something critical during the review. The real problem I’m struggling to deal with is the performance of the code.

In the first part of the model, cumulative probabilities have to be computed over all the orthants of the distribution density. This process has a time complexity of 2^N, where N is the number of entities included into the dataset. As long as the number of entities is less than 12, everything is fast enough on my PC. With 20 entities, which is my current target, the model needs to run mvncdf over 1048576 combinations of orthants and this takes forever to finish.

I already improved the code a little bit by replacing the main for loop with a parfor loop. I acquired a huge performance gain by replacing the built-in mvncdf function with a user-made one.

I’m not very familiar with cross-entropy minimization models, so maybe there are math tricks I can use to simplify this calculation. Maybe the code can be vectorized even more. Well… any help or suggestion to improve the calculations speed is more than welcome!

clc();
clear();

% DATA

pods = [0.015; 0.02; 0.013; 0.007; 0.054; 0.034; 0.009; 0.065; 0.029; 0.205];
dts = [2.1; 2; 2.2; 2.4; 1.5; 1.8; 2.3; 1.5; 1.8; 0.8];

% Test of time complexity:
% pods = [pods; pods];
% dts = [dts; dts];

n = numel(pods);
c = eye(n);

k = 2^n;
kh = k / 2;
offsets = ones(n,1);

% G / BOUNDS FOR 1

g1 = combn([0 1],n);
bounds_1 = zeros(k,1);

parfor i = 1:k
    g1_c = g1(i,:).';
    lb = min([(-Inf * ~g1_c) dts],[],2);
    ub = max([(Inf * g1_c) dts],[],2);

    bounds_1(i) = mvncdf2(c,lb,ub);
end

% G / BOUNDS FOR 2:N

g2 = repmat({zeros(kh,n)},n,1);
bounds_2 = zeros(n,kh);

for i = 2:k
    g1_c = g1(i,:);
    b = bounds_1(i);

    for j = 1:n
        if (g1_c(j) == 0)
            continue;
        end           
            
        offset_j = offsets(j);

        g2t_j = g2{j};
        g2t_j(offset_j,:) = g1_c;
        g2{j} = g2t_j;

        bounds_2(j,offset_j) = b;

        offsets(j) = offset_j + 1;
    end

end

% SOLUTION

options = optimset(optimset(@fsolve),'Display','iter','TolFun',1e-08,'TolX',1e-08);
cns = [1; pods];
x0 = zeros(size(pods,1)+1,1);
lm = fsolve(@(x)objective(x,n,g1,bounds_1,g2,bounds_2,cns),x0,options);

stop = 1;

% Objective function of the model.
function p = objective(x,n,g1,bounds_1,g2,bounds_2,cns)

    mu = x(1);
    lambda = x(2:end);

    p = zeros(n + 1,1);

    for i = 1:numel(bounds_1)
        p(1) = p(1) + exp(-g1(i,:) * lambda) * bounds_1(i);
    end

    for i = 1:n
        g2_k = g2{i,1};
        
        for j = 1:size(bounds_2,2)
            p(i+1) = p(i+1) + exp(-g2_k(j,:) * lambda) * bounds_2(i,j);
        end
    end
    
    p = (exp(-1-mu) * p) - cns;

end

% All combinations of elements.
function [m,i] = combn(v,n)

    if ((fix(n) ~= n) || (n < 1) || (numel(n) ~= 1))
        error('Parameter N must be a scalar positive integer.');
    end

    if (isempty(v))
        m = [];
        i = [];
    elseif (n == 1)
        m = v(:); 
        i = (1:numel(v)).';
    else
        i = combn_local(1:numel(v),n);
        m = v(i);
    end
    
    function y = combn_local(v,n)

        if (n > 1)
            [y{n:-1:1}] = ndgrid(v);
            y = reshape(cat(n+1,y{:}),[],n);
        else
            y = v(:);
        end

    end
    
end

% Multivariate normal cumulative distribution function.
function y = mvncdf2(c,lb,ub)

    persistent options;

    if (isempty(options))
        options = optimset(optimset(@fsolve),'Algorithm','trust-region-dogleg','Diagnostics','off','Display','off','Jacobian','on');
    end
    
    n = size(c,1);

    [cp,lb,ub] = cholperm(n,c,lb,ub);
    d = diag(cp);

    if any(d < eps())
        y = NaN;
        return;
    end

    lb = lb ./ d;
    ub = ub ./ d;
    cp = (cp ./ repmat(d,1,n)) - eye(n);

    [sol,~,exitflag] = fsolve(@(x)gradpsi(x,cp,lb,ub),zeros(2 * (n - 1),1),options);

    if (exitflag ~= 1)
        y = NaN;
        return;
    end

    x = sol(1:(n - 1));
    x(n) = 0;
    x = x(:);
    
    mu = sol(n:((2 * n) - 2));
    mu(n) = 0;
    mu = mu(:);
    
    c = cp * x;
    lb = lb - mu - c;
    ub = ub - mu - c;

    y = exp(sum(lnpr(lb,ub) + (0.5 * mu.^2) - (x .* mu)));

end

function [cp,l,u] = cholperm(n,c,l,u)

    s2p = sqrt(2 * pi());

    cp = zeros(n,n);
    z = zeros(n,1);

    for j = 1:n
        j_seq = 1:(j - 1);
        jn_seq = j:n;
        j1n_seq = (j + 1):n;

        cp_off = cp(jn_seq,j_seq);
        z_off = z(j_seq);
        cpz = cp_off * z_off;

        d = diag(c);
        s = d(jn_seq) - sum(cp_off.^2,2);
        s(s < 0) = eps();
        s = sqrt(s);

        lt = (l(jn_seq) - cpz) ./ s;
        ut = (u(jn_seq) - cpz) ./ s;

        p = Inf(n,1);
        p(jn_seq) = lnpr(lt,ut);

        [~,k] = min(p);
        jk = [j k];
        kj = [k j];

        c(jk,:) = c(kj,:);
        c(:,jk) = c(:,kj);

        cp(jk,:) = cp(kj,:);
        l(jk) = l(kj);
        u(jk) = u(kj);

        s = c(j,j) - sum(cp(j,j_seq).^2);
        s(s < 0) = eps();

        cp(j,j) = sqrt(s);
        cp(j1n_seq,j) = (c(j1n_seq,j) - (cp(j1n_seq,j_seq) * (cp(j,j_seq)).')) / cp(j,j);

        cp_jj = cp(j,j);
        cpz = cp(j,j_seq) * z(j_seq);
        lt = (l(j) - cpz) / cp_jj;
        ut = (u(j) - cpz) / cp_jj;

        w = lnpr(lt,ut);
        z(j) = (exp((-0.5 * lt.^2) - w) - exp((-0.5 * ut.^2) - w)) / s2p;
    end

end

function [g,j] = gradpsi(y,L,l,u)

    d = length(u);
    d_seq = 1:(d - 1);

    x = zeros(d,1);
    x(d_seq) = y(d_seq);

    mu = zeros(d,1);
    mu(d_seq) = y(d:end);

    c = zeros(d,1);
    c(2:d) = L(2:d,:) * x;

    lt = l - mu - c;
    ut = u - mu - c;

    w = lnpr(lt,ut);
    pd = sqrt(2 * pi());
    pl = exp((-0.5 * lt.^2) - w) / pd;
    pu = exp((-0.5 * ut.^2) - w) / pd;
    p = pl - pu;

    dfdx = -mu(d_seq) + (p.' * L(:,d_seq)).';
    dfdm = mu - x + p;
    g = [dfdx; dfdm(d_seq)];

    lt(isinf(lt)) = 0;
    ut(isinf(ut)) = 0;

    dp = -p.^2 + (lt .* pl) - (ut .* pu);
    dl = repmat(dp,1,d) .* L;

    mx = -eye(d) + dl;
    mx = mx(d_seq,d_seq);
    
    xx = L.' * dl;
    xx = xx(d_seq,d_seq);

    j = [xx mx.'; mx diag(1 + dp(d_seq))];

end

function p = lnpr(a,b)

    p = zeros(size(a));

    a_indices = a > 0;
    
    if (any(a_indices))
        x = a(a_indices);
        pa = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));
        
        x = b(a_indices);
        pb = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));

        p(a_indices) = pa + log1p(-exp(pb - pa));
    end

    b_indices = b < 0;

    if (any(b_indices))
        x = -a(b_indices);
        pa = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));

        x = -b(b_indices);
        pb = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));

        p(b_indices) = pb + log1p(-exp(pa - pb));
    end

    indices = ~a_indices & ~b_indices;

    if (any(indices))
        pa = erfc(-a(indices) / sqrt(2)) / 2;
        pb = erfc(b(indices) / sqrt(2)) / 2;
        p(indices) = log1p(-pa - pb);
    end

end

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP