| import numpy as np
|
| from tqdm import tqdm
|
| import rasterio
|
| from rasterio.windows import Window
|
| import os
|
| from glob import glob
|
|
|
| def compute_band_statistics(
|
| root,
|
| locations,
|
| modalities=("S1", "S2", "DEM", "Hillshade", "Cloudmask"),
|
| patch_size=256,
|
| stride=256
|
| ):
|
| """
|
| Compute per-band mean and std for each modality using only training locations.
|
| Ignores NaNs. Supports tqdm and optional logger.
|
| """
|
|
|
|
|
| sums = {m: None for m in modalities}
|
| sq_sums = {m: None for m in modalities}
|
| counts = {m: None for m in modalities}
|
|
|
|
|
| for loc in tqdm(locations, desc="Locations"):
|
| loc_dir = os.path.join(root, loc)
|
|
|
|
|
| files = sorted(glob(os.path.join(loc_dir, "*.tif")))
|
|
|
|
|
| for path in tqdm(files, desc=f"Files in {loc}", leave=False):
|
|
|
|
|
| if path.endswith("_s2.tif") and "S2" in modalities:
|
| key = "S2"
|
| elif path.endswith("_s1.tif") and "S1" in modalities:
|
| key = "S1"
|
| elif path.endswith("_dem.tif") and "DEM" in modalities:
|
| key = "DEM"
|
| elif path.endswith("_hillshade.tif") and "Hillshade" in modalities:
|
| key = "Hillshade"
|
| elif path.endswith("_cloud_mask.tif") and "Cloudmask" in modalities:
|
| key = "Cloudmask"
|
| else:
|
| continue
|
|
|
| with rasterio.open(path) as src:
|
| H, W = src.height, src.width
|
| C = src.count
|
|
|
|
|
| if sums[key] is None:
|
| sums[key] = np.zeros(C, dtype=np.float64)
|
| sq_sums[key] = np.zeros(C, dtype=np.float64)
|
| counts[key] = np.zeros(C, dtype=np.float64)
|
|
|
|
|
| for y in range(0, H - patch_size + 1, stride):
|
| for x in range(0, W - patch_size + 1, stride):
|
| window = Window(x, y, patch_size, patch_size)
|
| patch = src.read(window=window).astype(np.float64)
|
|
|
| patch = patch.reshape(C, -1)
|
|
|
| valid_mask = np.isfinite(patch)
|
|
|
| valid_values = np.where(valid_mask, patch, 0)
|
|
|
|
|
| sums[key] += valid_values.sum(axis=1)
|
| sq_sums[key] += (valid_values ** 2).sum(axis=1)
|
| counts[key] += valid_mask.sum(axis=1)
|
|
|
|
|
| band_stats = {}
|
|
|
| for m in modalities:
|
|
|
| if counts[m] is None or (counts[m] == 0).all():
|
| continue
|
|
|
| mean = sums[m] / counts[m]
|
| sq_mean = sq_sums[m] / counts[m]
|
| var = sq_mean - mean**2
|
| var[var < 0] = 0
|
| std = np.sqrt(var)
|
|
|
| band_stats[m] = {
|
| "mean": mean.tolist(),
|
| "std": std.tolist()
|
| }
|
|
|
| return band_stats
|
|
|