''' random functions that are helpful
'''
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import plot, scatter, hist, figure, clf, cla, xlabel, ylabel, xlim, ylim,\
gcf, gca, close, title, legend, grid, bar, suptitle, show,\
xticks, yticks, axis, boxplot
from sklearn.metrics import roc_curve, auc
import numpy as np
import datetime
from datetime import datetime as dt
# std library
from itertools import chain
import math
import sys
from queue import PriorityQueue
import pandas as pd
# typing
from typing import List, Dict, Union, Any, Iterable
from scipy import stats
import seaborn as sns
from .etl import nan_smooth
from .etl import round_time
from .histogram_utils import nhist
from scipy import interpolate
# mini function to change a [a, b] into [[a,b]] for use in the for loop later
[docs]
def list_ize(w):
if len(w) > 0:
if hasattr(w[0], "__len__"):
return w
else:
return [w]
else:
return []
[docs]
def plot_range(events, color='#0093e7', y_offset='none', height='none', zorder=None, alpha=0.5, **varargs):
"""
Shade vertical regions on the current axes.
Parameters
----------
events : list of [start, end] pairs
X positions defining the regions to shade.
color : str, optional
Fill color. Default is ``'#0093e7'``.
y_offset : float, optional
Bottom of the shaded region. Default uses current y-axis lower limit.
height : float, optional
Height of the shaded region. Default spans the full y-axis.
zorder : int, optional
Drawing order.
alpha : float, optional
Transparency. Default is 0.5.
"""
events = list_ize(events)
yy = ylim()
if y_offset == 'none':
y_offset = yy[0]
if height == 'none':
height = yy[1] - yy[0]
to_label = 'none'
# Fill registered cur_event times
for cur_event in events:
plt.fill_between([cur_event[0], cur_event[1]],
[height + y_offset, height + y_offset],
[y_offset, y_offset],
color=color, alpha=alpha, zorder=zorder, label=to_label, **varargs)
# make sure only one legend item appears for this event series
to_label = '_nolegend_'
[docs]
def plot_range_idx(t, events, **varargs):
"""
Shade regions by index into a time series.
Converts index pairs into time-domain ranges and calls
`plot_range`.
Parameters
----------
t : array-like
Time series (x-axis values).
events : list of [start_idx, end_idx] pairs
Index pairs into *t* defining the regions.
**varargs
Passed to `plot_range`.
"""
# if it is a series, get the values
if str(type(t)) == "<class 'pandas.core.series.Series'>":
t = t.values
events = list_ize(events)
# last element + ds
t = np.append(t, [t[-1] + (t[-1] - t[-2])])
event_times = [[t[a], t[b]] for a, b in events]
plot_range(event_times, **varargs)
[docs]
def plot_cdf(data, *args, **kargs):
"""
Plot the empirical cumulative distribution function (CDF).
Parameters
----------
data : array-like
Input data. NaN values are removed.
*args, **kargs
Passed to ``plt.plot``.
"""
data = np.asarray(data)
data = data[~np.isnan(data)]
x = np.sort(data)
y = 100 * np.arange(len(x)) / len(x)
plt.plot(x, y, *args, **kargs)
plt.ylim([0,100])
[docs]
def set_fontsize(f_size=15):
"""
Set the font size of the current axes' title, labels, and tick labels.
Parameters
----------
f_size : int, optional
Font size in points. Default is 15.
Returns
-------
ax : matplotlib.axes.Axes
The current axes.
"""
ax = plt.gca()
[w.set_fontsize(f_size) for w in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() +
ax.get_yticklabels())]
return ax
[docs]
def print_fig_fancy(fpathh, dpi=300, **kwargs):
plt.savefig(fpathh, dpi=dpi, facecolor=[0, 0, 0, 0], **kwargs)
[docs]
def make_title(title_str,format=False):
if title_str:
lower_case_words = {'in','the','a','and','of'}
upper_case_words = {}
# title_str = 'a_time_in_bed and fish'
# split by
tmp = title_str.replace("_"," ").split()
if format:
upper_words = [w[0].upper()+w[1:] if len(w)>1 else w[0].upper() for w in tmp]
lower_words = [w[0].lower()+w[1:] if w.lower() in lower_case_words else w for w in upper_words]
else :
lower_words = [w.lower() for w in tmp]
# lower_words[0][0]=lower_words[0][0].upper()
lower_words[0] = lower_words[0][0].upper() + lower_words[0][1:]
upper_words = [w.upper() if w.upper() in upper_case_words else w for w in lower_words]
combined_words = ' '.join(upper_words)
else:
combined_words = ''
return combined_words
[docs]
def large_fig(fig_num=1):
"""
Create a large figure (15 x 8 inches).
Parameters
----------
fig_num : int, optional
Figure number. Default is 1.
"""
plt.figure(fig_num,figsize=(15,8))
[docs]
def fancy_plotter(x,y,marker_style='o',line_styles=None):
"""
Plot x vs y with a linear trend line overlay.
Parameters
----------
x : array-like
X data.
y : array-like
Y data.
marker_style : str, optional
Marker format string. Default is ``'o'``.
line_styles : dict, optional
Keyword arguments for the trend line. Default is
``{'color': '0.4', 'lw': 3}``.
"""
if line_styles is None:
line_styles = {'color':'0.4','lw':3}
# check that these are not regular lists, so that we can use logical indexing later
if type(y)!=np.ndarray:
y=np.array(y)
if type(x)!=np.ndarray:
x = np.array(x)
plt.plot(x,y,marker_style)
# if
I1 = np.logical_not(np.isnan(y))
I2 = np.logical_not(np.isnan(x))
I = np.logical_and(I1,I2)
if sum(I) > 3:
m, b = np.polyfit(x[I], y[I], 1)
plt.plot(x, m*np.array(x) + b,**line_styles)
[docs]
def cplot(z, *args, **kargs):
"""
Plot complex numbers on the real/imaginary plane.
Parameters
----------
z : complex array-like
Complex data to plot (real -> x, imag -> y).
*args, **kargs
Passed to ``plt.plot``.
"""
plot(np.real(z), np.imag(z), *args, **kargs)
[docs]
def cplot_circle(z_center, r):
t = np.linspace(0, 2 * np.pi, 100)
z = r * np.exp(1j * t) + z_center
cplot(z)
[docs]
def ctext(z, *args, **kargs):
"""
Place text at a position given by a complex number.
Parameters
----------
z : complex
Position (real -> x, imag -> y).
*args, **kargs
Passed to ``plt.text``.
"""
plt.text(np.real(z), np.imag(z), *args, **kargs)
[docs]
def cscatter(z, *args, **kargs):
"""
Scatter plot of complex numbers on the real/imaginary plane.
Parameters
----------
z : complex array-like
Complex data (real -> x, imag -> y).
*args, **kargs
Passed to ``plt.scatter``.
"""
scatter(np.real(z), np.imag(z), *args, **kargs)
[docs]
def polar_grid(lw=1, r=False, linecolor='.3', style=':', nrings=2, nrays = 6):
"""
Overlay a polar grid (rings and rays) on the current axes.
Useful for complex-plane plots.
Parameters
----------
lw : float, optional
Line width. Default is 1.
r : float, optional
Radius of the grid. Default is auto-detected from axis limits.
linecolor : str, optional
Color of grid lines. Default is ``'.3'``.
style : str, optional
Line style. Default is ``':'``.
nrings : int, optional
Number of concentric rings. Default is 2.
nrays : int, optional
Number of radial rays. Default is 6.
Returns
-------
line : list of Line2D
The plotted grid lines.
"""
ax = plt.gca()
axis('equal')
ex = ax.get_xlim()
yy = ax.get_ylim()
to_plot = []
if not r:
r = np.min(np.abs(np.concatenate([ex, yy])))
t = np.linspace(0, 2 * np.pi, nrays +1)[:-1]
z = 100 * r * np.exp(1j * t)
for cur in z:
to_plot.append(0)
to_plot.append(cur)
to_plot.append(np.nan)
t = np.linspace(0, 2 * np.pi, 200)
for cur_r in np.linspace(r/nrings, r, nrings):
circz = cur_r * np.exp(1j * t)
to_plot = to_plot + [np.nan] + list(circz)
cur_plt = cplot(np.array(to_plot), style, color=linecolor, lw=lw)
xlim(ex)
ylim(yy)
return cur_plt
[docs]
def plot_diag(lw=1, color='.5', reverse=False):
"""
Plot a diagonal x=y reference line on the current axes.
Parameters
----------
lw : float, optional
Line width. Default is 1.
color : str, optional
Line color. Default is ``'.5'``.
reverse : bool, optional
If True, plot x = -y instead. Default is False.
"""
ax = plt.gca()
ex = ax.get_xlim()
yy = ax.get_ylim()
if np.diff(ex)[0] > np.diff(yy)[0]:
y = yy
x = yy
else:
y = ex
x = ex
if reverse:
x = np.flip(x)
plt.plot(x, y, '--', color=color, lw=lw, label='_nolegend_')
[docs]
def plot_zero(lineheight=0, axx='x', **kwargs):
"""
Plot a horizontal or vertical reference line.
Parameters
----------
lineheight : float, optional
Position of the line. Default is 0.
axx : {'x', 'y'}, optional
``'x'`` for a horizontal line, ``'y'`` for vertical. Default is ``'x'``.
**kwargs
Passed to ``plt.plot``. Defaults to a gray dashed line.
"""
if len(kwargs) == 0:
kwargs = {'color' : '.5',
'linestyle' : '--',
'lw' : 1,
'label': '_nolegend_'
}
ax = plt.gca()
if axx == 'x':
ex = ax.get_xlim()
x = ex
y = [lineheight, lineheight]
elif axx == 'y':
yy = ax.get_ylim()
x = [lineheight, lineheight]
y = yy
else:
raise Exception("you can't plot zero on no axis!")
return plt.plot(x,y, **kwargs)
# return plt.plot(x,y,style, lw=lw, **kwargs)
[docs]
def plot_axes(color='.5'):
"""
Plot both horizontal and vertical reference lines at zero.
Parameters
----------
color : str, optional
Line color. Default is ``'.5'``.
"""
plot_zero(axx='x', color=color)
plot_zero(axx='y', color=color)
[docs]
def plot_pin(x, y, color='k'):
"""
Plot a pin (vertical line with dot) at a specific x position.
Parameters
----------
x : float
X position of the pin.
y : float
Height of the pin.
color : str, optional
Color of the line and marker. Default is ``'k'``.
"""
plot([x, x], [0, y], linewidth=3, color=color)
plot([x], [y], 'o', color=color, markersize=10)
[docs]
def bar_centered(y, **kwargs):
"""
Bar plot centered on integers 1 through N.
Parameters
----------
y : array-like
Bar heights.
**kwargs
Passed to ``plt.bar``.
Returns
-------
container : BarContainer
The bar container.
"""
x = np.arange(len(y))+1
h = plt.bar(x, y, align='center', **kwargs)
plt.xticks(x)
return h
[docs]
def errorb(cur_series, serror=True):
"""
Bar plot with error bars from a pandas Series of arrays.
Parameters
----------
cur_series : pandas.Series
Each element is an array-like of values. The index provides
x-tick labels.
serror : bool, optional
If True (default), use standard error. If False, use standard
deviation.
"""
means = cur_series.apply(np.nanmean)
errors = cur_series.apply(np.nanstd)
if serror:
errors = errors / np.sqrt(len(errors))
x_pos = np.arange(len(cur_series)) + 1
bar(x_pos, means,
yerr=errors,
align='center',
alpha=0.5,
color='gray',
capsize=4)
ax = gca()
ax.set_xticks(x_pos)
ax.set_xticklabels(cur_series.index)
format_axis_date()
[docs]
def bplot(X):
if type(X) == dict:
keys = X.keys()
S = len(keys)
for idx, k in enumerate(keys):
x = X[k]
x = x.astype(float)
x = x[~np.isnan(x)]
boxplot(x, positions=[idx], patch_artist=True, boxprops=dict(facecolor='lightblue', lw=1, color='k'),
widths=0.6, whiskerprops=dict(linewidth=2, color='k', linestyle='-'),
medianprops=dict(linewidth=2))
xticks(range(S))
xticklabels(keys)
if type(keys) == str:
format_axis_date()
nicefy()
[docs]
def subplotter_auto(n, ii, **kwargs):
"""
Create a subplot with automatically chosen grid dimensions.
Computes a near-square grid that fits *n* subplots and activates
the *ii*-th one.
Parameters
----------
n : int
Total number of subplots.
ii : int
Index of the subplot to activate (0-based).
**kwargs
Passed to `subplotter`.
"""
y = int(np.ceil(np.sqrt(n)))
x = int(np.ceil(n / y))
subplotter(x, y, ii, **kwargs)
[docs]
def xticklabels(all_lbl):
gca().set_xticklabels(all_lbl)
[docs]
def yticklabels(all_lbl):
gca().set_yticklabels(all_lbl)
[docs]
def subplotter(x, y=None, nth=None, xlbl=None, ylbl=None, y_ticks=None):
"""
MATLAB-style subplot with 0-based indexing and spanning support.
Parameters
----------
x : int
Number of rows, or a 3-digit integer like ``220`` meaning
2 rows, 2 columns, 0th subplot.
y : int, optional
Number of columns.
nth : int or list of int, optional
Subplot index (0-based). Pass a list to span multiple cells.
xlbl : str, optional
X-axis label, shown only on the bottom row.
ylbl : str, optional
Y-axis label, shown only on the left column.
y_ticks : list or False, optional
Y-tick labels for non-left columns. False hides them.
Returns
-------
ax : matplotlib.axes.Axes
The created subplot axes.
"""
# allow you to enter input like (220) instead of x=2, y=2, nth =0
if x > 99:
nth = int(x % 10)
y = int((x - nth) % 100)
x = int((x - y) / 100)
y = int(y / 10)
kwargs = {}
if type(nth) != int:
# note special case y == 1, where rowspan should be used
if len(nth) > y:
kwargs = {'colspan': y,
'rowspan': len(nth) / y}
if int(kwargs['rowspan']) != kwargs['rowspan']:
raise Exception("this isn't supported yet")
else:
kwargs['rowspan'] = int(kwargs['rowspan'])
else:
if nth[1] == nth[0] + 1 and y > 1:
kwargs = {'colspan': len(nth)}
else:
kwargs = {'rowspan': len(nth)}
nth = nth[0]
tupp = (x, y)
cnt = 0
for ii in range(tupp[0]):
for jj in range(tupp[1]):
if cnt==nth:
ax = plt.subplot2grid(tupp, (ii, jj), **kwargs)
if jj == 0:
if ylbl:
ylabel(ylbl)
else:
if y_ticks is not None:
if y_ticks == False:
yticklabels([])
else:
yticklabels(y_ticks)
if ii + 1 == x:
if xlbl is not None:
xlabel(xlbl)
return ax
cnt+=1
raise Exception("You have only " + str(x*y) + " subplots, but you asked for the " + str(nth) + "'th")
[docs]
def pop_all():
"""
Bring all matplotlib figures to the foreground.
Useful when using IPython in the terminal.
Returns
-------
int
Number of figures shown.
"""
all_figures=[manager.canvas.figure for manager in matplotlib.\
_pylab_helpers.Gcf.get_all_fig_managers()]
[fig.canvas.manager.show() for fig in all_figures]
return len(all_figures)
[docs]
def suplabel(axis,label,label_prop=None,
labelpad=5,
ha='center',va='center'):
"""
Add a shared xlabel or ylabel to the figure, similar to ``suptitle``.
Parameters
----------
axis : {'x', 'y'}
Which axis to label.
label : str
The label text.
label_prop : dict, optional
Keyword arguments for ``plt.text``.
labelpad : float, optional
Padding from the axis in points. Default is 5.
ha : str, optional
Horizontal alignment. Default is ``'center'``.
va : str, optional
Vertical alignment. Default is ``'center'``.
"""
fig = plt.gcf()
xmin = []
ymin = []
for ax in fig.axes:
xmin.append(ax.get_position().xmin)
ymin.append(ax.get_position().ymin)
xmin,ymin = min(xmin),min(ymin)
dpi = fig.dpi
if axis.lower() == "y":
rotation=90.
x = xmin-float(labelpad)/dpi
y = 0.5
elif axis.lower() == 'x':
rotation = 0.
x = 0.5
y = ymin - float(labelpad)/dpi
else:
raise Exception("Unexpected axis: x or y")
if label_prop is None:
label_prop = dict()
plt.text(x,y,label,rotation=rotation,
transform=fig.transFigure,
ha=ha,va=va,
**label_prop)
# Note: this is a half-complete port of my former matlab code
# https://www.mathworks.com/matlabcentral/fileexchange/29545-power-law-exponential-and-logarithmic-fit?s_tid=prof_contriblnk
[docs]
def logfit(x, y=None, graph_type='linear', ftir=.05, marker_style='.k', line_style='--g',
skip_begin = 0, skip_end = 0):
"""
Fit and plot a line through data on linear, semi-log, or log-log axes.
Parameters
----------
x : array-like
X data (or y data if *y* is None, or complex with real=x, imag=y).
y : array-like, optional
Y data.
graph_type : {'linear', 'logy', 'loglog'}, optional
Fit type. Default is ``'linear'``.
ftir : float, optional
Fraction to extend the fit line beyond the data. Default is 0.05.
marker_style : str or dict, optional
Marker style for data points. Default is ``'.k'``.
line_style : str or dict, optional
Line style for the fit. Default is ``'--g'``.
skip_begin, skip_end : int, optional
Number of points to skip at the beginning/end of the fit.
Returns
-------
slope : float
Slope of the fit.
intercept : float
Intercept of the fit.
"""
# check if you only passes one var in
if y is None:
if np.iscomplex(x[0]):
y = np.imag(x)
x = np.real(x)
else:
y = x
x = np.array(range(len(y)))
# convert to floats
x = np.array(x).astype(float)
y = np.array(y).astype(float)
# remove any nans
I = np.logical_not(np.isnan(x)) & np.logical_not(np.isnan(y))
x = x[I]
y = y[I]
def linearfit(x2fit, y2fit):
cur_range = np.array([x2fit.min(), x2fit.max()])
tot_range = cur_range[1] - cur_range[0]
exp_range = cur_range + ftir * tot_range * np.array([- 1, 1])
ex = np.linspace(exp_range[0], exp_range[1], 100)
slope, intercept, r_value, p_value, std_err = stats.linregress(x2fit, y2fit)
yy = slope * ex + intercept
return slope, intercept, ex, yy
def logyfit(x2fit, y2fit):
gca().set_yscale('log')
y2fit = np.log10(y2fit)
idxs = np.isfinite(y2fit)
y2fit = y2fit[idxs]
x2fit = x2fit[idxs]
slope, intercept, ex, yy = linearfit(x2fit, y2fit)
return slope, intercept, ex, np.power(10,yy)
def loglogfit(x2fit, y2fit):
gca().set_xscale('log')
gca().set_yscale('log')
x2fit = np.log10(x2fit)
y2fit = np.log10(y2fit)
idxs = np.isfinite(y2fit) & np.isfinite(x2fit)
y2fit = y2fit[idxs]
x2fit = x2fit[idxs]
slope, intercept, ex, yy = linearfit(x2fit, y2fit)
return slope, intercept, np.power(10,ex), np.power(10,yy)
if len(x) < 3:
return np.nan, np.nan
graph_calc = {'linear': linearfit,
'logy': logyfit,
'loglog': loglogfit}
if graph_type in graph_calc:
if skip_end > 0:
slope, intercept, ex, yy = graph_calc[graph_type](x[skip_begin:-skip_end], y[skip_begin:-skip_end])
else:
slope, intercept, ex, yy = graph_calc[graph_type](x[skip_begin:], y[skip_begin:])
else:
raise Exception("we can't do that graph type yet")
if marker_style:
if type(marker_style) == str:
plot(x, y, marker_style)
else:
plot(x, y, **marker_style)
if line_style:
if type(line_style) == str:
plot(ex, yy, line_style, linewidth=3)
else:
plot(ex, yy, '--g', **line_style)
return slope, intercept
[docs]
def test_logfit():
logfit(np.arange(10), 2 * np.random.randn(10) + np.arange(10),
marker_style={'color': 'r', 'lw':0, 'marker': 'o'},
line_style={'color': 'g', 'lw': 3, 'linestyle':'--'})
[docs]
def streamgraph(df, smooth=None, normalize=None,
wiggle=None, label_dict=None, color=None,
order=True, linewidth=0.5, round_time=False, legend_flag=True):
"""
Create a streamgraph (stacked area chart with wiggle baseline).
Parameters
----------
df : DataFrame
Two-column DataFrame: first column is the time/x axis,
second column is the categorical variable to stack.
smooth : int, optional
Smoothing window width. Default is None (no smoothing).
normalize : bool, optional
If True, normalize to 100%. Default is None.
wiggle : bool or str, optional
Baseline mode. True for wiggle, False for zero baseline,
or ``'stream'``/``'river'`` for weighted wiggle. Default is auto.
color : str, dict, or palette, optional
Seaborn palette name, dict of label->color, or palette object.
order : bool or list, optional
True to sort by peak height, or a list of labels. Default is True.
linewidth : float, optional
Edge line width. Default is 0.5.
round_time : str, optional
Pandas frequency string to round the time column. Default is False.
legend_flag : bool, optional
Show legend. Default is True.
Returns
-------
ax : matplotlib.axes.Axes
The axes with the streamgraph.
"""
# reorder a list from inside out, for use with streamgraph
def reorder_idx(idxs):
idxs = list(np.flip(idxs))
for ii in np.flip(np.arange(1, len(idxs), 2)):
idxs.append(idxs.pop(ii))
return np.array(idxs)
# normalize for use with streamgraph so everything sums to 1
def normalize_y(y):
return 100 * y / y.sum(axis=0)
# interpret wiggle and normalize params for cool defaults
# if not wiggle:
# if normalize:
# curate the data:
# add stupid col for who knows why - maybe can leave it out ..
cur_cols = df.columns
if len(cur_cols) == 1:
df['fake_col'] = 'blank'
legend_flag = False
if round_time:
df[cur_cols[0]] = df[cur_cols[0]].dt.round(round_time)
df['dumb12345'] = 'stoopid'
cur_cols = df.columns
df = df[[cur_cols[2], cur_cols[0], cur_cols[1]]]
cur_datas_grouped = df.groupby([cur_cols[0], cur_cols[1]], sort=False).count().unstack()
cur_clean = cur_datas_grouped.fillna(0).transpose().sort_index(axis=1)
cur_labels = np.array([w[1] for w in list(cur_clean.index)])
if len(cur_labels) > 200:
raise Exception("Whoa dude - you are stacking too many things (" + str(len(cur_labels)) + ")")
# prepare that sweet sweet data
x = list(cur_clean.columns)
y = np.array(cur_clean)
# smooth it so it isn't so choppy
if smooth:
y = np.array([nan_smooth(w, smooth) for w in y])
# normalize it so it takes up the whole graph
# remember to set the y'lims below - or unset the autoscale y
if normalize:
y = normalize_y(y)
if wiggle is None:
wiggle = False
# prepare the colors for plotting
if not color:
# Maybe do something depending on "N here?"
color = 'muted'
if type(color) == str:
pallet_list = sns.color_palette(color, n_colors=len(y))
pallet_dict = {k: pallet_list[ii] for ii, k in enumerate(cur_labels)}
elif type(color) == sns.palettes._ColorPalette:
pallet_list = color
pallet_dict = {k: pallet_list[ii] for ii, k in enumerate(cur_labels)}
elif type(color) == dict:
pallet_dict = color
if wiggle is None:
wiggle = True
# prepare the idxs (if you'll wiggle)
if order:
if type(order) == list:
idxs = [cur_labels.tolist().index(w) for w in order]
# print(idxs)
else:
idxs = np.argsort(np.amax(y, axis=1))
if wiggle in [True, 'wiggle', 'streamgraph', 'stream', 'themeriver', 'river']:
idxs = reorder_idx(idxs)[::-1]
else:
idxs = np.arange(len(y))
labels_reord = cur_labels[idxs]
cur_palette = [pallet_dict[w] for w in labels_reord]
if wiggle == True:
cur_base = 'wiggle'
elif wiggle == False:
cur_base = 'zero'
elif wiggle in ['streamgraph', 'stream', 'themeriver', 'river']:
cur_base = 'weighted_wiggle'
else:
cur_base = wiggle
# do the plot!
ax = plt.stackplot(x, y[idxs], labels=labels_reord,
baseline=cur_base, colors=cur_palette, lw=linewidth, edgecolor='k')
gca().set_yticklabels(np.abs(gca().get_yticks()).astype(int))
if legend_flag:
lgnd = legend(labels_reord)
gca().legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.0)
# nicefy()
plt.style.use('classic')
return gca()
## example data for use in streamgraph:
# index date column, grouped column
# 11623 2011-06-01 United States
# 13117 2011-07-01 Europe
# 13118 2011-06-01 United States
# 13120 2011-07-01 Oceania
# 13121 2011-08-01 Canada
# 13122 2011-12-01 United States
# 13123 2011-08-01 United States
# 13125 2011-10-01 Canada
# 65169 2011-08-01 Canada
# 13126 2011-09-01 United States
# 65170 2011-09-01 Canada
# Does this work? can it be used inside nicefy?
# def set_fontsize(f_size=15, ax='none'):
# if ax == 'none':
# ax = plt.gca()
#
# [w.set_fontsize(f_size) for w in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
# ax.get_xticklabels() +
# ax.get_yticklabels())]
# return ax
[docs]
def nicefy(fsize=15, f_size=False, clean_legend=False, cur_fig=None, background = 'white', resize=False, legend_outside=False,
expand_y=False, expand_x=False, touch_limits=False, touch_text=False):
"""
Make the current figure publication-ready.
Adjusts font sizes, removes top/right spines, applies tight layout,
and optionally sets background color and cleans up legends.
Parameters
----------
fsize : int, optional
Font size for all text elements. Default is 15.
f_size : int, optional
Deprecated alias for *fsize*.
clean_legend : bool, optional
If True, make the legend background transparent. Default is False.
cur_fig : Figure, optional
Figure to nicefy. Default is the current figure.
background : str, optional
``'white'``, ``'black'``, or a matplotlib style name.
Default is ``'white'``.
touch_limits : bool, optional
If True, enable tight autoscaling. Default is False.
expand_y, expand_x : bool or str, optional
Expand axis limits slightly. ``True`` expands both directions,
``'top'`` expands only the upper end. Default is False.
touch_text : bool, optional
If True, auto-format axis label and title text. Default is False.
"""
# backwards compatability for this change
if f_size != False:
fsize = f_size
#todo: check if you are log scale, and do expandx or expandy to be top
if cur_fig:
fig = cur_fig
else:
fig = plt.gcf()
# fig.set_tight_layout(True)
if background == 'black':
fig.set_facecolor('black')
plt.style.use('dark_background')
elif background == 'white':
fig.set_facecolor('white')
plt.style.use('classic')
else:
plt.style.use(background)
# ['seaborn-darkgrid', 'Solarize_Light2', 'seaborn-notebook', 'classic', 'seaborn-ticks', 'grayscale', 'bmh',
# 'seaborn-talk', 'dark_background', 'ggplot', 'fivethirtyeight', '_classic_test', 'seaborn-colorblind',
# 'seaborn-deep', 'seaborn-whitegrid', 'seaborn-bright', 'seaborn-poster', 'seaborn-muted', 'seaborn-paper',
# 'seaborn-white', 'fast', 'seaborn-pastel', 'seaborn-dark', 'tableau-colorblind10', 'seaborn',
# 'seaborn-dark-palette']
axes = fig.get_axes()
# remove the grids
# [ax.grid(False) for ax in axes]
for ax in axes:
if touch_text:
ax.xaxis.set_label_text(make_title(ax.xaxis.label.get_text()))
ax.yaxis.set_label_text(make_title(ax.yaxis.label.get_text()))
ax.set_title(make_title(ax.title.get_text()))
[w.set_fontsize(fsize) for w in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() +
ax.get_yticklabels())]
[plt.setp(ax.get_legend().get_texts(), fontsize=fsize) for ax in axes if ax.get_legend()]
ax = gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
def expand_the_bounds(cur_bounds):
cur_range = np.diff(cur_bounds)
return cur_bounds + cur_range * 0.02 * np.array([-1, 1])
def expand_the_top(cur_bounds):
cur_range = np.diff(cur_bounds)
return cur_bounds + cur_range * 0.02 * np.array([0, 1])
if touch_limits:
plt.autoscale(enable=True, axis='x', tight=True)
plt.autoscale(enable=True, axis='y', tight=True)
if expand_y == True:
ylim(expand_the_bounds(ylim()))
elif str(expand_y) == 'top':
ylim(expand_the_top(ylim()))
if expand_x == True:
xlim(expand_the_bounds(xlim()))
elif str(expand_x) == 'top':
xlim(expand_the_top(xlim()))
# thanks
# https://stackoverflow.com/questions/925024/how-can-i-remove-the-top-and-right-axis-in-matplotlib
if resize and (not touch_limits):
plt.rcParams["figure.figsize"] = [7, 5]
plt.rcParams['image.cmap'] = 'viridis'
if clean_legend:
plt.legend(framealpha=0.0)
# fix titles so they all appear
plt.tight_layout()
# if legend_outside:
# gca().legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.0)
[docs]
def xylim(w):
"""
Set both x and y axis limits to the same range.
Parameters
----------
w : float or array-like
If scalar, sets limits to ``[-w, w]``. If a pair, sets limits
to ``[w[0], w[1]]``.
"""
if not hasattr(w, '__len__'):
w = [-w, w]
xlim(w)
ylim(w)
[docs]
def xyscale(w):
plt.xscale(w)
plt.yscale(w)
[docs]
def axis_robust(AX):
"""
Set axis limits selectively, leaving ``None`` entries unchanged.
Parameters
----------
AX : list of 4 floats or None
``[xmin, xmax, ymin, ymax]``. Use ``None`` for any limit you
want to keep at its current value.
Examples
--------
>>> axis_robust([0, None, None, None]) # only set xmin to 0
"""
AX_orig = axis()
for ii in range(4):
if AX[ii] is None:
AX[ii] = AX_orig[ii]
axis(AX)
[docs]
def xlim_robust(AX):
axis_robust([AX, *ylim()])
[docs]
def ylim_robust(AX):
axis_robust([*xlim(), AX])
[docs]
def set_axis_ticks_pctn(cur_axis = 'x'):
fmt = '%.0f%%' # Format you want the ticks, e.g. '40%'
ticker_obj = matplotlib.ticker.FormatStrFormatter(fmt)
if cur_axis.lower() == 'x':
cur_axis_h = gca().xaxis
elif cur_axis.lower() == 'y':
cur_axis_h = gca().yaxis
else:
raise Exception("You must pass either x or y, you passed: " + str(cur_axis))
cur_axis_h.set_major_formatter(ticker_obj)
[docs]
def plot_endpoints(endpoints, color='#0093e7'):
x_starts = [w[0] for w in endpoints]
x_ends = [w[1] for w in endpoints]
# Set figure size
plt.figure(figsize=(16, 8))
# Make a subplot for each day highlighting registered, filtered and
# Get timestamp for star and end of each day
x_start = min(x_starts)
x_end = max(x_ends)
# Make a subplot for each day
# plt.ylabel(day.strftime('%Y-%m-%d'), rotation=0, labelpad= 30)
# Get event for each date
for idx, event in enumerate(endpoints):
plt.fill_between(
[event[0], event[1]],
[1, 1],
color=color, label='Event')
plt.ylim((0, 1))
plt.xlim((x_start, x_end))
# Eliminate spaces between subplots
# plt.subplots_adjust(hspace=0)
# Create color patches for the legend
# Is this right???
# red_patch = plt.patches.Patch(color='#0093e7', label='Detected')
# Plot legend on the bottom subplot
# plt.legend(loc='lower left', borderaxespad=0., prop={'size': 12},
# ncol=3, fancybox=True, shadow=True, handles=[red_patch])
[docs]
def linspecer(n, color='muted'):
"""
Generate *n* distinguishable colors from a seaborn palette.
Parameters
----------
n : int
Number of colors.
color : str, optional
Seaborn palette name. Default is ``'muted'``.
Returns
-------
colors : ndarray of shape (n, 3)
RGB color values.
"""
return np.array(sns.color_palette(color, n_colors=n))
# ax.format_xdata = mdates.DateFormatter('%Y-%m-%d')
# plt.xticks(x,tick_labels)
# plt.xticks(rotation=70)
# def x
# gca().set_xscale('log')
# gca().set_yscale('log')
# pcolor on a log scale, with zero values marked
[docs]
def logpcolor(x, y, C):
C = np.log10(C)
C[C == -np.inf] = -np.max(C)
return plt.pcolor(x, y, C)
[docs]
def brighten(c, frac = .5):
"""
Brighten an RGB color by blending it toward white.
Parameters
----------
c : array-like
RGB color (values 0-1).
frac : float, optional
Blend fraction. Smaller values produce brighter colors.
Default is 0.5.
Returns
-------
color : ndarray
Brightened RGB color.
"""
return np.array(c) * frac + 1 - frac
# manually add a colorbar, definitely not the best way to go about this
# but was really tricky to do otherwise
[docs]
def example_add_colobar():
Y = np.linspace(0, 10.5, 100)
scale_func = lambda w: (w * 100 / 10.5).astype(int)
C = linspecer(101, "coolwarm")
add_colorbar(Y, C, scale_func)
[docs]
def add_colorbar(Y, C, scale_func):
# Y = np.linspace(0, .8, 100)
c = [C[w] for w in scale_func(Y)]
for ii, y in enumerate(Y):
plot([0, 1], [y, y], lw=4, color=c[ii])
xticks([])
gca().yaxis.tick_right()
[docs]
def legend_helper(fig: Union[plt.Figure, plt.Axes],
*args: Iterable[plt.Axes]) -> Dict[str, Any]:
"""
Collect legend handles and labels from multiple axes.
Parameters
----------
fig : Figure or Axes
A matplotlib Figure (collects from all its axes) or a single Axes.
*args : Axes
Additional axes to collect from (when *fig* is an Axes).
Returns
-------
dict
Dict with ``'handles'`` and ``'labels'`` lists.
"""
if isinstance(fig, plt.Figure):
handles, labels = [list(chain.from_iterable(seq)) for seq in zip(*(
ax.get_legend_handles_labels() for ax in fig.axes
))]
else:
handles, labels = [list(chain.from_iterable(seq)) for seq in zip(*(
ax.get_legend_handles_labels() for ax in chain([fig], args)
))]
return {
'handles': handles,
'labels': labels,
}
[docs]
def calc_plot_ROC(y1, y2):
"""
Plot an ROC curve from two distributions used as a binary classifier.
Parameters
----------
y1 : array-like
Scores for the negative class.
y2 : array-like
Scores for the positive class.
Returns
-------
auc : float
Area under the ROC curve.
"""
y_score = np.concatenate([y1, y2])
y_true = np.concatenate([np.zeros(len(y1)), np.ones(len(y2))])
return plot_ROC(y_true, y_score)
[docs]
def plot_ROC_hist(y_true, y_score):
y_true = np.array(y_true)
y_score = np.array(y_score)
return nhist({'true': y_score[y_true], 'false': y_score[~y_true]}, normalize='number')
[docs]
def plot_ROC(y_true, y_score, c='k'):
"""
Plot an ROC curve from true labels and predicted scores.
Parameters
----------
y_true : array-like
Binary ground-truth labels.
y_score : array-like
Predicted scores (higher = more likely positive).
c : str, optional
Line color. Default is ``'k'``.
Returns
-------
auc : float
Area under the ROC curve.
"""
I = np.logical_not(np.isnan(y_score))
y_true = y_true[I]
y_score = y_score[I]
fpr, tpr, _ = roc_curve(y_true, y_score)
cur_auc = auc(fpr, tpr)
plot(fpr, tpr, c)
plot_diag()
xylim([-.01, 1.01])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve, AUC = ' + '{:.2f}'.format(cur_auc))
nicefy()
return cur_auc
[docs]
def jitter(xx, yy, maxn=4, xscale=None):
"""
Horizontally jitter overlapping points so they are visible.
Parameters
----------
xx : array-like
X positions.
yy : array-like
Y positions (used to detect overlap).
maxn : int, optional
Maximum number of dots wide per bin. Default is 4.
xscale : float, optional
Horizontal spacing between jittered points. Default is auto.
Returns
-------
xx : ndarray
Jittered x positions.
"""
if xscale is None:
# split 50 percent of a bar width 1 by the max number of elements to fit
xscale = 0.5 / maxn
def get_bins(yy, bin_width):
"""
Get the bins if you know the bin_width
Then center it
:param yy:
:param bin_width:
:return:
"""
bins = np.arange(min(yy), max(yy) + bin_width, bin_width)
# print(bins)
# bins = bins - (max(bins) - max(yy)) / 2
return bins
def get_bin_width(yy, maxn, n_iter=5):
"""
Get the right bin width, do a binary search on histogram bin width
checking various bin widths until you've found one where at most 'n' items fall in one bin
if everything is equal then the maxn parameter doesn't matter
todo: handle the case where y values are equal more elegantly
:param yy:
:return:
"""
# if you know the bin width - what is the most that fall in one bin
get_max_per_bin = lambda yy, bin_width: np.max(np.histogram(yy, bins=get_bins(yy, bin_width))[0])
# initialize bin width options to be the most and least it could be
if np.std(yy) == 0:
bin_width = 1
else:
yy_sort = sorted(yy)
bin_width_min = np.min(np.diff(yy_sort))
bin_width_max = yy_sort[-1] - yy_sort[0]
for cnt in range(n_iter):
bin_width = (bin_width_max + bin_width_min) / 2
max_per_bin = get_max_per_bin(yy, bin_width)
if max_per_bin >= maxn:
bin_width_max = bin_width
else:
bin_width_min = bin_width
return bin_width
if len(yy) < 2:
# if there is only one point, then you don't need to jitter it
return xx
else:
xx = np.array(xx).astype(float)
bin_width = get_bin_width(yy, maxn, n_iter=5)
bins = get_bins(yy, bin_width)
idxs = np.digitize(yy, bins, right=True)
for bin_idx in range(len(bins)):
I = idxs == bin_idx
n = sum(I)
push = -(n - 1) / 2 # so it will be centered if there is only one.
xx[I] = xx[I] + xscale * (push + np.arange(n))
return xx
[docs]
def interp_plot(x, y, *args, **kargs):
"""
Plot with PCHIP interpolation for smooth curves from sparse data.
Parameters
----------
x : array-like
X data (supports pandas Timestamps).
y : array-like
Y data. NaN values are removed before interpolation.
*args, **kargs
Passed to ``plt.plot``.
Returns
-------
x_i : ndarray
Interpolated x values.
y_i : ndarray
Interpolated y values.
"""
y = np.array(y)
x = np.array(x)
if len(x) == 0:
return
date_flag = type(x[0]) == pd._libs.tslibs.timestamps.Timestamp
if date_flag:
x = np.array([w.value for w in x])
n = 400
I = np.logical_not(pd.isnull(y))
x = x[I]
y = y[I]
x_i = np.linspace(np.min(x), np.max(x), n)
f = interpolate.PchipInterpolator(x, y)
y_i = f(x_i)
if date_flag:
x_i = np.array([pd.Timestamp(w) for w in x_i])
plot(x_i, y_i, *args, **kargs)
return x_i, y_i