Skip to content

gemdat.plots

This module contains all the plots that Gemdat can generate.

arrhenius(*, fit, show_std=True)

Plot Arrhenius fit.

Parameters:

  • fit (ArrheniusFit) โ€“

    ArrheniusFit instance.

  • show_std (bool, default: True ) โ€“

    If True, show error bars (from diffusivities_std) and a ยฑ1ฯƒ fit band (from cov).

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_arrhenius.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def arrhenius(*, fit: ArrheniusFit, show_std: bool = True) -> go.Figure:
    """Plot Arrhenius fit.

    Parameters
    ----------
    fit
        ArrheniusFit instance.
    show_std
        If True, show error bars (from diffusivities_std) and a ยฑ1ฯƒ fit band
        (from cov).

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    T = fit.temperatures
    x = 1000.0 / T
    y = np.log(fit.diffusivities)

    error_y = None
    if show_std and getattr(fit, 'diffusivities_std', None) is not None:
        sigma_ln = fit.diffusivities_std / fit.diffusivities
        error_y = dict(type='data', array=sigma_ln, visible=True)

    fig = go.Figure()
    color_hex = fig.layout['template']['layout']['colorway'][0]
    color_rgba = hex2rgba(color_hex, opacity=0.3)

    fig.add_trace(
        go.Scatter(x=x, y=y, mode='markers', name='data', error_y=error_y, line_color=color_hex)
    )

    # Fit line
    t_line = np.linspace(float(T.min()), float(T.max()), 200)
    x_line = 1000.0 / t_line
    ln_line = fit.intercept + fit.slope * (1.0 / t_line)
    fig.add_trace(
        go.Scatter(x=x_line, y=ln_line, mode='lines', name='fit', line_color=color_hex)
    )

    # ยฑ1ฯƒ band (in ln-space)
    if show_std and getattr(fit, 'cov', None) is not None:
        v = np.column_stack([1.0 / t_line, np.ones_like(t_line)])
        var = np.einsum('ij,jk,ik->i', v, fit.cov, v)
        std = np.sqrt(np.maximum(var, 0.0))
        upper = ln_line + std
        lower = ln_line - std

        fig.add_trace(
            go.Scatter(
                x=x_line,
                y=upper,
                mode='lines',
                line=dict(width=0),
                showlegend=False,
                fillcolor=color_rgba,
            )
        )
        fig.add_trace(
            go.Scatter(
                x=x_line,
                y=lower,
                mode='lines',
                fill='tonexty',
                line=dict(width=0),
                name='ยฑ1ฯƒ',
                opacity=0.2,
                fillcolor=color_rgba,
            )
        )

    fig.update_layout(xaxis_title='1000/T (Kโปยน)', yaxis_title='ln(D)')
    return fig

autocorrelation(*, orientations, show_traces=True, show_shaded=True)

Plot the autocorrelation function of the unit vectors series.

Parameters:

  • orientations (Orientations) โ€“

    The unit vector trajectories

  • show_traces (bool, default: True ) โ€“

    If True, show traces of individual trajectories

  • show_shaded (bool, default: True ) โ€“

    If True, show standard deviation as shaded area

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_autocorrelation.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def autocorrelation(
    *,
    orientations: Orientations,
    show_traces: bool = True,
    show_shaded: bool = True,
) -> go.Figure:
    """Plot the autocorrelation function of the unit vectors series.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    show_traces : bool
        If True, show traces of individual trajectories
    show_shaded : bool
        If True, show standard deviation as shaded area

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    ac = orientations.autocorrelation()
    ac_std = ac.std(axis=0)
    ac_mean = ac.mean(axis=0)

    time_ps = orientations._time_step * 1e12
    t_values = np.arange(ac_mean.shape[0]) * time_ps

    fig = go.Figure()

    color_hex = fig.layout['template']['layout']['colorway'][0]
    color_rgba = hex2rgba(color_hex, opacity=0.3)

    fig.add_trace(
        go.Scatter(
            x=t_values,
            y=ac_mean,
            line_color=color_hex,
            name='FFT Autocorrelation',
            mode='lines',
            line={'width': 3},
            legendgroup='autocorr',
            zorder=10,
        )
    )

    if show_shaded:
        fig.add_trace(
            go.Scatter(
                x=t_values,
                y=ac_mean + ac_std,
                fillcolor=color_rgba,
                mode='lines',
                line={'width': 0},
                legendgroup='autocorr',
                showlegend=False,
                zorder=0,
            )
        )
        fig.add_trace(
            go.Scatter(
                x=t_values,
                y=ac_mean - ac_std,
                fillcolor=color_rgba,
                mode='none',
                legendgroup='autocorr',
                showlegend=False,
                fill='tonexty',
                zorder=0,
            )
        )

    if show_traces:
        for i, trace in enumerate(ac):
            fig.add_trace(
                go.Scatter(
                    x=t_values,
                    y=trace,
                    name=i,
                    mode='lines',
                    line={'width': 0.25},
                    showlegend=False,
                    zorder=5,
                )
            )

    fig.update_layout(
        title='FFT Autocorrelation',
        xaxis_title='Time lag (ps)',
        yaxis_title='mean + std',
    )

    return fig

bond_length_distribution(*, orientations, bins=50)

Plot the bond length probability distribution.

Parameters:

  • orientations (Orientations) โ€“

    The unit vector trajectories

  • bins (int, default: 50 ) โ€“

    The number of bins

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_bond_length_distribution.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> go.Figure:
    """Plot the bond length probability distribution.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    bins : int, optional
        The number of bins

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    hist_df = _orientations_to_histogram(orientations, bins=bins)
    x, y = _fit_skewnorm_to_hist(hist_df, steps=100)

    fig = px.bar(
        hist_df,
        x='center',
        y='prob',
    )

    fig.add_trace(
        go.Scatter(x=x, y=y, name='Skewed Gaussian Fit', mode='lines', line={'width': 3})
    )

    fig.update_layout(
        title='Bond length probability distribution',
        xaxis_title='Bond length (ร…)',
        yaxis_title='Probability density (ร…<sup>-1</sup>)',
    )

    return fig

collective_jumps(*, jumps)

Plot collective jumps per jump-type combination.

Parameters:

  • jumps (Jumps) โ€“

    Input data

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_collective_jumps.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def collective_jumps(*, jumps: Jumps) -> go.Figure:
    """Plot collective jumps per jump-type combination.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    collective = jumps.collective()
    matrix = collective.site_pair_count_matrix()

    fig = px.imshow(matrix)

    labels = collective.site_pair_count_matrix_labels()

    ticks = list(range(len(labels)))

    fig.update_layout(
        xaxis={'tickmode': 'array', 'tickvals': ticks, 'ticktext': labels},
        yaxis={'tickmode': 'array', 'tickvals': ticks, 'ticktext': labels},
        title='Cooperative jumps per jump-type combination',
    )

    return fig

density(volume, *, structure=None, force_lattice=None)

Create density plot from volume and structure.

Uses plotly as plotting backend.

Arguments

volume : Volume Input volume structure : Structure, optional Input structure force_lattice : Lattice | None Plot volume and structure using this lattice as a basis. Overrides the default, which is to use volume.lattice and structure.lattice where applicable.

Returns:

  • fig ( Figure ) โ€“

    Output as plotly figure

Source code in src/gemdat/plots/plotly/_density.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def density(
    volume: Volume,
    *,
    structure: Structure | None = None,
    force_lattice: Lattice | None = None,
) -> go.Figure:
    """Create density plot from volume and structure.

    Uses plotly as plotting backend.

    Arguments
    ---------
    volume : Volume
        Input volume
    structure : Structure, optional
        Input structure
    force_lattice : Lattice | None
        Plot volume and structure using this lattice as a basis.
        Overrides the default, which is to use `volume.lattice`
        and `structure.lattice` where applicable.

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output as plotly figure
    """
    from ._plot3d import plot_3d

    return plot_3d(volume=volume, structure=structure, lattice=force_lattice)

displacement_histogram(trajectory, n_parts=1)

Plot histogram of total displacement at final timestep.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) โ€“

    Plot error bars by dividing data into n parts

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_histogram.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def displacement_histogram(trajectory: Trajectory, n_parts: int = 1) -> go.Figure:
    """Plot histogram of total displacement at final timestep.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Plot error bars by dividing data into n parts

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    if n_parts == 1:
        df = _trajectory_to_dataframe(trajectory)

        fig = px.bar(df, x='Displacement', y='count', color='Element', barmode='stack')

        fig.update_layout(
            title='Displacement per element',
            xaxis_title='Displacement (ร…)',
            yaxis_title='Nr. of atoms',
        )
    else:
        interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
        dfs = [_trajectory_to_dataframe(part) for part in trajectory.split(n_parts)]

        all_df = pd.concat(dfs)

        # Get the mean and standard deviation
        grouped = all_df.groupby(['Displacement', 'Element'])
        mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
        std = grouped.std().reset_index().rename(columns={'count': 'std'})
        df = mean.merge(std, how='inner')

        fig = px.bar(
            df,
            x='Displacement',
            y='mean',
            color='Element',
            error_y='std',
            barmode='group',
        )

        fig.update_layout(
            title=(
                f'Displacement per element after {int(interval[1] - interval[0])} timesteps'
            ),
            xaxis_title='Displacement (ร…)',
            yaxis_title='Nr. of atoms',
        )

    return fig

displacement_per_atom(*, trajectory)

Plot displacement per atom.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_per_atom.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per atom.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    distances = [dist for dist in trajectory.distances_from_base_position()]

    for i, distance in enumerate(distances):
        fig.add_trace(
            go.Scatter(y=distance, name=i, mode='lines', line={'width': 1}, showlegend=False)
        )

    fig.update_layout(
        title='Displacement per atom',
        xaxis_title='Time step',
        yaxis_title='Displacement (ร…)',
    )

    return fig

displacement_per_element(*, trajectory)

Plot displacement per element.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_per_element.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    displacements = _mean_displacements_per_element(trajectory)

    fig = go.Figure()

    for symbol, (mean, std) in displacements.items():
        fig.add_trace(
            go.Scatter(
                y=mean,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 3},
                legendgroup=symbol,
            )
        )
        fig.add_trace(
            go.Scatter(
                y=mean + std,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 0},
                legendgroup=symbol,
                showlegend=False,
            )
        )
        fig.add_trace(
            go.Scatter(
                y=mean - std,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 0},
                legendgroup=symbol,
                showlegend=False,
                fill='tonexty',
            )
        )

    fig.update_layout(
        title='Displacement per element',
        xaxis_title='Time step',
        yaxis_title='Displacement (ร…)',
    )

    return fig

energy_along_path(path, *, structure=None, other_paths=None)

Plot energy along specified path.

Parameters:

  • path (Pathway) โ€“

    Pathway object containing the energy along the path

  • structure (Structure, default: None ) โ€“

    Structure object to get the site information

  • other_paths (Pathway | list[Pathway], default: None ) โ€“

    Optional list of alternative paths to plot

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_energy_along_path.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def energy_along_path(
    path: Pathway,
    *,
    structure: Structure | None = None,
    other_paths: list[Pathway] | None = None,
) -> go.Figure:
    """Plot energy along specified path.

    Parameters
    ----------
    path : Pathway
        Pathway object containing the energy along the path
    structure : Structure
        Structure object to get the site information
    other_paths : Pathway | list[Pathway]
        Optional list of alternative paths to plot

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=np.arange(len(path.energy)),
            y=path.energy,
            name='Optimal path',
            mode='lines',
            line={'width': 3},
        )
    )

    if structure:
        nearest_sites = path.path_over_structure(structure)

        prev_site = nearest_sites[0]
        sections = [(0, prev_site)]

        for i, site in enumerate(nearest_sites):
            if site != prev_site:
                sections.append((i, site))

            prev_site = site

        highlight = True

        for (start, site), (stop, _) in pairwise(sections):
            if highlight:
                fig.add_vrect(
                    x0=start,
                    x1=stop,
                    line_width=0,
                    fillcolor='red',
                    opacity=0.1,
                )

            fig.add_annotation(
                x=(start + stop) / 2,
                y=0.1,
                yref='y domain',
                text=site.label,
                xanchor='center',
                yanchor='middle',
                showarrow=False,
                hovertext=str(site),
            )

            highlight = not highlight

    if other_paths:
        for idx, path in enumerate(other_paths):
            fig.add_trace(
                go.Scatter(
                    x=np.arange(len(path.energy)),
                    y=path.energy,
                    name=f'Alternative {idx + 1}',
                    mode='lines',
                    line={'width': 1},
                )
            )

    fig.update_layout(title='Pathway', xaxis_title='Steps', yaxis_title='Free energy (eV)')

    return fig

frequency_vs_occurence(*, trajectory)

Plot attempt frequency vs occurence.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_frequency_vs_occurence.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
    """Plot attempt frequency vs occurence.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : plotly.graph_objects.Figure.Figure
        Output figure
    """
    metrics = trajectory.metrics()
    speed = metrics.speed()

    length = speed.shape[1]
    half_length = length // 2 + 1

    trans = np.fft.fft(speed)

    two_sided = np.abs(trans / length)
    one_sided = two_sided[:, :half_length]

    fig = go.Figure()

    f = trajectory.sampling_frequency * np.arange(half_length) / length

    sum_freqs = np.sum(one_sided, axis=0)
    smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
    fig.add_trace(
        go.Scatter(
            y=smoothed,
            x=f,
            mode='lines',
            line={'width': 3, 'color': 'blue'},
            showlegend=False,
        )
    )

    y_max = np.max(sum_freqs)

    attempt_freq, attempt_freq_std = (float(i) for i in metrics.attempt_frequency())

    if attempt_freq:
        fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
    if attempt_freq and attempt_freq_std:
        fig.add_vline(
            x=attempt_freq + attempt_freq_std,
            line={'width': 2, 'color': 'red', 'dash': 'dash'},
        )
        fig.add_vline(
            x=attempt_freq - attempt_freq_std,
            line={'width': 2, 'color': 'red', 'dash': 'dash'},
        )

    fig.update_layout(
        title='Frequency vs Occurence',
        xaxis_title='Frequency (Hz)',
        yaxis_title='Occurrence (a.u.)',
        xaxis_range=[-0.1e13, 2.5e13],
        yaxis_range=[0, y_max],
        width=600,
        height=500,
    )

    return fig

jumps_3d(*, jumps)

Plot jumps in 3D.

Parameters:

  • jumps (Jumps) โ€“

    Input data

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_3d.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def jumps_3d(*, jumps: Jumps) -> go.Figure:
    """Plot jumps in 3D.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    from ._plot3d import plot_3d

    return plot_3d(jumps=jumps, structure=jumps.sites)

jumps_3d_animation(*, jumps, t_start, t_stop, decay=0.05, skip=5, interval=20)

Plot jumps in 3D as an animation over time.

Parameters:

  • jumps (Jumps) โ€“

    Input data

  • t_start (int) โ€“

    Time step to start animation (relative to equilibration time)

  • t_stop (int) โ€“

    Time step to stop animation (relative to equilibration time)

  • decay (float, default: 0.05 ) โ€“

    Controls the decay of the line width (higher = faster decay)

  • skip (float, default: 5 ) โ€“

    Skip frames (increase for faster, but less accurate rendering)

  • interval (int, default: 20 ) โ€“

    Delay between frames in milliseconds.

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/matplotlib/_jumps_3d_animation.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def jumps_3d_animation(
    *,
    jumps: Jumps,
    t_start: int,
    t_stop: int,
    decay: float = 0.05,
    skip: int = 5,
    interval: int = 20,
) -> animation.FuncAnimation:
    """Plot jumps in 3D as an animation over time.

    Parameters
    ----------
    jumps : Jumps
        Input data
    t_start : int
        Time step to start animation (relative to equilibration time)
    t_stop : int
        Time step to stop animation (relative to equilibration time)
    decay : float, optional
        Controls the decay of the line width (higher = faster decay)
    skip : float, optional
        Skip frames (increase for faster, but less accurate rendering)
    interval : int, optional
        Delay between frames in milliseconds.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    minwidth = 0.2
    maxwidth = 5.0

    trajectory = jumps.trajectory

    class LabelItems:
        def __init__(self, labels, coords):
            self.labels = labels
            self.coords = coords

        def items(self):
            yield from zip(self.labels, self.coords)

    coords = jumps.sites.frac_coords
    lattice = trajectory.get_lattice()

    color_from = colormaps['Set1'].colors  # type: ignore
    color_to = colormaps['Pastel1'].colors  # type: ignore

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    xyz_labels = LabelItems(
        'OABC',
        [
            [-0.1, -0.1, -0.1],
            [1.1, -0.1, -0.1],
            [-0.1, 1.1, -0.1],
            [-0.1, -0.1, 1.1],
        ],
    )

    plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)

    plotter.plot_labels(xyz_labels, lattice=lattice, ax=ax, color='green', size=12)

    assert len(ax.collections) == 0
    plotter.plot_points(coords, lattice=lattice, ax=ax, s=50, color='white', edgecolor='black')
    points = ax.collections

    events = jumps.data.sort_values('start time', ignore_index=True)

    for _, event in events.iterrows():
        site_i = event['start site']
        site_j = event['destination site']

        coord_i = coords[site_i]
        coord_j = coords[site_j]

        lw = 0

        _, image = lattice.get_distance_and_image(coord_i, coord_j)

        line = [coord_i, coord_j + image]

        plotter.plot_path(line, lattice=lattice, ax=ax, color='red', linewidth=lw)

    lines = ax.lines[3:]

    ax.set(
        title='Jumps between sites',
        xlabel="x' (ร…)",
        ylabel="y' (ร…)",
        zlabel="z' (ร…)",
    )

    ax.set_aspect('equal')  # only auto is supported

    def update(frame_no):
        t_frame = t_start + (frame_no * skip)

        for i, event in events.iterrows():
            if event['start time'] > t_frame:
                break

            lw = max(maxwidth - decay * (t_frame - event['start time']), minwidth)

            line = lines[i]
            line.set_color('red')
            line.set_linewidth(lw)

            points[event['start site']].set_facecolor(
                color_from[event['atom index'] % len(color_from)]
            )
            points[event['destination site']].set_facecolor(
                color_to[event['atom index'] % len(color_to)]
            )

        start_time = event['start time']
        ax.set_title(f'T: {t_frame} | Next jump: {start_time}')

    n_frames = int((t_stop - t_start) / skip)

    return animation.FuncAnimation(
        fig=fig, func=update, frames=n_frames, interval=interval, repeat=False
    )

jumps_vs_distance(*, jumps, jump_res=0.1, n_parts=1)

Plot jumps vs distance histogram.

Parameters:

  • jumps (Jumps) โ€“

    Input jumps data

  • jump_res (float, default: 0.1 ) โ€“

    Resolution of the bins in Angstrom

  • n_parts (int, default: 1 ) โ€“

    Number of parts for error analysis

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_vs_distance.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def jumps_vs_distance(
    *,
    jumps: Jumps,
    jump_res: float = 0.1,
    n_parts: int = 1,
) -> go.Figure:
    """Plot jumps vs distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    jump_res : float, optional
        Resolution of the bins in Angstrom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)

    if n_parts == 1:
        fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
    else:
        fig = px.bar(df, x='Displacement', y='mean', error_y='std', barmode='stack')

    fig.update_layout(
        title='Jumps vs. Distance',
        xaxis_title='Distance (ร…)',
        yaxis_title='Number of jumps',
    )

    return fig

jumps_vs_time(*, jumps, bins=8, n_parts=1)

Plot jumps vs distance histogram.

Parameters:

  • jumps (Jumps) โ€“

    Input jumps data

  • bins (int, default: 8 ) โ€“

    Number of bins

  • n_parts (int, default: 1 ) โ€“

    Number of parts for error analysis

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_vs_time.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def jumps_vs_time(*, jumps: Jumps, bins: int = 8, n_parts: int = 1) -> go.Figure:
    """Plot jumps vs distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    bins : int, optional
        Number of bins
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    maxlen = len(jumps.trajectory) / n_parts
    binsize = maxlen / bins + 1
    data = []

    for jumps_part in jumps.split(n_parts=n_parts):
        data.append(
            np.histogram(jumps_part.data['start time'], bins=bins, range=(0.0, maxlen))[0]
        )

    df = pd.DataFrame(data=data)
    columns = [binsize / 2 + binsize * col for col in range(bins)]

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std), columns=['time', 'count', 'std'])

    if n_parts > 1:
        fig = px.bar(df, x='time', y='count', error_y='std')
    else:
        fig = px.bar(df, x='time', y='count')

    fig.update_layout(
        bargap=0.2,
        title='Jumps vs. time',
        xaxis_title='Time (steps)',
        yaxis_title='Number of jumps',
    )

    return fig

msd_kinisi(trajectory, specie, *, diffusion_analyzer=None, step_skip=1, dt=None, dimension='xyz', distance_unit='angstrom', specie_indices=None, masses=None, progress=True, save_cache=True, return_cache=True, show_shaded=True)

Plot mean-squared displacement (MSD) with uncertainties from a kinisi DiffusionAnalyzer.

Parameters:

  • trajectory (Trajectory) โ€“

    GEMDAT trajectory

  • specie (str) โ€“

    Specie to calculate diffusivity for, e.g. "Li".

  • diffusion_analyzer (Optional['DiffusionAnalyzer'], default: None ) โ€“

    A kinisi DiffusionAnalyzer instance.

  • step_skip (int, default: 1 ) โ€“

    Number of MD integrator time steps between stored frames.

  • dt ('sc.Variable | None', default: None ) โ€“

    Time intervals to calculate displacements over. Optional; if None, kinisi defaults to a regular grid from the smallest interval (time_step * step_skip) to the full trajectory length.

  • dimension (str, default: 'xyz' ) โ€“

    Subset of "xyz" indicating displacement axes of interest.

  • distance_unit (str, default: 'angstrom' ) โ€“

    Unit of distance in the input structures, as a string understood by scipp.Unit(...) (default: "angstrom").

  • specie_indices ('sc.Variable | None', default: None ) โ€“

    Indices of the specie to calculate the diffusivity for. Optional; if None, kinisi selects indices based on specie.

  • masses ('sc.Variable | None', default: None ) โ€“

    Masses for centre-of-mass handling. Optional.

  • progress (bool, default: True ) โ€“

    Show progress bars during parsing and MSD evaluation.

  • save_cache (bool, default: True ) โ€“

    Cache the populated analyzer on this trajectory instance.

  • return_cache (bool, default: True ) โ€“

    Use cached data.

  • show_shaded (bool, default: True ) โ€“

    If True, plot ยฑ1ฯƒ uncertainties as a shaded region.

Returns:

  • fig ( Figure ) โ€“

    Output figure.

Source code in src/gemdat/plots/plotly/_msd_kinisi.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def msd_kinisi(
    trajectory: Trajectory,
    specie: str,
    *,
    diffusion_analyzer: Optional['DiffusionAnalyzer'] = None,
    step_skip: int = 1,
    dt: 'sc.Variable | None' = None,
    dimension: str = 'xyz',
    distance_unit: str = 'angstrom',
    specie_indices: 'sc.Variable | None' = None,
    masses: 'sc.Variable | None' = None,
    progress: bool = True,
    save_cache: bool = True,
    return_cache: bool = True,
    show_shaded: bool = True,
) -> go.Figure:
    """Plot mean-squared displacement (MSD) with uncertainties from a kinisi
    DiffusionAnalyzer.

    Parameters
    ----------
    trajectory
        GEMDAT trajectory
    specie
        Specie to calculate diffusivity for, e.g. ``"Li"``.
    diffusion_analyzer
        A kinisi DiffusionAnalyzer instance.
    step_skip
        Number of MD integrator time steps between stored frames.
    dt
        Time intervals to calculate displacements over. Optional; if ``None``,
        kinisi defaults to a regular grid from the smallest interval
        (``time_step * step_skip``) to the full trajectory length.
    dimension
        Subset of ``"xyz"`` indicating displacement axes of interest.
    distance_unit
        Unit of distance in the input structures, as a string understood by
        ``scipp.Unit(...)`` (default: ``"angstrom"``).
    specie_indices
        Indices of the specie to calculate the diffusivity for. Optional; if ``None``,
        kinisi selects indices based on ``specie``.
    masses
        Masses for centre-of-mass handling. Optional.
    progress
        Show progress bars during parsing and MSD evaluation.
    save_cache
        Cache the populated analyzer on this trajectory instance.
    return_cache
        Use cached data.
    show_shaded : bool, optional
        If True, plot ยฑ1ฯƒ uncertainties as a shaded region.

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure.
    """
    if diffusion_analyzer:
        cache_data = diffusion_analyzer
    else:
        cache_data = trajectory.to_kinisi_diffusion_analyzer(
            specie=specie,
            step_skip=step_skip,
            dt=dt,
            dimension=dimension,
            distance_unit=distance_unit,
            specie_indices=specie_indices,
            masses=masses,
            progress=progress,
            save_cache=save_cache,
            return_cache=return_cache,
        )

    dt = cache_data.dt
    msd = cache_data.msd

    x = np.asarray(dt.values)
    y = np.asarray(msd.values)

    variances = cache_data.msd.variances
    yerr = None if variances is None else np.sqrt(np.asarray(variances))

    fig = go.Figure()

    color_hex = fig.layout['template']['layout']['colorway'][0]
    color_rgba = hex2rgba(color_hex, opacity=0.3)

    name = f'{specie} MSD'

    if (yerr is not None) and show_shaded:
        name = f'{specie} MSD ยฑ 1ฯƒ'
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y + yerr,
                fillcolor=color_rgba,
                mode='lines',
                line={'width': 0},
                legendgroup=specie,
                showlegend=False,
                zorder=0,
            )
        )
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y - yerr,
                fillcolor=color_rgba,
                mode='none',
                legendgroup=specie,
                fill='tonexty',
                showlegend=False,
                zorder=0,
            )
        )

    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            name=name,
            mode='lines',
            line={'width': 3, 'color': color_hex},
            legendgroup=specie,
            zorder=1,
        )
    )

    fig.update_layout(
        showlegend=True,
        title='Mean squared displacement',
        xaxis_title=f'Time lag ({dt.unit})',
        yaxis_title=f'MSD ({msd.unit})',
    )

    return fig

msd_per_element(*, trajectory)

Plot mean squared displacement per element.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_msd_per_element.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot mean squared displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    time_ps = trajectory.time_step_ps

    species = list(set(trajectory.species))

    for i, sp in enumerate(species):
        assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

        color_hex = fig.layout['template']['layout']['colorway'][i]
        color_rgba = hex2rgba(color_hex, opacity=0.3)

        traj = trajectory.filter(sp.symbol)

        msd = traj.mean_squared_displacement()
        msd_mean = np.mean(msd, axis=0)
        msd_std = np.std(msd, axis=0)
        t_values = np.arange(len(msd_mean)) * time_ps

        fig.add_trace(
            go.Scatter(
                x=t_values,
                y=msd_mean + msd_std,
                fillcolor=color_rgba,
                mode='lines',
                line={'width': 0},
                legendgroup=sp.symbol,
                showlegend=False,
                zorder=0,
            )
        )
        fig.add_trace(
            go.Scatter(
                x=t_values,
                y=msd_mean - msd_std,
                fillcolor=color_rgba,
                mode='none',
                legendgroup=sp.symbol,
                showlegend=False,
                fill='tonexty',
                zorder=0,
            )
        )

        fig.add_trace(
            go.Scatter(
                x=t_values,
                y=msd_mean,
                name=f'{sp.symbol} mean+std',
                line_color=color_hex,
                mode='lines',
                line={'width': 3},
                legendgroup=sp.symbol,
                zorder=1,
            )
        )

    fig.update_layout(
        title='Mean squared displacement per element',
        xaxis_title='Time lag (ps)',
        yaxis_title=r'MSD (ร…<sup>2</sup>)',
    )

    return fig

plot_3d(*, volume=None, structure=None, paths=None, jumps=None, lattice=None, title='3D plot')

Plot 3d.

Source code in src/gemdat/plots/plotly/_plot3d.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def plot_3d(
    *,
    volume: Volume | None = None,
    structure: Structure | None = None,
    paths: Pathway | list[Pathway] | None = None,
    jumps: Jumps | None = None,
    lattice: Lattice | None = None,
    title: str = '3D plot',
) -> go.Figure:
    """Plot 3d."""
    fig = go.Figure()

    if not lattice:
        if volume:
            lattice = volume.lattice
        elif structure:
            lattice = structure.lattice
        elif jumps:
            lattice = jumps.trajectory.get_lattice()
        else:
            raise ValueError(
                'Lattice cannot be determined form volume, structure, or jumps object.'
            )
    else:
        raise ValueError('Cannot derive lattice from input.')

    plot_lattice_vectors(lattice, fig=fig)

    if volume:
        plot_volume(volume, lattice=lattice, fig=fig)

    if structure:
        plot_structure(structure=structure, lattice=lattice, fig=fig)

    if paths:
        plot_paths(paths=paths, lattice=lattice, fig=fig)

    if jumps:
        plot_jumps(jumps=jumps, fig=fig)

    update_layout(title=title, lattice=lattice, fig=fig)

    return fig

plot_3d_points(points, labels, *, fig, point_size=5, colors=None)

Plot points using plotly.

Parameters:

  • points (ndarray) โ€“

    Input points

  • labels (Sequence) โ€“

    Labels for points. Length must match points.

  • fig (Figure) โ€“

    Plotly figure to add traces to

  • point_size (int, default: 5 ) โ€“

    Size of the points

  • colors (Optional[dict[str, str]], default: None ) โ€“

    Mapping of colors for the each label. See the following link for a list of accepted colours: https://developer.mozilla.org/en-US/docs/Web/CSS/named-color

Source code in src/gemdat/plots/plotly/_plot3d.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def plot_3d_points(
    points: np.ndarray,
    labels: Sequence,
    *,
    fig: go.Figure,
    point_size: int = 5,
    colors: Optional[dict[str, str]] = None,
):
    """Plot points using plotly.

    Parameters
    ----------
    points : np.ndarray
        Input points
    labels : Sequence
        Labels for points. Length must match points.
    fig : plotly.graph_objects.Figure
        Plotly figure to add traces to
    point_size : int, optional
        Size of the points
    colors: dict, optional
        Mapping of colors for the each label.
        See the following link for a list of accepted colours:
        https://developer.mozilla.org/en-US/docs/Web/CSS/named-color
    """
    assert len(points) == len(labels)

    if not colors:
        colors = {
            label: px.colors.sample_colorscale('rainbow', [i / (len(labels) - 1)])
            for i, label in enumerate(labels)
        }

    for i, (x, y, z) in enumerate(points):
        label = labels[i]
        color = colors[label]

        fig.add_trace(
            go.Scatter3d(
                x=[x],
                y=[y],
                z=[z],
                mode='markers',
                name=label,
                marker={'size': point_size, 'color': color, 'line': {'width': 2.5}},
                showlegend=False,
            )
        )

polar(*, orientations, shape=(90, 360), normalize_histo=True)

Plot a polar projection of a spherical function.

This function uses the transformed trajectory.

Parameters:

  • orientations (Orientations) โ€“

    The unit vector trajectories

  • shape (tuple, default: (90, 360) ) โ€“

    The shape of the spherical sector in which the trajectory is plotted

  • normalize_histo (bool, default: True ) โ€“

    If True, normalize the histogram by the area of the bins, by default True

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/matplotlib/_polar.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def polar(
    *,
    orientations: Orientations,
    shape: tuple[int, int] = (90, 360),
    normalize_histo: bool = True,
) -> matplotlib.figure.Figure:
    """Plot a polar projection of a spherical function.

    This function uses the transformed trajectory.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    shape : tuple
        The shape of the spherical sector in which the trajectory is plotted
    normalize_histo : bool, optional
        If True, normalize the histogram by the area of the bins, by default True

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    from gemdat.orientations import calculate_spherical_areas

    az, el, _ = orientations.vectors_spherical.T
    az = az.flatten()
    el = el.flatten()

    hist, *_ = np.histogram2d(el, az, shape)

    if normalize_histo:
        areas = calculate_spherical_areas(shape)
        hist = hist / areas
        # Drop the bins at the poles where normalization is not possible
        hist = hist[1:-1, :]

    axis_theta, axis_phi = hist.shape

    phi = np.radians(np.linspace(0, 360, axis_phi))
    theta = np.linspace(0, 180, axis_theta)

    theta, phi = np.meshgrid(theta, phi)

    fig, (ax1, ax2) = plt.subplots(1, 2, subplot_kw=dict(projection='polar'))

    cs1 = ax1.contourf(phi, theta, hist.T)
    ax1.set_title('ฮธ < 90ยฐ')
    ax1.set_rmax(90)
    ax1.set_yticklabels([])

    ax2.contourf(phi, 180 - theta, hist.T)
    ax2.set_title('ฮธ > 90ยฐ')
    ax2.set_rmax(90)
    ax2.set_yticklabels([])

    fig.colorbar(cs1, ax=[ax1, ax2], orientation='horizontal', label='Areal Probability')

    plt.subplots_adjust(wspace=0.5, bottom=0.35)  # Increase horizontal spacing

    return fig

radial_distribution(rdfs)

Plot radial distribution function.

Parameters:

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_radial_distribution.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def radial_distribution(rdfs: Iterable[RDFData]) -> go.Figure:
    """Plot radial distribution function.

    Parameters
    ----------
    rdfs : Iterable[RDFData]
        List of RDF data to plot

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    for rdf in rdfs:
        fig.add_trace(
            go.Scatter(
                x=rdf.x,
                y=rdf.y,
                name=rdf.label,
                mode='lines',
                # line={'width': 0.25}
            )
        )

    states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
    state_suffix = f' ({states})' if states else ''

    fig.update_layout(
        title=f'Radial distribution function{state_suffix}',
        xaxis_title='Distance (ร…)',
        yaxis_title='g(r)',
    )

    return fig

rectilinear(*, orientations, shape=(90, 360), normalize_histo=True)

Plot a rectilinear projection of a spherical function.

This function uses the transformed trajectory.

Parameters:

  • orientations (Orientations) โ€“

    The unit vector trajectories

  • shape (tuple, default: (90, 360) ) โ€“

    The shape of the spherical sector in which the trajectory is plotted

  • normalize_histo (bool, default: True ) โ€“

    If True, normalize the histogram by the area of the bins, by default True

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_rectilinear.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def rectilinear(
    *,
    orientations: Orientations,
    shape: tuple[int, int] = (90, 360),
    normalize_histo: bool = True,
) -> go.Figure:
    """Plot a rectilinear projection of a spherical function.

    This function uses the transformed trajectory.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    shape : tuple
        The shape of the spherical sector in which the trajectory is plotted
    normalize_histo : bool, optional
        If True, normalize the histogram by the area of the bins, by default True

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    from gemdat.orientations import calculate_spherical_areas

    az, el, _ = orientations.vectors_spherical.T
    az = az.flatten()
    el = el.flatten()

    hist, *_ = np.histogram2d(el, az, shape)

    if normalize_histo:
        areas = calculate_spherical_areas(shape)
        hist = hist / areas
        # Drop the bins at the poles where normalization is not possible
        hist = hist[1:-1, :]

    axis_theta, axis_phi = hist.shape

    phi = np.linspace(0, 360, axis_phi)
    theta = np.linspace(0, 180, axis_theta)

    fig = go.Figure(
        data=go.Contour(
            x=phi,
            y=theta,
            z=hist,
            colorbar={
                'title': 'Areal probability',
                'title_side': 'right',
            },
        )
    )

    fig.update_layout(
        title='Rectilinear plot',
        xaxis_title='Azimuthal angle ฯ† (ยฐ)',
        yaxis_title='Elevation ฮธ (ยฐ)',
    )

    return fig

shape(shape, bins=50, sites=None, cmap=None)

Plot site cluster shapes.

Parameters:

  • shape (ShapeData) โ€“

    Shape data to plot

  • bins (int | Sequence[float], default: 50 ) โ€“

    Number of bins or sequence of bin edges. See hist() for more info.

  • sites (Collection[PeriodicSite] | None, default: None ) โ€“

    Plot these sites on the shape density

  • cmap (Colormap | str, default: None ) โ€“

    Colormap for the 2D histogram

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/matplotlib/_shape.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def shape(
    shape: ShapeData,
    bins: int | Sequence[float] = 50,
    sites: Collection[PeriodicSite] | None = None,
    cmap: Colormap | str | None = None,
) -> matplotlib.figure.Figure:
    """Plot site cluster shapes.

    Parameters
    ----------
    shape : ShapeData
        Shape data to plot
    bins : int | Sequence[float]
        Number of bins or sequence of bin edges.
        See [hist()][matplotlib.pyplot.hist] for more info.
    sites : Collection[PeriodicSite] | None
        Plot these sites on the shape density
    cmap : Colormap | str, optional
        Colormap for the 2D histogram

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    if cmap is None:
        cmap = density_matching_cmap_mpl()

    x_labels = ('X / ร…', 'Y / ร…', 'Z / ร…')
    y_labels = ('Y / ร…', 'Z / ร…', 'X / ร…')

    fig, axes = plt.subplots(
        nrows=2,
        ncols=3,
        sharex=True,
        figsize=(12, 5),
        gridspec_kw={'height_ratios': (4, 1)},
    )

    distances = shape.distances()

    R = np.mean(distances)
    std = np.std(distances)
    title = f'{shape.name}: R = {R:.3f}$~ร…$, std = {std:.3f}'

    mean_dist = np.mean(distances)

    _tmp_dict = defaultdict(list)
    vector_dict = {}
    if sites:
        for site in sites:
            _tmp_dict[site.label].append(site.coords)
        for key, values in _tmp_dict.items():
            vector_dict[key] = np.array(values) - shape.origin

    coords = shape.coords

    axes[0, 1].set_title(title)  # type: ignore

    for col, (i, j) in enumerate(((0, 1), (1, 2), (2, 0))):
        ax0 = axes[0, col]  # type: ignore
        ax1 = axes[1, col]  # type: ignore

        x_coords = coords[:, i]
        y_coords = coords[:, j]

        ax0.hist2d(x=x_coords, y=y_coords, bins=bins, cmap=cmap)
        ax0.set_ylabel(y_labels[col])

        circle = plt.Circle(
            (0, 0),
            mean_dist,
            color='r',
            linestyle='--',
            fill=False,
        )
        ax0.add_patch(circle)

        ax0.scatter(x=[0], y=[0], color='r', marker='.')

        for label, vects in vector_dict.items():
            x_vs = vects[:, i]
            y_vs = vects[:, j]

            for x, y in zip(x_vs, y_vs):
                ax0.text(x, y, s=label, color='r')

            ax0.scatter(x=x_vs, y=y_vs, color='r', marker='.', label=label)

        ax0.axis('equal')

        ax1.hist(x=x_coords, bins=bins, density=True)
        ax1.set_xlabel(x_labels[col])
        ax1.set_ylabel('density')

    fig.tight_layout()

    return fig

vibrational_amplitudes(*, trajectory, bins=50, n_parts=1)

Plot histogram of vibrational amplitudes with fitted Gaussian.

Parameters:

  • trajectory (Trajectory) โ€“

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) โ€“

    Number of parts for error analysis

Returns:

  • fig ( Figure ) โ€“

    Output figure

Source code in src/gemdat/plots/plotly/_vibrational_amplitudes.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def vibrational_amplitudes(
    *, trajectory: Trajectory, bins: int = 50, n_parts: int = 1
) -> go.Figure:
    """Plot histogram of vibrational amplitudes with fitted Gaussian.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    metrics = trajectory.metrics()

    trajectories = trajectory.split(n_parts)

    hist = _get_vibrational_amplitudes_hist(trajectories=trajectories, bins=bins)

    if n_parts == 1:
        fig = px.bar(hist.dataframe, x='center', y='count')
    else:
        fig = px.bar(hist.dataframe, x='center', y='count', error_y='std')

    x = np.linspace(hist.min_amp, hist.max_amp, 100) + hist.offset
    y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude())
    fig.add_trace(go.Scatter(x=x, y=y_gauss, name='Fitted Gaussian'))

    fig.update_layout(
        title='Histogram of vibrational amplitudes with fitted Gaussian',
        xaxis_title='Amplitude (ร…)',
        yaxis_title='Occurrence (a.u.)',
    )

    return fig