Skip to content
Snippets Groups Projects
Commit 737ddaef authored by Ivan Tsybulin's avatar Ivan Tsybulin Committed by Maria Tarasevich
Browse files

Weighted fields

parent e4340429
No related branches found
No related tags found
1 merge request!2Weighted fields
......@@ -5,7 +5,11 @@ import matplotlib.pyplot as plt
ts = np.fromfile('ts.std', dtype=np.float32).reshape(1147, 28, 31)
ps = np.fromfile('ps.std', dtype=np.float32).reshape(1147, 28, 31)
svd = supersvd(ts, ps, 4)
w = np.zeros((28, 31), dtype=np.float32)
lat = np.deg2rad(35 + 1.5 * np.arange(31))
w[:, :] = np.cos(lat).reshape(1, -1)
svd = supersvd(ts, ps, 4, WX=w, WY=w)
first_mode_ts = svd.x_vect[0]
......
......@@ -9,6 +9,8 @@ def main():
help="Data type, default is '%(default)s'")
parser.add_argument("-x", metavar="X.STD", required=True, help="X data input file name")
parser.add_argument("-y", metavar="Y.STD", required=True, help="Y data input file name")
parser.add_argument("-wx", metavar="WX.STD", required=False, help="X data weight file name")
parser.add_argument("-wy", metavar="WY.STD", required=False, help="Y data weight file name")
parser.add_argument("-t", "--time", type=int, required=True, help="Length of the time interval")
parser.add_argument("-k", type=int, default=3,
help="Number of singular values, default is %(default)d")
......@@ -29,8 +31,14 @@ def main():
X = np.fromfile(args.x, dtype=dtype).reshape(t, -1)
Y = np.fromfile(args.y, dtype=dtype).reshape(t, -1)
WX = None
if args.wx is not None:
WX = np.fromfile(args.wx, dtype=dtype).reshape(-1)
WY = None
if args.wy is not None:
WY = np.fromfile(args.wy, dtype=dtype).reshape(-1)
svd = supersvd(X, Y, args.k, args.elim_mean)
svd = supersvd(X, Y, args.k, args.elim_mean, WX, WY)
if args.xv is not None:
svd.x_vect.tofile(args.xv)
......
......@@ -9,7 +9,7 @@ SuperSvdResult = namedtuple('SuperSvdResult', [
'x_vect', 'y_vect', 'eigenvalue_fraction', 'eigenvalues',
])
def supersvd(X, Y, k=3, eliminate_mean=True):
def supersvd(X, Y, k=3, eliminate_mean=True, WX=None, WY=None):
"""
X and Y - the input data for which correlation is seeked
dim(X) = nT x nX
......@@ -27,12 +27,14 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
||YC[e, :]||_2 = ||XC[e, :]||_2 = 1 for each e
XV and YV form an orthogonal basis, i.e.
XV and YV form an orthogonal basis w.r.t. weight, i.e.
sum_i XV[e, i] XV[e', i] = 0
sum_j YV[e, j] YV[e', j] = 0
sum_i WX[i] XV[e, i] XV[e', i] = 0
sum_j WY[j] YV[e, j] YV[e', j] = 0
when e != e'
By default WX[i] = WY[i] = 1
Returns
XC: k x nT
YC: k x nT
......@@ -54,6 +56,10 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
X = X - X.mean(axis=0)
Y = Y - Y.mean(axis=0)
if WX is not None:
X = X * _np.sqrt(WX.reshape(1, -1))
if WY is not None:
Y = Y * _np.sqrt(WY.reshape(1, -1))
# Norming makes eigenvalues ~O(1)
COV = (X.T @ Y) / nT / (X.shape[1] * Y.shape[1])**0.25
......@@ -93,6 +99,11 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
Xvar_frac /= Xvar
Yvar_frac /= Yvar
if WX is not None:
XV = XV / _np.sqrt(WX.reshape(1, -1))
if WY is not None:
YV = YV / _np.sqrt(WY.reshape(1, -1))
return SuperSvdResult(
x_coeff=XC,
y_coeff=YC,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment