RMS Permutation Demo
Importing Packages
import numpy as np
import pandas as pd
from src.mheatmap import (
amc_postprocess,
rms_permute,
mosaic_heatmap
)
import matplotlib.pyplot as plt
import scipy
import warnings
Load Data
- Load the ground truth labels
Salinas_gt.mat
: Ground truth labels for Salinas dataset- Load the predicted labels from
spectral clustering
# Load the data
y_true = scipy.io.loadmat("data/Salinas_gt.mat")["salinas_gt"].reshape(-1)
# Load predicted labels from spectral clustering
y_pred = np.array(
pd.read_csv(
"data/Salinas_spectralclustering.csv",
header=None,
low_memory=False,
)
.values[1:]
.flatten()
)
print(f"y_true shape: {y_true.shape}")
print(f"y_pred shape: {len(y_pred)}")
AMC Post-processing
- Alignment with
Hungarian
algorithm - Masking the zeros (unlabeled pixels) with
mask_zeros_from_gt
- Computing the confusion matrix
See AMC Post-processing
for more details.
RMS Permutation
- Reverse Merge/Split Idea:
- Merge: \(GT0, GT1 \rightarrow PRD0, PRD0\)
- Split: \(GT0, GT0 \rightarrow PRD0, PRD1\)
- Which impact OA or AA metrics but not ARI (
resolution issue
)
# Demonstrate RMS permutation analysis
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
rms_C, rms_labels, _, rms_map_matrix, rms_map_type = rms_permute(conf_mat, labels)
Visualize the results
# Visualize original vs RMS permuted matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
mosaic_heatmap(conf_mat, ax=ax1, xticklabels=labels, yticklabels=labels, cmap="YlGnBu")
ax1.set_title("Original", fontsize=18, color='#4A4A4A') # Medium gray
ax1.tick_params(colors='#4A4A4A')
mosaic_heatmap(
rms_C, ax=ax2, xticklabels=rms_labels, yticklabels=rms_labels, cmap="YlGnBu"
)
ax2.set_title("RMS Permuted", fontsize=18, color='#4A4A4A') # Medium gray
ax2.xaxis.set_ticks_position('top')
ax2.tick_params(colors='#4A4A4A')
plt.tight_layout()
plt.show()
RMS Matrix Visualization
import seaborn as sns
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
rms_map_matrix,
annot=True,
annot_kws={"size": 18},
cbar=False,
cmap="YlGnBu",
xticklabels=['GT1', 'GT2', 'PRED1', 'PRED2'],
yticklabels=rms_map_type
)
ax.set_title("RMS Matrix", fontsize=18, color='#4A4A4A') # Medium gray
ax.tick_params(colors='#4A4A4A', axis='both', which='major', labelsize=18)
ax.axvline(x=2, color='black', linewidth=2)
plt.tight_layout()
plt.savefig("rms_matrix.png", dpi=300, transparent=True)
plt.show()