npb_data_app / plotting.py
patrickramos's picture
Add general pitch classification
d1369a2
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import transforms
from matplotlib.colors import LinearSegmentedColormap
import polars as pl
from pyfonts import load_google_font
from scipy.stats import gaussian_kde
import numpy as np
from types import SimpleNamespace
from datetime import date
from data import data_df
from convert import ball_kind_code_to_color, get_text_color_from_color
from stats import filter_data_by_date_and_game_kind, compute_team_games, compute_pitch_stats
mpl.use('Agg')
def get_pitcher_stats(id, lr=None, game_kind=None, start_date=None, end_date=None, min_ip=1, min_pitches=1, pitch_class_type='specific'):
source_data = data_df.filter(pl.col('ballKind_code') != '-')
# if start_date is not None:
# source_data = source_data.filter(pl.col('date') >= start_date)
# if end_date is not None:
# source_data = source_data.filter(pl.col('date') <= end_date)
#
# if game_kind is not None:
# source_data = source_data.filter(pl.col('coarse_game_kind') == game_kind)
source_data = filter_data_by_date_and_game_kind(source_data, start_date=start_date, end_date=end_date, game_kind=game_kind)
source_data = (
compute_team_games(source_data)
.with_columns(
pl.when(pl.col('half_inning').str.ends_with('1')).then('home_games').otherwise('visitor_games').first().over('pitId').alias('games'),
pl.col('inning_code').unique().len().over('pitId').alias('IP')
)
)
if min_ip == 'qualified':
source_data = source_data.with_columns((pl.col('IP') >= pl.col('games')).alias('qualified'))
else:
source_data = source_data.with_columns((pl.col('IP') >= min_ip).alias('qualified'))
if lr is not None:
source_data = source_data.filter(pl.col('batLR') == lr)
pitch_stats = compute_pitch_stats(source_data, player_type='pitcher', pitch_class_type=pitch_class_type, min_pitches=min_pitches).filter(pl.col('pitId') == id)
pitch_shapes = (
source_data
.filter(
(pl.col('pitId') == id) &
pl.col('x').is_not_null() &
pl.col('y').is_not_null() &
(pl.col('ballSpeed') > 0)
)
[['pitId', 'general_ballKind_code', 'ballKind_code', 'ballSpeed', 'x', 'y']]
)
pitcher_stats = (
source_data
.group_by('pitId')
.agg(
pl.col('pitcher_name').first(),
(pl.when(pl.col('presult').str.contains('strikeout')).then(1).otherwise(0).sum() / pl.col('pa_code').unique().len()).alias('K%'),
(pl.when(pl.col('presult') == 'Walk').then(1).otherwise(0).sum() / pl.col('pa_code').unique().len()).alias('BB%'),
(pl.col('csw').sum() / pl.col('pitch').sum()).alias('CSW%'),
pl.col('aux_bresult').struct.field('batType').drop_nulls().value_counts(normalize=True),
pl.first('qualified')
)
.explode('batType')
.unnest('batType')
.pivot(on='batType', values='proportion')
.fill_null(0)
.with_columns(
(pl.col('G') + pl.col('B')).alias('GB%'),
(pl.col('F') + pl.col('P')).alias('FB%'),
pl.col('L').alias('LD%'),
)
.drop('G', 'F', 'B', 'P', 'L')
.with_columns(
(pl.when(pl.col('qualified')).then(pl.col(stat)).rank(descending=(stat == 'BB%'))/pl.when(pl.col('qualified')).then(pl.col(stat)).count()).alias(f'{stat}_pctl')
for stat in ['CSW%', 'K%', 'BB%', 'GB%']
)
.filter(pl.col('pitId') == id)
)
return SimpleNamespace(pitcher_stats=pitcher_stats, pitch_stats=pitch_stats, pitch_shapes=pitch_shapes)
def get_card_data(id, **kwargs):
both, left, right = get_pitcher_stats(id, **kwargs), get_pitcher_stats(id, 'l', **kwargs), get_pitcher_stats(id, 'r', **kwargs)
pitcher_stats = both.pitcher_stats.join(left.pitcher_stats, on='pitId', suffix='_left').join(right.pitcher_stats, on='pitId', suffix='_right')
pitch_stats = both.pitch_stats.join(left.pitch_stats, on='ballKind_code', how='full', suffix='_left').join(right.pitch_stats, on='ballKind_code', how='full', suffix='_right').fill_null(0)
return SimpleNamespace(
pitcher_stats=pitcher_stats,
pitch_stats=pitch_stats,
both_pitch_shapes=both.pitch_shapes,
left_pitch_shapes=left.pitch_shapes,
right_pitch_shapes=right.pitch_shapes
)
def plot_arsenal(ax, pitches):
ax.set_xlim(0, 11)
x = np.arange(len(pitches)) + 0.5
y = np.zeros(len(pitches))
ax.scatter(x, y, c=[ball_kind_code_to_color.get(pitch, 'C0') for pitch in pitches], s=170)
for i, pitch in enumerate(pitches):
color = ball_kind_code_to_color.get(pitch, 'C0')
ax.text(x=i+0.5, y=0, s=pitch, horizontalalignment='center', verticalalignment='center', font=font, color=get_text_color_from_color(color))
def plot_usage(ax, usages):
left = 0
height = 0.8
for pitch, usage in usages.iter_rows():
color = ball_kind_code_to_color[pitch]
ax.barh(0, usage, height=height, left=left, color=color)
if usage > 0.1:
ax.text(left+usage/2, 0, f'{usage:.0%}', horizontalalignment='center', verticalalignment='center', size=8, font=font, color=get_text_color_from_color(color))
left += usage
ax.set_xlim(0, 1)
ax.set_ylim(-height/2, height/2*2.75)
x_range = np.arange(-100, 100+1)
y_range = np.arange(0, 250+1)
X, Y = np.meshgrid(x_range, y_range)
def fit_pred_kde(data):
kde = gaussian_kde(data)
Z = kde(np.concat((X, Y)).reshape(2, -1)).reshape(*X.shape)
return Z
def plot_loc(ax, locs):
ax.set_aspect('equal', adjustable='datalim')
ax.set_ylim(-52, 252)
ax.add_patch(plt.Rectangle((-100, 0), width=200, height=250, facecolor='darkgray', edgecolor='dimgray'))
ax.add_patch(plt.Rectangle((-80, 25), width=160, height=200, facecolor='gainsboro', edgecolor='dimgray'))
ax.add_patch(plt.Rectangle((-60, 50), width=120, height=150, fill=False, edgecolor='yellowgreen', linestyle=':'))
ax.add_patch(plt.Rectangle((-40, 75), width=80, height=100, facecolor='ivory', edgecolor='darkgray'))
ax.add_patch(plt.Polygon([(0, -10), (45, -30), (51, -50), (-51, -50), (-45, -30), (0, -10)], facecolor='snow', edgecolor='darkgray'))
for (pitch,), _locs in locs.sort(pl.len().over('general_ballKind_code'), descending=True).group_by('general_ballKind_code', maintain_order=True):
if len(_locs) <= 2:
continue
Z = fit_pred_kde(_locs[['x', 'y']].to_numpy().T)
Z = Z / Z.sum()
Z_flat = Z.ravel()
sorted_Z = np.sort(Z_flat)
sorted_Z_idxs = np.argsort(Z_flat)
Z_cumsum = (sorted_Z).cumsum()
t = Z_flat[sorted_Z_idxs[np.argmin(np.abs(Z_cumsum - (1-0.68)))]]
ax.contourf(X, Y, Z, levels=[t, 1], colors=ball_kind_code_to_color[pitch], alpha=0.5)
ax.contour(X, Y, Z, levels=t.reshape(1), colors=ball_kind_code_to_color[pitch], alpha=0.75)
def plot_velo(ax, velos):
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
for (pitch,), _velos in velos.group_by('general_ballKind_code'):
if len(_velos) <= 1:
continue
violin = ax.violinplot(_velos['ballSpeed'], orientation='horizontal', side='high', showextrema=False)
for _violin in violin['bodies']:
_violin.set_facecolor(ball_kind_code_to_color[pitch])
mean = _velos['ballSpeed'].mean()
ax.text(mean, 0.5, round(mean), horizontalalignment='center', verticalalignment='center', color='gray', alpha=0.75, font=font, transform=trans)
stat_cmap = LinearSegmentedColormap.from_list('stat', colors=['dodgerblue', 'snow', 'crimson'])
def plot_pitch_stats(ax, stats, stat_names):
ax.set_aspect('equal', adjustable='datalim')
# axis_to_data = lambda coords: ax.transData.inverted().transform(ax.transAxes.transform(coords))
table = mpl.table.Table(ax)
rows = len(stat_names) + 1
cols = len(stats) + 1
cell_height = 1/rows
cell_width = 1/cols
for row, stat in enumerate(stat_names, start=1):
cell = table.add_cell(row=row, col=0, width=cell_width, height=cell_height, text=stat, loc='center', fontproperties=font, edgecolor='white')
for col, pitch in enumerate(stats['ballKind_code'], start=1):
color = ball_kind_code_to_color.get(pitch, 'C0')
cell = table.add_cell(row=0, col=col, width=cell_width, height=cell_height, text=pitch, loc='center', fontproperties=font, facecolor=color, edgecolor='white')
cell.get_text().set_color(get_text_color_from_color(color))
_stats = stats.filter(pl.col('ballKind_code') == pitch)
qualified = _stats['qualified'].item()
for row, stat_name in enumerate(stat_names, start=1):
stat = _stats[stat_name].item()
stat_pctl = _stats[f'{stat_name}_pctl'].item()
cell = table.add_cell(row=row, col=col, width=cell_width, height=cell_height, text=f'{stat:.0%}', loc='center', fontproperties=font, facecolor=(stat_cmap([0, stat_pctl, 1])[1] if qualified else 'gainsboro'), edgecolor='white')
if not qualified:
cell.get_text().set_color('gray')
ax.add_artist(table)
def plot_pitcher_stats(ax, stats, stat_names):
ax.set_aspect('equal', adjustable='datalim')
table = mpl.table.Table(ax)
cell_height = 1
cell_width = 1/(len(stat_names)*2)
qualified = stats['qualified'].item()
for i, stat_name in enumerate(stat_names):
stat = stats[stat_name].item()
stat_pctl = stats[f'{stat_name}_pctl'].item()
table.add_cell(row=0, col=i*2, width=cell_width, height=cell_height, text=stat_name, loc='center', fontproperties=font, edgecolor='white')
cell = table.add_cell(row=0, col=i*2+1, width=cell_width, height=cell_height, text=f'{stat:.0%}', loc='center', fontproperties=font, facecolor=(stat_cmap([0, stat_pctl, 1])[1] if qualified else 'gainsboro'), edgecolor='white')
if not qualified:
cell.get_text().set_color('gray')
ax.add_artist(table)
font = load_google_font('Saira Extra Condensed', weight='medium')
def create_pitcher_overview_card(id, season, dpi=300):
data = get_card_data(id, start_date=date(season, 1, 1), end_date=date(season, 12, 31), game_kind='Regular Season', min_pitches=100, pitch_class_type='general')
fig = plt.figure(figsize=(1080/300, 1350/300), dpi=dpi)
gs = fig.add_gridspec(8, 6, height_ratios=[1, 1, 1.5, 6, 1, 3, 1, 0.5])
title_ax = fig.add_subplot(gs[0, :])
title_ax.text(x=0, y=0, s=data.pitcher_stats['pitcher_name'].item().upper(), verticalalignment='baseline', font=font, size=20)
# title_ax.text(x=1, y=1, s='2021\n-2023', horizontalalignment='right', verticalalignment='top', font=font, size=8)
title_ax.text(x=0.95, y=0, s=season, horizontalalignment='right', verticalalignment='baseline', font=font, size=20)
title_ax.text(x=1, y=0.5, s='REG', horizontalalignment='right', verticalalignment='center', font=font, size=10, rotation='vertical')
arsenal_ax = fig.add_subplot(gs[1, :])
plot_arsenal(arsenal_ax, data.pitch_stats['ballKind_code'])
usage_l_ax = fig.add_subplot(gs[2, :3])
plot_usage(usage_l_ax, data.pitch_stats[['ballKind_code', 'usage_left']])
usage_l_ax.text(0, 1, 'LHH usage', horizontalalignment='left', verticalalignment='top', linespacing=0.5, color='gray', font=font, size=10, transform=usage_l_ax.transAxes)
usage_r_ax = fig.add_subplot(gs[2, 3:])
plot_usage(usage_r_ax, data.pitch_stats[['ballKind_code', 'usage_right']])
usage_r_ax.text(0, 1, 'RHH usage', horizontalalignment='left', verticalalignment='top', linespacing=0.5, color='gray', font=font, size=10, transform=usage_r_ax.transAxes)
loc_l_ax = fig.add_subplot(gs[3, :3])
loc_l_ax.text(0, 1, 'LHH\nloc', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=loc_l_ax.transAxes)
plot_loc(loc_l_ax, data.left_pitch_shapes)
loc_r_ax = fig.add_subplot(gs[3, 3:])
loc_r_ax.text(0, 1, 'RHH\nloc', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=loc_r_ax.transAxes)
plot_loc(loc_r_ax, data.right_pitch_shapes)
velo_ax = fig.add_subplot(gs[4, :])
plot_velo(velo_ax, data.both_pitch_shapes)
velo_ax.text(0, 1, 'Velo', verticalalignment='top', horizontalalignment='left', color='gray', font=font, size=10, transform=velo_ax.transAxes)
pitch_stats_ax = fig.add_subplot(gs[5, :])
plot_pitch_stats(pitch_stats_ax, data.pitch_stats, ['CSW%', 'GB%'])
pitcher_stats_ax = fig.add_subplot(gs[6, :])
plot_pitcher_stats(pitcher_stats_ax, data.pitcher_stats, ['CSW%', 'K%', 'BB%', 'GB%'])
# k_ax = fig.add_subplot(gs[5, :2])
# plot_stat(k_ax, data.pitcher_stats, 'K%')
# bb_ax = fig.add_subplot(gs[5, 2:4])
# plot_stat(bb_ax, data.pitcher_s`tats, 'BB%')
# gb_ax = fig.add_subplot(gs[5, 4:])
# plot_stat(gb_ax, data.pitcher_stats, 'GB%')
credits_ax = fig.add_subplot(gs[7, :])
credits_ax.text(x=0, y=0.5, s='Data: SPAIA, Sanspo', verticalalignment='center', font=font, size=7)
credits_ax.text(x=1, y=0.5, s='@yakyucosmo', horizontalalignment='right', verticalalignment='center', font=font, size=7)
for ax in [
title_ax,
arsenal_ax,
usage_l_ax, usage_r_ax,
loc_l_ax, loc_r_ax,
velo_ax,
# k_ax, bb_ax, gb_ax,
pitch_stats_ax,
pitcher_stats_ax,
credits_ax
]:
ax.axis('off')
ax.tick_params(
axis='both',
which='both',
length=0,
labelbottom=False,
labelleft=False
)
return fig
# fig = create_card('1600153', season=2023, dpi=300)
# plt.show()