Skip to content

Plots

Gemdat contains several built-in plots for visualizing trajectories, jumps, transitions, and radial distribution functions.

These are collected in the plots module. The intended usage is that you import gemdat.plots like this:

from gemdat import plots

plots.displacement_per_element(trajectory)
plots.jumps_vs_distance(trajectory, sites)
plots.radial_distribution(rdfs)

All plotting functions take a gemdat.Trajectory, gemdat.Jumps, gemdat.Transitions, gemdat.rdf.RDFData or a combination as input. In addition, for some plots you have a few parameters to tune the output.

Trajectory and displacements plots

This module contains all the plots that Gemdat can generate.

displacement_per_atom(*, trajectory)

Plot displacement per atom.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per atom.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig = go.Figure()

    distances = [dist for dist in trajectory.distances_from_base_position()]

    for i, distance in enumerate(distances):
        fig.add_trace(
            go.Scatter(y=distance,
                       name=i,
                       mode='lines',
                       line={'width': 1},
                       showlegend=False))

    fig.update_layout(title='Displacement per atom',
                      xaxis_title='Time step',
                      yaxis_title='Displacement (Angstrom)')

    return fig

displacement_per_element(*, trajectory)

Plot displacement per element.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig = go.Figure()

    grouped = defaultdict(list)

    species = trajectory.species

    for sp, distances in zip(species,
                             trajectory.distances_from_base_position()):
        grouped[sp.symbol].append(distances)

    for symbol, distances in grouped.items():
        mean_disp = np.mean(distances, axis=0)
        std_disp = np.std(distances, axis=0)
        fig.add_trace(
            go.Scatter(y=mean_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 3},
                       legendgroup=symbol))
        fig.add_trace(
            go.Scatter(y=mean_disp + std_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 0},
                       legendgroup=symbol,
                       showlegend=False))
        fig.add_trace(
            go.Scatter(y=mean_disp - std_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 0},
                       legendgroup=symbol,
                       showlegend=False,
                       fill='tonexty'))

    fig.update_layout(title='Displacement per element',
                      xaxis_title='Time step',
                      yaxis_title='Displacement (Angstrom)')

    return fig

displacement_histogram(trajectory, n_parts=1)

Plot histogram of total displacement at final timestep.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Plot error bars by dividing data into n parts

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def displacement_histogram(trajectory: Trajectory,
                           n_parts: int = 1) -> go.Figure:
    """Plot histogram of total displacement at final timestep.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Plot error bars by dividing data into n parts

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    if n_parts == 1:
        df = _trajectory_to_dataframe(trajectory)

        fig = px.bar(df,
                     x='Displacement',
                     y='count',
                     color='Element',
                     barmode='stack')

        fig.update_layout(title='Displacement per element',
                          xaxis_title='Displacement (Angstrom)',
                          yaxis_title='Nr. of atoms')
    else:
        interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
        dfs = [
            _trajectory_to_dataframe(part)
            for part in trajectory.split(n_parts)
        ]

        all_df = pd.concat(dfs)

        # Get the mean and standard deviation
        grouped = all_df.groupby(['Displacement', 'Element'])
        mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
        std = grouped.std().reset_index().rename(columns={'count': 'std'})
        df = mean.merge(std, how='inner')

        fig = px.bar(df,
                     x='Displacement',
                     y='mean',
                     color='Element',
                     error_y='std',
                     barmode='group')

        fig.update_layout(
            title=
            f'Displacement per element after {int(interval[1]-interval[0])} timesteps',
            xaxis_title='Displacement (Angstrom)',
            yaxis_title='Nr. of atoms')

    return fig

Simulation metrics plots

This module contains all the plots that Gemdat can generate.

frequency_vs_occurence(*, trajectory)

Plot attempt frequency vs occurence.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_vibration.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
    """Plot attempt frequency vs occurence.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : go.figure.Figure
        Output figure
    """
    metrics = SimulationMetrics(trajectory)
    speed = metrics.speed()

    length = speed.shape[1]
    half_length = length // 2 + 1

    trans = np.fft.fft(speed)

    two_sided = np.abs(trans / length)
    one_sided = two_sided[:, :half_length]

    fig = go.Figure()

    f = trajectory.sampling_frequency * np.arange(half_length) / length

    sum_freqs = np.sum(one_sided, axis=0)
    smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
    fig.add_trace(
        go.Scatter(y=smoothed,
                   x=f,
                   mode='lines',
                   line={
                       'width': 3,
                       'color': 'blue'
                   },
                   showlegend=False))

    y_max = np.max(sum_freqs)

    attempt_freq, attempt_freq_std = metrics.attempt_frequency()

    if attempt_freq:
        fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
    if attempt_freq and attempt_freq_std:
        fig.add_vline(x=attempt_freq + attempt_freq_std,
                      line={
                          'width': 2,
                          'color': 'red',
                          'dash': 'dash'
                      })
        fig.add_vline(x=attempt_freq - attempt_freq_std,
                      line={
                          'width': 2,
                          'color': 'red',
                          'dash': 'dash'
                      })

    fig.update_layout(title='Frequency vs Occurence',
                      xaxis_title='Frequency (Hz)',
                      yaxis_title='Occurrence (a.u.)',
                      xaxis_range=[-0.1e13, 2.5e13],
                      yaxis_range=[0, y_max],
                      width=600,
                      height=500)

    return fig

vibrational_amplitudes(*, trajectory, bins=50, n_parts=1)

Plot histogram of vibrational amplitudes with fitted Gaussian.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_vibration.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def vibrational_amplitudes(*,
                           trajectory: Trajectory,
                           bins: int = 50,
                           n_parts: int = 1) -> go.Figure:
    """Plot histogram of vibrational amplitudes with fitted Gaussian.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    trajectories = trajectory.split(n_parts)
    single_metrics = SimulationMetrics(trajectory)
    metrics = [
        SimulationMetrics(trajectory).amplitudes()
        for trajectory in trajectories
    ]

    max_amp = max(max(metric) for metric in metrics)
    min_amp = min(min(metric) for metric in metrics)

    max_amp = max(abs(min_amp), max_amp)
    min_amp = -max_amp

    data = []

    for metric in metrics:
        data.append(
            np.histogram(metric,
                         bins=bins,
                         range=(min_amp, max_amp),
                         density=True)[0])

    df = pd.DataFrame(data=data)

    # offset to middle of bar
    offset = (max_amp - min_amp) / (bins * 2)

    columns = np.linspace(min_amp + offset,
                          max_amp + offset,
                          bins,
                          endpoint=False)

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std),
                      columns=['amplitude', 'count', 'std'])

    if n_parts == 1:
        fig = px.bar(df, x='amplitude', y='count')
    else:
        fig = px.bar(df, x='amplitude', y='count', error_y='std')

    x = np.linspace(min_amp, max_amp, 100)
    y_gauss = stats.norm.pdf(x, 0, single_metrics.vibration_amplitude())
    fig.add_trace(go.Scatter(x=x, y=y_gauss, name='Fitted Gaussian'))

    fig.update_layout(
        title='Histogram of vibrational amplitudes with fitted Gaussian',
        xaxis_title='Amplitude (Ã…ngstrom)',
        yaxis_title='Occurrence (a.u.)')

    return fig

Jumps and transition plots

This module contains all the plots that Gemdat can generate.

jumps_vs_distance(*, jumps, jump_res=0.1, n_parts=1)

Plot jumps vs. distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • jump_res (float, default: 0.1 ) –

    Resolution of the bins in Angstrom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def jumps_vs_distance(*,
                      jumps: Jumps,
                      jump_res: float = 0.1,
                      n_parts: int = 1) -> go.Figure:
    """Plot jumps vs. distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    jump_res : float, optional
        Resolution of the bins in Angstrom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    sites = jumps.sites
    trajectory = jumps.trajectory
    lattice = trajectory.get_lattice()

    pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)

    bin_max = (1 + pdist.max() // jump_res) * jump_res
    n_bins = int(bin_max / jump_res) + 1
    x = np.linspace(0, bin_max, n_bins)

    bin_idx = np.digitize(pdist, bins=x)
    data = []
    for transitions_part in jumps.split(n_parts=n_parts):
        counts = np.zeros_like(x)
        for idx, n in zip(bin_idx.flatten(),
                          transitions_part.matrix().flatten()):
            counts[idx] += n
        for idx in range(n_bins):
            if counts[idx] > 0:
                data.append((x[idx], counts[idx]))

    df = pd.DataFrame(data=data, columns=['Displacement', 'count'])

    grouped = df.groupby(['Displacement'])
    mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
    std = grouped.std().reset_index().rename(columns={'count': 'std'})
    df = mean.merge(std, how='inner')

    if n_parts == 1:
        fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
    else:
        fig = px.bar(df,
                     x='Displacement',
                     y='mean',
                     error_y='std',
                     barmode='stack')

    fig.update_layout(title='Jumps vs. Distance',
                      xaxis_title='Distance (Angstrom)',
                      yaxis_title='Number of jumps')

    return fig

jumps_vs_time(*, jumps, bins=8, n_parts=1)

Plot jumps vs. distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • bins (int, default: 8 ) –

    Number of bins

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def jumps_vs_time(*,
                  jumps: Jumps,
                  bins: int = 8,
                  n_parts: int = 1) -> go.Figure:
    """Plot jumps vs. distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    bins : int, optional
        Number of bins
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    maxlen = len(jumps.trajectory) / n_parts
    binsize = maxlen / bins + 1
    data = []

    for jumps_part in jumps.split(n_parts=n_parts):
        data.append(
            np.histogram(jumps_part.data['start time'],
                         bins=bins,
                         range=(0., maxlen))[0])

    df = pd.DataFrame(data=data)
    columns = [binsize / 2 + binsize * col for col in range(bins)]

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std),
                      columns=['time', 'count', 'std'])

    if n_parts > 1:
        fig = px.bar(df, x='time', y='count', error_y='std')
    else:
        fig = px.bar(df, x='time', y='count')

    fig.update_layout(bargap=0.2,
                      title='Jumps vs. time',
                      xaxis_title='Time (steps)',
                      yaxis_title='Number of jumps')

    return fig

collective_jumps(*, jumps)

Plot collective jumps per jump-type combination.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def collective_jumps(*, jumps: Jumps) -> go.Figure:
    """Plot collective jumps per jump-type combination.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """

    matrix = jumps.collective().site_pair_count_matrix()
    labels = jumps.collective().site_pair_count_matrix_labels()

    fig = px.imshow(matrix)

    ticks = list(range(len(labels)))

    fig.update_layout(xaxis={
        'tickmode': 'array',
        'tickvals': ticks,
        'ticktext': labels
    },
                      yaxis={
                          'tickmode': 'array',
                          'tickvals': ticks,
                          'ticktext': labels
                      },
                      title='Cooperative jumps per jump-type combination')

    return fig

jumps_3d(*, jumps)

Plot jumps in 3D.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def jumps_3d(*, jumps: Jumps) -> go.Figure:
    """Plot jumps in 3D.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    from ._plot3d import plot_3d
    return plot_3d(jumps=jumps, structure=jumps.sites)

jumps_3d_animation(*, jumps, t_start, t_stop, decay=0.05, skip=5, interval=20)

Plot jumps in 3D as an animation over time.

Parameters:

  • jumps (Jumps) –

    Input data

  • t_start (int) –

    Time step to start animation (relative to equilibration time)

  • t_stop (int) –

    Time step to stop animation (relative to equilibration time)

  • decay (float, default: 0.05 ) –

    Controls the decay of the line width (higher = faster decay)

  • skip (float, default: 5 ) –

    Skip frames (increase for faster, but less accurate rendering)

  • interval (int, default: 20 ) –

    Delay between frames in milliseconds.

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_jumps.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def jumps_3d_animation(
    *,
    jumps: Jumps,
    t_start: int,
    t_stop: int,
    decay: float = 0.05,
    skip: int = 5,
    interval: int = 20,
) -> animation.FuncAnimation:
    """Plot jumps in 3D as an animation over time.

    Parameters
    ----------
    jumps : Jumps
        Input data
    t_start : int
        Time step to start animation (relative to equilibration time)
    t_stop : int
        Time step to stop animation (relative to equilibration time)
    decay : float, optional
        Controls the decay of the line width (higher = faster decay)
    skip : float, optional
        Skip frames (increase for faster, but less accurate rendering)
    interval : int, optional
        Delay between frames in milliseconds.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    minwidth = 0.2
    maxwidth = 5.0

    trajectory = jumps.trajectory

    class LabelItems:

        def __init__(self, labels, coords):
            self.labels = labels
            self.coords = coords

        def items(self):
            yield from zip(self.labels, self.coords)

    coords = jumps.sites.frac_coords
    lattice = trajectory.get_lattice()

    color_from = colormaps['Set1'].colors  # type: ignore
    color_to = colormaps['Pastel1'].colors  # type: ignore

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
                                     [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])

    plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)

    plotter.plot_labels(xyz_labels,
                        lattice=lattice,
                        ax=ax,
                        color='green',
                        size=12)

    assert len(ax.collections) == 0
    plotter.plot_points(coords,
                        lattice=lattice,
                        ax=ax,
                        s=50,
                        color='white',
                        edgecolor='black')
    points = ax.collections

    events = jumps.data.sort_values('start time', ignore_index=True)

    for _, event in events.iterrows():
        site_i = event['start site']
        site_j = event['destination site']

        coord_i = coords[site_i]
        coord_j = coords[site_j]

        lw = 0

        _, image = lattice.get_distance_and_image(coord_i, coord_j)

        line = [coord_i, coord_j + image]

        plotter.plot_path(line,
                          lattice=lattice,
                          ax=ax,
                          color='red',
                          linewidth=lw)

    lines = ax.lines[3:]

    ax.set(
        title='Jumps between sites',
        xlabel="x' (ang)",
        ylabel="y' (ang)",
        zlabel="z' (ang)",
    )

    ax.set_aspect('equal')  # only auto is supported

    def update(frame_no):
        t_frame = t_start + (frame_no * skip)

        for i, event in events.iterrows():

            if event['start time'] > t_frame:
                break

            lw = max(maxwidth - decay * (t_frame - event['start time']),
                     minwidth)

            line = lines[i]
            line.set_color('red')
            line.set_linewidth(lw)

            points[event['start site']].set_facecolor(
                color_from[event['atom index'] % len(color_from)])
            points[event['destination site']].set_facecolor(
                color_to[event['atom index'] % len(color_to)])

        start_time = event['start time']
        ax.set_title(f'T: {t_frame} | Next jump: {start_time}')

    n_frames = int((t_stop - t_start) / skip)

    return animation.FuncAnimation(fig=fig,
                                   func=update,
                                   frames=n_frames,
                                   interval=interval,
                                   repeat=False)

Radial distribution plots

This module contains all the plots that Gemdat can generate.

radial_distribution(rdfs)

Plot radial distribution function.

Parameters:

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_rdf.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure:
    """Plot radial distribution function.

    Parameters
    ----------
    rdfs : Iterable[RDFData]
        List of RDF data to plot

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    fig, ax = plt.subplots()

    for rdf in rdfs:
        ax.plot(rdf.x, rdf.y, label=rdf.symbol)

    states = ', '.join({rdf.state for rdf in rdfs})

    ax.legend()
    ax.set(title=f'Radial distribution function ({states})',
           xlabel='Distance (Ang)',
           ylabel='Counts')

    return fig