Skip to content

Plots

Gemdat contains several built-in plots for visualizing trajectories, jumps, transitions, and radial distribution functions.

These are collected in the plots module. The intended usage is that you import gemdat.plots like this:

from gemdat import plots

plots.displacement_per_element(trajectory)
plots.jumps_vs_distance(trajectory, sites)
plots.radial_distribution(rdfs)

All plotting functions take a gemdat.Trajectory, gemdat.Jumps, gemdat.Transitions, gemdat.rdf.RDFData or a combination as input. In addition, for some plots you have a few parameters to tune the output.

Trajectory and displacements plots

This module contains all the plots that Gemdat can generate.

displacement_per_atom(*, trajectory)

Plot displacement per atom.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_per_atom.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per atom.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    distances = [dist for dist in trajectory.distances_from_base_position()]

    for i, distance in enumerate(distances):
        fig.add_trace(
            go.Scatter(y=distance, name=i, mode='lines', line={'width': 1}, showlegend=False)
        )

    fig.update_layout(
        title='Displacement per atom',
        xaxis_title='Time step',
        yaxis_title='Displacement (Ã…)',
    )

    return fig

displacement_per_element(*, trajectory)

Plot displacement per element.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_per_element.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    displacements = _mean_displacements_per_element(trajectory)

    fig = go.Figure()

    for symbol, (mean, std) in displacements.items():
        fig.add_trace(
            go.Scatter(
                y=mean,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 3},
                legendgroup=symbol,
            )
        )
        fig.add_trace(
            go.Scatter(
                y=mean + std,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 0},
                legendgroup=symbol,
                showlegend=False,
            )
        )
        fig.add_trace(
            go.Scatter(
                y=mean - std,
                name=symbol + ' + std',
                mode='lines',
                line={'width': 0},
                legendgroup=symbol,
                showlegend=False,
                fill='tonexty',
            )
        )

    fig.update_layout(
        title='Displacement per element',
        xaxis_title='Time step',
        yaxis_title='Displacement (Ã…)',
    )

    return fig

displacement_histogram(trajectory, n_parts=1)

Plot histogram of total displacement at final timestep.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Plot error bars by dividing data into n parts

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacement_histogram.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def displacement_histogram(trajectory: Trajectory, n_parts: int = 1) -> go.Figure:
    """Plot histogram of total displacement at final timestep.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Plot error bars by dividing data into n parts

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    if n_parts == 1:
        df = _trajectory_to_dataframe(trajectory)

        fig = px.bar(df, x='Displacement', y='count', color='Element', barmode='stack')

        fig.update_layout(
            title='Displacement per element',
            xaxis_title='Displacement (Ã…)',
            yaxis_title='Nr. of atoms',
        )
    else:
        interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
        dfs = [_trajectory_to_dataframe(part) for part in trajectory.split(n_parts)]

        all_df = pd.concat(dfs)

        # Get the mean and standard deviation
        grouped = all_df.groupby(['Displacement', 'Element'])
        mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
        std = grouped.std().reset_index().rename(columns={'count': 'std'})
        df = mean.merge(std, how='inner')

        fig = px.bar(
            df,
            x='Displacement',
            y='mean',
            color='Element',
            error_y='std',
            barmode='group',
        )

        fig.update_layout(
            title=(
                f'Displacement per element after {int(interval[1] - interval[0])} timesteps'
            ),
            xaxis_title='Displacement (Ã…)',
            yaxis_title='Nr. of atoms',
        )

    return fig

Simulation metrics plots

This module contains all the plots that Gemdat can generate.

frequency_vs_occurence(*, trajectory)

Plot attempt frequency vs occurence.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_frequency_vs_occurence.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
    """Plot attempt frequency vs occurence.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : plotly.graph_objects.Figure.Figure
        Output figure
    """
    metrics = trajectory.metrics()
    speed = metrics.speed()

    length = speed.shape[1]
    half_length = length // 2 + 1

    trans = np.fft.fft(speed)

    two_sided = np.abs(trans / length)
    one_sided = two_sided[:, :half_length]

    fig = go.Figure()

    f = trajectory.sampling_frequency * np.arange(half_length) / length

    sum_freqs = np.sum(one_sided, axis=0)
    smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
    fig.add_trace(
        go.Scatter(
            y=smoothed,
            x=f,
            mode='lines',
            line={'width': 3, 'color': 'blue'},
            showlegend=False,
        )
    )

    y_max = np.max(sum_freqs)

    attempt_freq, attempt_freq_std = (float(i) for i in metrics.attempt_frequency())

    if attempt_freq:
        fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
    if attempt_freq and attempt_freq_std:
        fig.add_vline(
            x=attempt_freq + attempt_freq_std,
            line={'width': 2, 'color': 'red', 'dash': 'dash'},
        )
        fig.add_vline(
            x=attempt_freq - attempt_freq_std,
            line={'width': 2, 'color': 'red', 'dash': 'dash'},
        )

    fig.update_layout(
        title='Frequency vs Occurence',
        xaxis_title='Frequency (Hz)',
        yaxis_title='Occurrence (a.u.)',
        xaxis_range=[-0.1e13, 2.5e13],
        yaxis_range=[0, y_max],
        width=600,
        height=500,
    )

    return fig

vibrational_amplitudes(*, trajectory, bins=50, n_parts=1)

Plot histogram of vibrational amplitudes with fitted Gaussian.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_vibrational_amplitudes.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def vibrational_amplitudes(
    *, trajectory: Trajectory, bins: int = 50, n_parts: int = 1
) -> go.Figure:
    """Plot histogram of vibrational amplitudes with fitted Gaussian.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    metrics = trajectory.metrics()

    trajectories = trajectory.split(n_parts)

    hist = _get_vibrational_amplitudes_hist(trajectories=trajectories, bins=bins)

    if n_parts == 1:
        fig = px.bar(hist.dataframe, x='center', y='count')
    else:
        fig = px.bar(hist.dataframe, x='center', y='count', error_y='std')

    x = np.linspace(hist.min_amp, hist.max_amp, 100) + hist.offset
    y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude())
    fig.add_trace(go.Scatter(x=x, y=y_gauss, name='Fitted Gaussian'))

    fig.update_layout(
        title='Histogram of vibrational amplitudes with fitted Gaussian',
        xaxis_title='Amplitude (Ã…)',
        yaxis_title='Occurrence (a.u.)',
    )

    return fig

Jumps and transition plots

This module contains all the plots that Gemdat can generate.

jumps_vs_distance(*, jumps, jump_res=0.1, n_parts=1)

Plot jumps vs distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • jump_res (float, default: 0.1 ) –

    Resolution of the bins in Angstrom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_vs_distance.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def jumps_vs_distance(
    *,
    jumps: Jumps,
    jump_res: float = 0.1,
    n_parts: int = 1,
) -> go.Figure:
    """Plot jumps vs distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    jump_res : float, optional
        Resolution of the bins in Angstrom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)

    if n_parts == 1:
        fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
    else:
        fig = px.bar(df, x='Displacement', y='mean', error_y='std', barmode='stack')

    fig.update_layout(
        title='Jumps vs. Distance',
        xaxis_title='Distance (Ã…)',
        yaxis_title='Number of jumps',
    )

    return fig

jumps_vs_time(*, jumps, bins=8, n_parts=1)

Plot jumps vs distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • bins (int, default: 8 ) –

    Number of bins

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_vs_time.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def jumps_vs_time(*, jumps: Jumps, bins: int = 8, n_parts: int = 1) -> go.Figure:
    """Plot jumps vs distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    bins : int, optional
        Number of bins
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    maxlen = len(jumps.trajectory) / n_parts
    binsize = maxlen / bins + 1
    data = []

    for jumps_part in jumps.split(n_parts=n_parts):
        data.append(
            np.histogram(jumps_part.data['start time'], bins=bins, range=(0.0, maxlen))[0]
        )

    df = pd.DataFrame(data=data)
    columns = [binsize / 2 + binsize * col for col in range(bins)]

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std), columns=['time', 'count', 'std'])

    if n_parts > 1:
        fig = px.bar(df, x='time', y='count', error_y='std')
    else:
        fig = px.bar(df, x='time', y='count')

    fig.update_layout(
        bargap=0.2,
        title='Jumps vs. time',
        xaxis_title='Time (steps)',
        yaxis_title='Number of jumps',
    )

    return fig

collective_jumps(*, jumps)

Plot collective jumps per jump-type combination.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_collective_jumps.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def collective_jumps(*, jumps: Jumps) -> go.Figure:
    """Plot collective jumps per jump-type combination.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    collective = jumps.collective()
    matrix = collective.site_pair_count_matrix()

    fig = px.imshow(matrix)

    labels = collective.site_pair_count_matrix_labels()

    ticks = list(range(len(labels)))

    fig.update_layout(
        xaxis={'tickmode': 'array', 'tickvals': ticks, 'ticktext': labels},
        yaxis={'tickmode': 'array', 'tickvals': ticks, 'ticktext': labels},
        title='Cooperative jumps per jump-type combination',
    )

    return fig

jumps_3d(*, jumps)

Plot jumps in 3D.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps_3d.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def jumps_3d(*, jumps: Jumps) -> go.Figure:
    """Plot jumps in 3D.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    from ._plot3d import plot_3d

    return plot_3d(jumps=jumps, structure=jumps.sites)

jumps_3d_animation(*, jumps, t_start, t_stop, decay=0.05, skip=5, interval=20)

Plot jumps in 3D as an animation over time.

Parameters:

  • jumps (Jumps) –

    Input data

  • t_start (int) –

    Time step to start animation (relative to equilibration time)

  • t_stop (int) –

    Time step to stop animation (relative to equilibration time)

  • decay (float, default: 0.05 ) –

    Controls the decay of the line width (higher = faster decay)

  • skip (float, default: 5 ) –

    Skip frames (increase for faster, but less accurate rendering)

  • interval (int, default: 20 ) –

    Delay between frames in milliseconds.

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_jumps_3d_animation.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def jumps_3d_animation(
    *,
    jumps: Jumps,
    t_start: int,
    t_stop: int,
    decay: float = 0.05,
    skip: int = 5,
    interval: int = 20,
) -> animation.FuncAnimation:
    """Plot jumps in 3D as an animation over time.

    Parameters
    ----------
    jumps : Jumps
        Input data
    t_start : int
        Time step to start animation (relative to equilibration time)
    t_stop : int
        Time step to stop animation (relative to equilibration time)
    decay : float, optional
        Controls the decay of the line width (higher = faster decay)
    skip : float, optional
        Skip frames (increase for faster, but less accurate rendering)
    interval : int, optional
        Delay between frames in milliseconds.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    minwidth = 0.2
    maxwidth = 5.0

    trajectory = jumps.trajectory

    class LabelItems:
        def __init__(self, labels, coords):
            self.labels = labels
            self.coords = coords

        def items(self):
            yield from zip(self.labels, self.coords)

    coords = jumps.sites.frac_coords
    lattice = trajectory.get_lattice()

    color_from = colormaps['Set1'].colors  # type: ignore
    color_to = colormaps['Pastel1'].colors  # type: ignore

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    xyz_labels = LabelItems(
        'OABC',
        [
            [-0.1, -0.1, -0.1],
            [1.1, -0.1, -0.1],
            [-0.1, 1.1, -0.1],
            [-0.1, -0.1, 1.1],
        ],
    )

    plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)

    plotter.plot_labels(xyz_labels, lattice=lattice, ax=ax, color='green', size=12)

    assert len(ax.collections) == 0
    plotter.plot_points(coords, lattice=lattice, ax=ax, s=50, color='white', edgecolor='black')
    points = ax.collections

    events = jumps.data.sort_values('start time', ignore_index=True)

    for _, event in events.iterrows():
        site_i = event['start site']
        site_j = event['destination site']

        coord_i = coords[site_i]
        coord_j = coords[site_j]

        lw = 0

        _, image = lattice.get_distance_and_image(coord_i, coord_j)

        line = [coord_i, coord_j + image]

        plotter.plot_path(line, lattice=lattice, ax=ax, color='red', linewidth=lw)

    lines = ax.lines[3:]

    ax.set(
        title='Jumps between sites',
        xlabel="x' (Ã…)",
        ylabel="y' (Ã…)",
        zlabel="z' (Ã…)",
    )

    ax.set_aspect('equal')  # only auto is supported

    def update(frame_no):
        t_frame = t_start + (frame_no * skip)

        for i, event in events.iterrows():
            if event['start time'] > t_frame:
                break

            lw = max(maxwidth - decay * (t_frame - event['start time']), minwidth)

            line = lines[i]
            line.set_color('red')
            line.set_linewidth(lw)

            points[event['start site']].set_facecolor(
                color_from[event['atom index'] % len(color_from)]
            )
            points[event['destination site']].set_facecolor(
                color_to[event['atom index'] % len(color_to)]
            )

        start_time = event['start time']
        ax.set_title(f'T: {t_frame} | Next jump: {start_time}')

    n_frames = int((t_stop - t_start) / skip)

    return animation.FuncAnimation(
        fig=fig, func=update, frames=n_frames, interval=interval, repeat=False
    )

Radial distribution plots

This module contains all the plots that Gemdat can generate.

radial_distribution(rdfs)

Plot radial distribution function.

Parameters:

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_radial_distribution.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def radial_distribution(rdfs: Iterable[RDFData]) -> go.Figure:
    """Plot radial distribution function.

    Parameters
    ----------
    rdfs : Iterable[RDFData]
        List of RDF data to plot

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    fig = go.Figure()

    for rdf in rdfs:
        fig.add_trace(
            go.Scatter(
                x=rdf.x,
                y=rdf.y,
                name=rdf.label,
                mode='lines',
                # line={'width': 0.25}
            )
        )

    states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
    state_suffix = f' ({states})' if states else ''

    fig.update_layout(
        title=f'Radial distribution function{state_suffix}',
        xaxis_title='Distance (Ã…)',
        yaxis_title='g(r)',
    )

    return fig