Source code for embedded_voting.utils.plots

# -*- coding: utf-8 -*-
"""
This file is part of Embedded Voting.
"""
import ternary
import numpy as np


[docs]def create_3d_plot(fig, position=None): """ Create the background for a 3D plot on the non-negative orthant. Parameters ---------- fig The matplotlib figure on which we are drawing. position The position of the subplot on which we are drawing. Return ------ matplotlib ax """ if position is None: position = [1, 1, 1] ax = fig.add_subplot(position[0], position[1], position[2], projection='3d') ax.set_xlim(0, 1) ax.set_ylim(1, 0) ax.set_zlim(0, 1) angle = np.linspace(0, np.pi / 2, 100) n_angles = len(angle) cos_x = [np.cos(a) for a in angle] sin_x = [np.sin(a) for a in angle] ax.plot(cos_x, sin_x, [0] * n_angles, color='k', alpha=0.5) ax.plot([0] * n_angles, cos_x, sin_x, color='k', alpha=0.5) ax.plot(cos_x, [0] * n_angles, sin_x, color='k', alpha=0.5) ax.grid(True) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) return ax
[docs]def create_ternary_plot(fig, position=None): """ Create the background for a 2D ternary plot of the non-negative orthant. Return ______ matplotlib ax """ if position is None: position = [1, 1, 1] ax = fig.add_subplot(position[0], position[1], position[2]) figure, tax = ternary.figure(ax=ax, scale=1) tax.boundary(linewidth=2.0) tax.gridlines(multiple=0.1, color="blue") tax.ticks(axis='lbr', linewidth=1, multiple=0.1, ticks=["%.1f" % (i/10) for i in range(11)]) tax.clear_matplotlib_ticks() tax.get_axes().axis('off') tax.left_axis_label("$x$", fontsize=14) tax.right_axis_label("$y$", fontsize=14) tax.bottom_axis_label("$z$", fontsize=14) return tax
[docs]def create_map_plot(fig, image, position, title=""): """ Create the background for a map plot. Parameters ---------- fig : matplotlib figure The matplotlib figure on which we are drawing. image : np.ndarray The image to plot. Should be of size `map_size`, `map_size`. position: list The position of the subplot on which we are drawing. title : str Title of the plot. Return ------ matplotlib ax """ map_size = len(image) # noinspection PyUnresolvedReferences ax = fig.add_subplot(position[0], position[1], position[2]) ax.imshow(image[::-1, ::], vmin=0, vmax=1) ax.set_xlabel('Correlation') ax.set_ylabel('Orthogonality') ax.set_title(title) ax.set_xticks([0, map_size - 1]) ax.set_yticks([0, map_size - 1]) ax.set_xticklabels([0, 1]) ax.set_yticklabels([1, 0]) return ax