diff --git a/compute_mca.py b/compute_mca.py index 8b99d7c0037b2428e3248a7d068665f50c7ad00a..0389a3496d853da3b859c7fb5038a565af90c3ac 100644 --- a/compute_mca.py +++ b/compute_mca.py @@ -5,7 +5,11 @@ import matplotlib.pyplot as plt ts = np.fromfile('ts.std', dtype=np.float32).reshape(1147, 28, 31) ps = np.fromfile('ps.std', dtype=np.float32).reshape(1147, 28, 31) -svd = supersvd(ts, ps, 4) +w = np.zeros((28, 31), dtype=np.float32) +lat = np.deg2rad(35 + 1.5 * np.arange(31)) +w[:, :] = np.cos(lat).reshape(1, -1) + +svd = supersvd(ts, ps, 4, WX=w, WY=w) first_mode_ts = svd.x_vect[0] diff --git a/main.py b/main.py index 05eba5fae6d95406b68f49dd1e325897688af3cb..d7f4d97362a45da89a45b88d391984a45ac1b42d 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ def main(): help="Data type, default is '%(default)s'") parser.add_argument("-x", metavar="X.STD", required=True, help="X data input file name") parser.add_argument("-y", metavar="Y.STD", required=True, help="Y data input file name") + parser.add_argument("-wx", metavar="WX.STD", required=False, help="X data weight file name") + parser.add_argument("-wy", metavar="WY.STD", required=False, help="Y data weight file name") parser.add_argument("-t", "--time", type=int, required=True, help="Length of the time interval") parser.add_argument("-k", type=int, default=3, help="Number of singular values, default is %(default)d") @@ -29,8 +31,14 @@ def main(): X = np.fromfile(args.x, dtype=dtype).reshape(t, -1) Y = np.fromfile(args.y, dtype=dtype).reshape(t, -1) + WX = None + if args.wx is not None: + WX = np.fromfile(args.wx, dtype=dtype).reshape(-1) + WY = None + if args.wy is not None: + WY = np.fromfile(args.wy, dtype=dtype).reshape(-1) - svd = supersvd(X, Y, args.k, args.elim_mean) + svd = supersvd(X, Y, args.k, args.elim_mean, WX, WY) if args.xv is not None: svd.x_vect.tofile(args.xv) diff --git a/supersvd.py b/supersvd.py index fefe47ac3854b3afdeb69550429b3e84817dba94..c14f927630ee96d2417626a3dca569800dda71ca 100644 --- a/supersvd.py +++ b/supersvd.py @@ -9,7 +9,7 @@ SuperSvdResult = namedtuple('SuperSvdResult', [ 'x_vect', 'y_vect', 'eigenvalue_fraction', 'eigenvalues', ]) -def supersvd(X, Y, k=3, eliminate_mean=True): +def supersvd(X, Y, k=3, eliminate_mean=True, WX=None, WY=None): """ X and Y - the input data for which correlation is seeked dim(X) = nT x nX @@ -27,12 +27,14 @@ def supersvd(X, Y, k=3, eliminate_mean=True): ||YC[e, :]||_2 = ||XC[e, :]||_2 = 1 for each e - XV and YV form an orthogonal basis, i.e. + XV and YV form an orthogonal basis w.r.t. weight, i.e. - sum_i XV[e, i] XV[e', i] = 0 - sum_j YV[e, j] YV[e', j] = 0 + sum_i WX[i] XV[e, i] XV[e', i] = 0 + sum_j WY[j] YV[e, j] YV[e', j] = 0 when e != e' + By default WX[i] = WY[i] = 1 + Returns XC: k x nT YC: k x nT @@ -54,6 +56,10 @@ def supersvd(X, Y, k=3, eliminate_mean=True): X = X - X.mean(axis=0) Y = Y - Y.mean(axis=0) + if WX is not None: + X = X * _np.sqrt(WX.reshape(1, -1)) + if WY is not None: + Y = Y * _np.sqrt(WY.reshape(1, -1)) # Norming makes eigenvalues ~O(1) COV = (X.T @ Y) / nT / (X.shape[1] * Y.shape[1])**0.25 @@ -93,6 +99,11 @@ def supersvd(X, Y, k=3, eliminate_mean=True): Xvar_frac /= Xvar Yvar_frac /= Yvar + if WX is not None: + XV = XV / _np.sqrt(WX.reshape(1, -1)) + if WY is not None: + YV = YV / _np.sqrt(WY.reshape(1, -1)) + return SuperSvdResult( x_coeff=XC, y_coeff=YC,