Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added churn-risk/100
Empty file.
63 changes: 40 additions & 23 deletions churn-risk/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from .churn_predictor import ChurnPredictor


TRAIN_DATASET_PATH='./data/churnTrain.csv'
TEST_DATASET_PATH='./data/churnTest.csv'
TARGET_COLUMN="Churn?"
CATEGORICAL_COLUMNS=['Area Code']
DROP_COLUMNS=["Phone"]
TRAIN_DATASET_PATH = './data/churnTrain.csv'
TEST_DATASET_PATH = './data/churnTest.csv'
TARGET_COLUMN = "Churn?"
CATEGORICAL_COLUMNS = ['Area Code']
DROP_COLUMNS = ["Phone"]

# Load test data once at startup
df = pd.read_csv(TEST_DATASET_PATH)
df.dropna(inplace=True)
df['Total Charges'] = (df['Day Charges'] + df['Evening Charges'] + df['Night Charges'] + df['Intl Charges'])
Expand Down Expand Up @@ -84,8 +85,8 @@ def render_desc_info(q: Q, selected_row_index: Optional[int]):
)

total_charges = df['Total Charges']
charge = total_charges[selected_row_index] if selected_row_index is not None else total_charges.mean(axis=0)
rank = df['Total Charges'].rank(pct=True).values[selected_row_index] if selected_row_index is not None else df['Total Charges'].rank(pct=True).mean(axis=0)
charge = total_charges[selected_row_index] if selected_row_index is not None else total_charges.mean()
rank = df['Total Charges'].rank(pct=True).values[selected_row_index] if selected_row_index is not None else df['Total Charges'].rank(pct=True).mean()
q.page['total_charges'] = ui.tall_gauge_stat_card(
box='top-stats',
title='Total Charges' if selected_row_index else 'Average Total Charges',
Expand All @@ -104,7 +105,7 @@ def render_charges_breakdown(q: Q, selected_row_index: Optional[int]):
if selected_row_index is not None:
rows.append((label, df[label][selected_row_index]))
else:
rows.append((label, df[label].mean(axis=0)))
rows.append((label, df[label].mean()))
color_range = f'{q.client.primary_color} {q.client.secondary_color} {q.client.tertiary_color} #67dde6'
q.page['bar_chart'] = ui.plot_card(
box=ui.box('top-stats', height='300px'),
Expand All @@ -115,10 +116,27 @@ def render_charges_breakdown(q: Q, selected_row_index: Optional[int]):


def render_analysis(q: Q):
row_phone_no = int(q.args.customers[0]) if q.args.customers else None
q.page['title'].items[0].picker.values = q.args.customers
# Dropdown returns a string (or None), not a list
row_phone_no = int(q.args.customers) if q.args.customers else None

# Update dropdown selection visually
q.page['title'].items[0].dropdown.value = q.args.customers

# Find selected index safely
selected_row_index = None
if row_phone_no is not None:
matching = df[df['Phone'] == row_phone_no]
if not matching.empty:
selected_row_index = int(matching.index[0])
else:
q.page['title'].subtitle = f'Customer: {row_phone_no} (Not found in test data)'
# Clear previous plots to avoid stale data
for card in ['shap_plot', 'top_negative_plot', 'top_positive_plot', 'churn_rate', 'total_charges', 'bar_chart']:
if card in q.page:
del q.page[card]
return

q.page['title'].subtitle = f'Customer: {row_phone_no or "No customer selected"}'
selected_row_index = int(df[df['Phone'] == row_phone_no].index[0]) if row_phone_no else None

shap_rows = churn_predictor.get_shap(selected_row_index)
render_shap_plot(q, shap_rows, selected_row_index)
Expand Down Expand Up @@ -172,13 +190,13 @@ def init(q: Q):
title='Customer profiles from model predictions',
subtitle='Customer: No customer chosen',
items=[
# TODO: Replace with dropdown after https://github.com/h2oai/wave/pull/303 merged.
ui.picker(
ui.dropdown(
name='customers',
label='Customer Phone Number',
choices=[ui.choice(name=str(phone), label=str(phone)) for phone in df['Phone']],
max_choices=1,
trigger=True
trigger=True,
searchable=True,
placeholder='Search by phone number...'
),
ui.toggle(name='theme', label='Dark Theme', trigger=True)
]
Expand Down Expand Up @@ -206,15 +224,14 @@ async def serve(q: Q):
q.page['title'].items[1].toggle.value = dark_theme

if q.args.code:
del q.page['shap_plot']
del q.page['top_negative_plot']
del q.page['top_positive_plot']
del q.page['total_charges']
del q.page['bar_chart']
del q.page['churn_rate']
# Clean up analysis cards
for card in ['shap_plot', 'top_negative_plot', 'top_positive_plot', 'churn_rate', 'total_charges', 'bar_chart']:
if card in q.page:
del q.page[card]
render_code(q)
else:
del q.page['code']
if 'code' in q.page:
del q.page['code']
render_analysis(q)

await q.page.save()
await q.page.save()
Loading