Skip to content
Snippets Groups Projects
Commit b7264c53 authored by Maria Tarasevich's avatar Maria Tarasevich
Browse files

Merge branch 'weighted-fields' into 'master'

Weighted fields

See merge request !2
parents e4340429 737ddaef
No related branches found
No related tags found
1 merge request!2Weighted fields
...@@ -5,7 +5,11 @@ import matplotlib.pyplot as plt ...@@ -5,7 +5,11 @@ import matplotlib.pyplot as plt
ts = np.fromfile('ts.std', dtype=np.float32).reshape(1147, 28, 31) ts = np.fromfile('ts.std', dtype=np.float32).reshape(1147, 28, 31)
ps = np.fromfile('ps.std', dtype=np.float32).reshape(1147, 28, 31) ps = np.fromfile('ps.std', dtype=np.float32).reshape(1147, 28, 31)
svd = supersvd(ts, ps, 4) w = np.zeros((28, 31), dtype=np.float32)
lat = np.deg2rad(35 + 1.5 * np.arange(31))
w[:, :] = np.cos(lat).reshape(1, -1)
svd = supersvd(ts, ps, 4, WX=w, WY=w)
first_mode_ts = svd.x_vect[0] first_mode_ts = svd.x_vect[0]
......
...@@ -9,6 +9,8 @@ def main(): ...@@ -9,6 +9,8 @@ def main():
help="Data type, default is '%(default)s'") help="Data type, default is '%(default)s'")
parser.add_argument("-x", metavar="X.STD", required=True, help="X data input file name") parser.add_argument("-x", metavar="X.STD", required=True, help="X data input file name")
parser.add_argument("-y", metavar="Y.STD", required=True, help="Y data input file name") parser.add_argument("-y", metavar="Y.STD", required=True, help="Y data input file name")
parser.add_argument("-wx", metavar="WX.STD", required=False, help="X data weight file name")
parser.add_argument("-wy", metavar="WY.STD", required=False, help="Y data weight file name")
parser.add_argument("-t", "--time", type=int, required=True, help="Length of the time interval") parser.add_argument("-t", "--time", type=int, required=True, help="Length of the time interval")
parser.add_argument("-k", type=int, default=3, parser.add_argument("-k", type=int, default=3,
help="Number of singular values, default is %(default)d") help="Number of singular values, default is %(default)d")
...@@ -29,8 +31,14 @@ def main(): ...@@ -29,8 +31,14 @@ def main():
X = np.fromfile(args.x, dtype=dtype).reshape(t, -1) X = np.fromfile(args.x, dtype=dtype).reshape(t, -1)
Y = np.fromfile(args.y, dtype=dtype).reshape(t, -1) Y = np.fromfile(args.y, dtype=dtype).reshape(t, -1)
WX = None
if args.wx is not None:
WX = np.fromfile(args.wx, dtype=dtype).reshape(-1)
WY = None
if args.wy is not None:
WY = np.fromfile(args.wy, dtype=dtype).reshape(-1)
svd = supersvd(X, Y, args.k, args.elim_mean) svd = supersvd(X, Y, args.k, args.elim_mean, WX, WY)
if args.xv is not None: if args.xv is not None:
svd.x_vect.tofile(args.xv) svd.x_vect.tofile(args.xv)
......
...@@ -9,7 +9,7 @@ SuperSvdResult = namedtuple('SuperSvdResult', [ ...@@ -9,7 +9,7 @@ SuperSvdResult = namedtuple('SuperSvdResult', [
'x_vect', 'y_vect', 'eigenvalue_fraction', 'eigenvalues', 'x_vect', 'y_vect', 'eigenvalue_fraction', 'eigenvalues',
]) ])
def supersvd(X, Y, k=3, eliminate_mean=True): def supersvd(X, Y, k=3, eliminate_mean=True, WX=None, WY=None):
""" """
X and Y - the input data for which correlation is seeked X and Y - the input data for which correlation is seeked
dim(X) = nT x nX dim(X) = nT x nX
...@@ -27,12 +27,14 @@ def supersvd(X, Y, k=3, eliminate_mean=True): ...@@ -27,12 +27,14 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
||YC[e, :]||_2 = ||XC[e, :]||_2 = 1 for each e ||YC[e, :]||_2 = ||XC[e, :]||_2 = 1 for each e
XV and YV form an orthogonal basis, i.e. XV and YV form an orthogonal basis w.r.t. weight, i.e.
sum_i XV[e, i] XV[e', i] = 0 sum_i WX[i] XV[e, i] XV[e', i] = 0
sum_j YV[e, j] YV[e', j] = 0 sum_j WY[j] YV[e, j] YV[e', j] = 0
when e != e' when e != e'
By default WX[i] = WY[i] = 1
Returns Returns
XC: k x nT XC: k x nT
YC: k x nT YC: k x nT
...@@ -54,6 +56,10 @@ def supersvd(X, Y, k=3, eliminate_mean=True): ...@@ -54,6 +56,10 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
X = X - X.mean(axis=0) X = X - X.mean(axis=0)
Y = Y - Y.mean(axis=0) Y = Y - Y.mean(axis=0)
if WX is not None:
X = X * _np.sqrt(WX.reshape(1, -1))
if WY is not None:
Y = Y * _np.sqrt(WY.reshape(1, -1))
# Norming makes eigenvalues ~O(1) # Norming makes eigenvalues ~O(1)
COV = (X.T @ Y) / nT / (X.shape[1] * Y.shape[1])**0.25 COV = (X.T @ Y) / nT / (X.shape[1] * Y.shape[1])**0.25
...@@ -93,6 +99,11 @@ def supersvd(X, Y, k=3, eliminate_mean=True): ...@@ -93,6 +99,11 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
Xvar_frac /= Xvar Xvar_frac /= Xvar
Yvar_frac /= Yvar Yvar_frac /= Yvar
if WX is not None:
XV = XV / _np.sqrt(WX.reshape(1, -1))
if WY is not None:
YV = YV / _np.sqrt(WY.reshape(1, -1))
return SuperSvdResult( return SuperSvdResult(
x_coeff=XC, x_coeff=XC,
y_coeff=YC, y_coeff=YC,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment