TransWikia.com

Solving $E[abXab]=Y$ for Gaussian $(a,b)$

Mathematica Asked on December 7, 2020

I have $d/2$-dimensional variables $a,b$ jointly distributed as Gaussian($mu,Sigma$) in $d$ dimensions, and need to solve the following equation for $X$

$$E[ab^TXab^T]=Y$$

This is equivalent to solving the following for $X$ (from Wick’s theorem)

$$B X A + CX^TC + C mbox{Tr}(X^T C) – 2 ba’ b’Xa=Y$$

where $A,B,C,a,b$ are equal-sized partitions of Gaussian first and second moments

$$(a,b)=E[x]=mu$$

$$
left(begin{matrix}
A&C^T
C&B
end{matrix}right)
=E[xx’]=
Sigma+mumu^T
$$

Below is somewhat brute-force solution to solving it (apply Sherman-Morrison formula twice, then use this) which takes about 12 seconds for $d=1024$. There are some repeated expressions but simplifying by hand gets quite error prone, can someone see a way to speed this up?

(*Expectation of expression*)
Ex[expr_] := Expectation[expr, x [Distributed] dist];
split[vec_] := ArrayReshape[vec, {2, Length[vec]/2}];

CircleTimes = KroneckerProduct;

(* *Solve AX+XB=C.Equivalent to LyapunovSolve[A,B,C] but faster/more 
stable/works when ill-posed*)
    
sylvester[A_, B_, C0_] := 
  Module[{da, db, DA, T, DB, U, denom, cutoff, sdiv, Y}, {da, db} = 
    Length /@ {A, B};
   {DA, T} = Eigensystem[A + $MachineEpsilon IdentityMatrix[da]];
   T = Transpose[T];
   {DB, U} = Eigensystem[B + $MachineEpsilon IdentityMatrix[db]];
   U = Transpose[U];
   denom = Outer[Plus, DA, DB];
   cutoff = Max@Abs[denom]*10^6*$MachineEpsilon;
   sdiv = Map[If[Abs[#] > cutoff, 1/#, #] &, denom, {2}];
   Y = Inverse[T].C0.U*sdiv;
   T.Y.Inverse[U]];

(*Solve T-Sylvester equation AX+X[Transpose]B=C by reducing to 
Sylvester equation. Eddy's recipe from
https://mathematica.stackexchange.com/a/207044/217*)

tsylvester[a_, b_, c_] := 
  Module[{g, ig, h, u, x}, g = a + b[Transpose];
   ig = Inverse[g];
   h = (c + c[Transpose])/2;
   u = sylvester[a.ig, -ig[Transpose].b, 
     c - a.ig.h - h.ig[Transpose].b];
   u = (u - u[Transpose])/2;
   x = ig.(h + u);
   If[ValueQ[debug], 
    Print["tsylvester error is ", Norm[a.x + x[Transpose].b - c]]];
   x];


(*Solve generalized T-Sylvester BXA+CX[Transpose]D=E by reducing to 
T-Sylvester equation*)

generalizedTSylvester[a_, b_, c_, d_, 
   e_] := (tsylvester[Inverse[c].b, d.Inverse[a], 
    Inverse[c].e.Inverse[a]]);

(*Solve generalized T-Sylvester equation with rank-1 
correction:BXA+CX[Transpose]D+Utr(X[Transpose]U)=Y*)

generalizedTSylvesterRank1[A_, B_, C_, D_, U_, Y_] := 
  Module[{X, divAU, divAX, Y2}, 
   divAU = generalizedTSylvester[A, B, C, D, U];
   divAX = generalizedTSylvester[A, B, C, D, Y];
   X = divAX - dot[U, divAX]/(1 + dot[U, divAU]) divAU;
   If[ValueQ[debug], Y2 = B.X.A + C.X[Transpose].D + U dot[U, X];
    Print["generalizedTSylvesterRank1 error: ", Norm[Y - Y2]]];
   X];

(*Solve generalized T-Sylvester equation with rank-2 
correction:BXA+CX[Transpose]D+Utr(X[Transpose]U)-V 
tr(X[Transpose]V)=Y*)

generalizedTSylvesterRank2[A_, B_, C_, D_, U_, V_, Y_] := 
  Module[{divAU, divAX, X, Y2}, 
   divAU = generalizedTSylvesterRank1[A, B, C, D, U, V];
   divAX = generalizedTSylvesterRank1[A, B, C, D, U, Y];
   X = divAX + dot[V, divAX]/(1 - dot[V, divAU]) divAU;
   If[ValueQ[debug], 
    Y2 = B.X.A + C.X[Transpose].D + U dot[U, X] - V dot[V, X];
    Print["generalizedTSylvesterRank2 error: ", Norm[Y - Y2]]];
   X];



partitionMatrix[mat_, {a_, b_}] := 
  Module[{}, Assert[a + b == Length@mat];
    Assert;[a + b == Length@mat[Transpose]];
   Internal`PartitionRagged[mat, {{a, b}, {a, b}}]];

setupProblem[d0_] := (d = d0;
   x = Array[xx, d];
   {a, b} = split[x];(*{a1,a2,...},{b1,b2,...}*)
   mu = RandomReal[{-1, 1}, {d}];
   diag = DiagonalMatrix@Table[1/k, {k, 1, d}];
   rot = RandomVariate[CircularRealMatrixDistribution[d]];
   sigma = rot.diag.rot[Transpose];
   dist = MultinormalDistribution[mu, sigma];
   X = RandomReal[{-1, 1}, {d/2, d/2}];);

dot[mat1_, mat2_] := Total[mat1*mat2, 2];

(***Modify this to change problem size***)
SeedRandom[1];
setupProblem[1024]

{{AA, AB}, {BA, BB}} = 
  partitionMatrix[sigma + Outer[Times, mu, mu], {d/2, d/2}];
{A, B} = split[mu];

wicksForward[X_] := 
  BB.X.AA + Transpose[AB.X.AB] + BA dot[X, BA] - 
   2 Outer[Times, B, A] B.X.A;
wicksBackward[Y_] := 
  With[{BA0 = Sqrt[2] Outer[Times, B, A]}, 
   generalizedTSylvesterRank2[AA, BB, BA, BA, BA, BA0, Y]];

(*correctness check for small d*)
If[d < 8, On[Assert];
  Y1 = Ex[(a[CircleTimes]b).X.(a[CircleTimes]b)][Transpose];
  Y2 = wicksForward[X];
  Assert[Y1 == Y2]];
Y = wicksForward[X];
Norm[wicksBackward[Y] - X] // Timing (*{12.1, 0.00005}*)

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP