import numpy as np
import plotly.graph_objects as go
from scipy.stats import multivariate_normal
# --- (Assuming df0 is already loaded as per your previous code) ---
# Data Preparation
# We still define X, but we won't rely on df0['color'] for the animation anymore
X = df0[['eruptions', 'waiting']].values
# --- MODIFICATION 1: Initialization ---
np.random.seed(42)
K = 2
n_samples, n_features = X.shape
# Old way: random_indices = np.random.choice(n_samples, K, replace=False)
# Old way: mu = X[random_indices].copy()
# New way: Initialize centroids "far" apart (Min vs Max of the dataset)
ids = np.random.choice(n_samples, 1, replace=False)
mu = np.array([X[ids[0],:], X[ids[0]-20,:]])
global_cov = np.cov(X.T)
cov = [global_cov.copy() for _ in range(K)]
pi = np.array([0.5, 0.5])
# Visualization Grid (Kept same)
x_range = np.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, 60)
y_range = np.linspace(X[:, 1].min() - 5, X[:, 1].max() + 5, 60)
xx, yy = np.meshgrid(x_range, y_range)
pos = np.dstack((xx, yy))
def get_mixture_density(pi, mu, cov, grid_pos):
z = np.zeros(grid_pos.shape[:2])
for k in range(K):
z += pi[k] * multivariate_normal(mu[k], cov[k]).pdf(grid_pos)
return z
# EM Algorithm with Frame Capture
iterations = 15
frames = []
for i in range(iterations):
# E-STEP
probs = np.zeros((n_samples, K))
for k in range(K):
probs[:, k] = pi[k] * multivariate_normal(mu[k], cov[k]).pdf(X)
resp = probs / (probs.sum(axis=1, keepdims=True) + 1e-10)
# --- MODIFICATION 2: Dynamic Coloring based on current Label ---
# Assign label based on which cluster has higher responsibility
current_labels = np.argmax(resp, axis=1)
# Map labels to colors (0 -> Red, 1 -> Blue)
current_colors = np.where(current_labels == 0, 'red', 'blue')
# Text for parameters
param_text = (
f"<b>Iteration {i}</b><br>"
f"π: [{pi[0]:.2f}, {pi[1]:.2f}]<br>"
f"μ1: [{mu[0][0]:.2f}, {mu[0][1]:.2f}]<br>"
f"μ2: [{mu[1][0]:.2f}, {mu[1][1]:.2f}]"
)
z_total = get_mixture_density(pi, mu, cov, pos)
# Store state: Surface + Observations + Centroids
frames.append(go.Frame(
data=[
go.Surface(z=z_total, x=xx, y=yy), # Trace 0
# Trace 1 (Observations): Now uses 'current_colors'
go.Scatter3d(
x=df0['eruptions'],
y=df0['waiting'],
z=np.zeros(n_samples),
mode='markers',
marker=dict(size=4, color=current_colors, opacity=0.8)
),
go.Scatter3d( # Trace 2 (Moving Centroids)
x=mu[:, 0],
y=mu[:, 1],
z=[0, 0],
mode='markers',
marker=dict(size=10, color='black', symbol='diamond', line=dict(width=2, color='white')),
name='Centroids'
)
],
name=f"Iteration {i}",
layout=go.Layout(annotations=[dict(
text=param_text, align='left', showarrow=False,
x=0.05, y=0.95, xref='paper', yref='paper',
bgcolor="rgba(255, 255, 255, 0.7)", bordercolor="black", borderwidth=1
)])
))
# M-STEP
N_k = resp.sum(axis=0)
for k in range(K):
mu[k] = (resp[:, k, np.newaxis] * X).sum(axis=0) / N_k[k]
diff = X - mu[k]
weighted_diff = resp[:, k, np.newaxis, np.newaxis] * np.einsum('ni,nj->nij', diff, diff)
cov[k] = weighted_diff.sum(axis=0) / N_k[k]
pi[k] = N_k[k] / n_samples
# Build Figure
fig = go.Figure(
data=[
go.Surface(
z=frames[0].data[0].z, x=xx, y=yy,
opacity=0.6, colorscale='Viridis', showscale=False, name='Density'
),
go.Scatter3d(
x=df0['eruptions'].values, y=df0['waiting'].values, z=np.zeros(n_samples),
# --- MODIFICATION 3: Initial Plot uses colors from Frame 0 ---
mode='markers', marker=dict(size=4, color=frames[0].data[1].marker.color, opacity=0.8),
name='Observations'
),
go.Scatter3d(
x=frames[0].data[2].x, y=frames[0].data[2].y, z=[0, 0],
mode='markers', marker=dict(size=8, color='black', symbol='diamond', line=dict(width=2, color='white')),
name='Centroids'
)
],
layout=go.Layout(
width=400, height=470,
title="EM Algorithm Dynamic",
scene=dict(
xaxis_title='Eruptions', yaxis_title='Waiting', zaxis_title='Density',
camera=dict(eye=dict(x=0, y=1.5, z=0.6), center=dict(x=0, y=0, z=-0.2))
),
annotations=frames[0].layout.annotations,
updatemenus=[{
"buttons": [
{"args": [None, {"frame": {"duration": 400, "redraw": True}, "fromcurrent": True}],
"label": "Play", "method": "animate"},
{"args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
"label": "Pause", "method": "animate"}
],
"type": "buttons", "showactive": False, "x": -0.05, "y": 0.1, "xanchor": "left", "yanchor": "top"
}],
sliders=[{
"steps": [
{"args": [[f.name], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
"label": str(k), "method": "animate"} for k, f in enumerate(frames)
],
"active": 0, "currentvalue": {"prefix": "Iteration: "}, "pad": {"t": 50}
}]
),
frames=frames
)
fig.show()