Skip to content

RMS Permutation Demo


Importing Packages

import numpy as np
import pandas as pd
from src.mheatmap import (
    amc_postprocess, 
    rms_permute, 
    mosaic_heatmap
)
import matplotlib.pyplot as plt
import scipy
import warnings

Load Data

  • Load the ground truth labels
  • Salinas_gt.mat: Ground truth labels for Salinas dataset
  • Load the predicted labels from spectral clustering
# Load the data
y_true = scipy.io.loadmat("data/Salinas_gt.mat")["salinas_gt"].reshape(-1)
# Load predicted labels from spectral clustering
y_pred = np.array(
    pd.read_csv(
        "data/Salinas_spectralclustering.csv",
        header=None,
        low_memory=False,
    )
    .values[1:]
    .flatten()
)
print(f"y_true shape: {y_true.shape}")
print(f"y_pred shape: {len(y_pred)}")
    y_true shape: (111104,)
    y_pred shape: 111104

AMC Post-processing

  • Alignment with Hungarian algorithm
  • Masking the zeros (unlabeled pixels) with mask_zeros_from_gt
  • Computing the confusion matrix

See AMC Post-processing for more details.

# AMC post-processing
_, conf_mat, labels = amc_postprocess(y_pred, y_true)

RMS Permutation

  • Reverse Merge/Split Idea:
  • Merge: \(GT0, GT1 \rightarrow PRD0, PRD0\)
  • Split: \(GT0, GT0 \rightarrow PRD0, PRD1\)
  • Which impact OA or AA metrics but not ARI (resolution issue)
# Demonstrate RMS permutation analysis
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    rms_C, rms_labels, _, rms_map_matrix, rms_map_type = rms_permute(conf_mat, labels)

Visualize the results

# Visualize original vs RMS permuted matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

mosaic_heatmap(conf_mat, ax=ax1, xticklabels=labels, yticklabels=labels, cmap="YlGnBu")
ax1.set_title("Original", fontsize=18, color='#4A4A4A')  # Medium gray
ax1.tick_params(colors='#4A4A4A')

mosaic_heatmap(
    rms_C, ax=ax2, xticklabels=rms_labels, yticklabels=rms_labels, cmap="YlGnBu"
)
ax2.set_title("RMS Permuted", fontsize=18, color='#4A4A4A')  # Medium gray
ax2.xaxis.set_ticks_position('top')
ax2.tick_params(colors='#4A4A4A')

plt.tight_layout()
plt.show()

png

RMS Matrix Visualization

import seaborn as sns

fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
    rms_map_matrix, 
    annot=True, 
    annot_kws={"size": 18},
    cbar=False,
    cmap="YlGnBu",
    xticklabels=['GT1', 'GT2', 'PRED1', 'PRED2'],
    yticklabels=rms_map_type
)
ax.set_title("RMS Matrix", fontsize=18, color='#4A4A4A')  # Medium gray
ax.tick_params(colors='#4A4A4A', axis='both', which='major', labelsize=18)
ax.axvline(x=2, color='black', linewidth=2)
plt.tight_layout()
plt.savefig("rms_matrix.png", dpi=300, transparent=True)
plt.show()

png