diff --git a/compute_mca.py b/compute_mca.py
index 8b99d7c0037b2428e3248a7d068665f50c7ad00a..0389a3496d853da3b859c7fb5038a565af90c3ac 100644
--- a/compute_mca.py
+++ b/compute_mca.py
@@ -5,7 +5,11 @@ import matplotlib.pyplot as plt
 ts = np.fromfile('ts.std', dtype=np.float32).reshape(1147, 28, 31)
 ps = np.fromfile('ps.std', dtype=np.float32).reshape(1147, 28, 31)
 
-svd = supersvd(ts, ps, 4)
+w = np.zeros((28, 31), dtype=np.float32)
+lat = np.deg2rad(35 + 1.5 * np.arange(31))
+w[:, :] = np.cos(lat).reshape(1, -1)
+
+svd = supersvd(ts, ps, 4, WX=w, WY=w)
 
 first_mode_ts = svd.x_vect[0]
 
diff --git a/main.py b/main.py
index 05eba5fae6d95406b68f49dd1e325897688af3cb..d7f4d97362a45da89a45b88d391984a45ac1b42d 100644
--- a/main.py
+++ b/main.py
@@ -9,6 +9,8 @@ def main():
                         help="Data type, default is '%(default)s'")
     parser.add_argument("-x", metavar="X.STD", required=True, help="X data input file name")
     parser.add_argument("-y", metavar="Y.STD", required=True, help="Y data input file name")
+    parser.add_argument("-wx", metavar="WX.STD", required=False, help="X data weight file name")
+    parser.add_argument("-wy", metavar="WY.STD", required=False, help="Y data weight file name")
     parser.add_argument("-t", "--time", type=int, required=True, help="Length of the time interval")
     parser.add_argument("-k", type=int, default=3,
                         help="Number of singular values, default is %(default)d")
@@ -29,8 +31,14 @@ def main():
 
     X = np.fromfile(args.x, dtype=dtype).reshape(t, -1)
     Y = np.fromfile(args.y, dtype=dtype).reshape(t, -1)
+    WX = None
+    if args.wx is not None:
+        WX = np.fromfile(args.wx, dtype=dtype).reshape(-1)
+    WY = None
+    if args.wy is not None:
+        WY = np.fromfile(args.wy, dtype=dtype).reshape(-1)
 
-    svd = supersvd(X, Y, args.k, args.elim_mean)
+    svd = supersvd(X, Y, args.k, args.elim_mean, WX, WY)
     
     if args.xv is not None:
         svd.x_vect.tofile(args.xv)
diff --git a/supersvd.py b/supersvd.py
index fefe47ac3854b3afdeb69550429b3e84817dba94..c14f927630ee96d2417626a3dca569800dda71ca 100644
--- a/supersvd.py
+++ b/supersvd.py
@@ -9,7 +9,7 @@ SuperSvdResult = namedtuple('SuperSvdResult', [
     'x_vect', 'y_vect', 'eigenvalue_fraction', 'eigenvalues',
     ])
 
-def supersvd(X, Y, k=3, eliminate_mean=True):
+def supersvd(X, Y, k=3, eliminate_mean=True, WX=None, WY=None):
     """
     X and Y - the input data for which correlation is seeked
     dim(X) = nT x nX
@@ -27,12 +27,14 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
 
     ||YC[e, :]||_2 = ||XC[e, :]||_2 = 1 for each e
 
-    XV and YV form an orthogonal basis, i.e.
+    XV and YV form an orthogonal basis w.r.t. weight, i.e.
 
-    sum_i XV[e, i] XV[e', i] = 0 
-    sum_j YV[e, j] YV[e', j] = 0
+    sum_i WX[i] XV[e, i] XV[e', i] = 0
+    sum_j WY[j] YV[e, j] YV[e', j] = 0
     when e != e'
 
+    By default WX[i] = WY[i] = 1
+
     Returns 
         XC: k x nT
         YC: k x nT
@@ -54,6 +56,10 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
         X = X - X.mean(axis=0)
         Y = Y - Y.mean(axis=0)
 
+    if WX is not None:
+        X = X * _np.sqrt(WX.reshape(1, -1))
+    if WY is not None:
+        Y = Y * _np.sqrt(WY.reshape(1, -1))
     # Norming makes eigenvalues ~O(1)
     COV = (X.T @ Y) / nT / (X.shape[1] * Y.shape[1])**0.25
 
@@ -93,6 +99,11 @@ def supersvd(X, Y, k=3, eliminate_mean=True):
     Xvar_frac /= Xvar
     Yvar_frac /= Yvar
 
+    if WX is not None:
+        XV = XV / _np.sqrt(WX.reshape(1, -1))
+    if WY is not None:
+        YV = YV / _np.sqrt(WY.reshape(1, -1))
+
     return SuperSvdResult(
         x_coeff=XC,
         y_coeff=YC,