Skip to content

gemdat.plots

This module contains all the plots that Gemdat can generate.

autocorrelation(*, orientations)

Plot the autocorrelation function of the unit vectors series.

Parameters:

  • orientations (Orientations) –

    The unit vector trajectories

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_orientations.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def autocorrelation(
    *,
    orientations: Orientations,
) -> plt.Figure:
    """Plot the autocorrelation function of the unit vectors series.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    ac = orientations.autocorrelation()
    ac_std = ac.std(axis=0)
    ac_mean = ac.mean(axis=0)

    # Since we want to plot in picosecond, we convert the time units
    time_ps = orientations._time_step * 1e12
    tgrid = np.arange(ac_mean.shape[0]) * time_ps

    # and now we can plot the autocorrelation function
    fig, ax = plt.subplots()

    ax.plot(tgrid, ac_mean, label='FFT-Autocorrelation')
    ax.fill_between(tgrid, ac_mean - ac_std, ac_mean + ac_std, alpha=0.2)
    ax.set_xlabel('Time lag [ps]')
    ax.set_ylabel('Autocorrelation')

    return fig

bond_length_distribution(*, orientations, bins=1000)

Plot the bond length probability distribution.

Parameters:

  • orientations (Orientations) –

    The unit vector trajectories

  • bins (int, default: 1000 ) –

    The number of bins, by default 1000

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_orientations.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def bond_length_distribution(*,
                             orientations: Orientations,
                             bins: int = 1000) -> plt.Figure:
    """Plot the bond length probability distribution.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    bins : int, optional
        The number of bins, by default 1000

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    *_, bond_lengths = orientations.vectors_spherical.T
    bond_lengths = bond_lengths.flatten()

    fig, ax = plt.subplots()

    # Plot the normalized histogram
    hist, edges = np.histogram(bond_lengths, bins=bins, density=True)
    bin_centers = (edges[:-1] + edges[1:]) / 2

    # Fit a skewed Gaussian distribution to the orientations
    params, covariance = curve_fit(skewnorm.pdf,
                                   bin_centers,
                                   hist,
                                   p0=[1.5, 1, 1.5])

    # Create a new function using the fitted parameters
    def _skewnorm_fit(x):
        return skewnorm.pdf(x, *params)

    # Plot the histogram
    ax.hist(bond_lengths,
            bins=bins,
            density=True,
            color='blue',
            alpha=0.7,
            label='Data')

    # Plot the fitted skewed Gaussian distribution
    x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000)
    ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit')

    ax.set_xlabel(r'Bond length $[\AA]$')
    ax.set_ylabel(r'Probability density $[\AA^{-1}]$')
    ax.set_title('Bond Length Probability Distribution')
    ax.legend()
    ax.grid(True)

    return fig

collective_jumps(*, jumps)

Plot collective jumps per jump-type combination.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def collective_jumps(*, jumps: Jumps) -> go.Figure:
    """Plot collective jumps per jump-type combination.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """

    matrix = jumps.collective().site_pair_count_matrix()
    labels = jumps.collective().site_pair_count_matrix_labels()

    fig = px.imshow(matrix)

    ticks = list(range(len(labels)))

    fig.update_layout(xaxis={
        'tickmode': 'array',
        'tickvals': ticks,
        'ticktext': labels
    },
                      yaxis={
                          'tickmode': 'array',
                          'tickvals': ticks,
                          'ticktext': labels
                      },
                      title='Cooperative jumps per jump-type combination')

    return fig

density(volume, *, structure=None, force_lattice=None)

Create density plot from volume and structure.

Uses plotly as plotting backend.

Arguments

volume : Volume Input volume structure : Structure, optional Input structure force_lattice : Lattice | None Plot volume and structure using this lattice as a basis. Overrides the default, which is to use volume.lattice and structure.lattice where applicable.

Returns:

  • fig ( Figure ) –

    Output as plotly figure

Source code in src/gemdat/plots/plotly/_density.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def density(
    volume: Volume,
    *,
    structure: Structure | None = None,
    force_lattice: Lattice | None = None,
) -> go.Figure:
    """Create density plot from volume and structure.

    Uses plotly as plotting backend.

    Arguments
    ---------
    volume : Volume
        Input volume
    structure : Structure, optional
        Input structure
    force_lattice : Lattice | None
        Plot volume and structure using this lattice as a basis. Overrides the default, which is to
        use `volume.lattice` and `structure.lattice` where applicable.

    Returns
    -------
    fig : go.Figure
        Output as plotly figure
    """
    from ._plot3d import plot_3d
    return plot_3d(volume=volume, structure=structure, lattice=force_lattice)

displacement_histogram(trajectory, n_parts=1)

Plot histogram of total displacement at final timestep.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Plot error bars by dividing data into n parts

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def displacement_histogram(trajectory: Trajectory,
                           n_parts: int = 1) -> go.Figure:
    """Plot histogram of total displacement at final timestep.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Plot error bars by dividing data into n parts

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    if n_parts == 1:
        df = _trajectory_to_dataframe(trajectory)

        fig = px.bar(df,
                     x='Displacement',
                     y='count',
                     color='Element',
                     barmode='stack')

        fig.update_layout(title='Displacement per element',
                          xaxis_title='Displacement (Angstrom)',
                          yaxis_title='Nr. of atoms')
    else:
        interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
        dfs = [
            _trajectory_to_dataframe(part)
            for part in trajectory.split(n_parts)
        ]

        all_df = pd.concat(dfs)

        # Get the mean and standard deviation
        grouped = all_df.groupby(['Displacement', 'Element'])
        mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
        std = grouped.std().reset_index().rename(columns={'count': 'std'})
        df = mean.merge(std, how='inner')

        fig = px.bar(df,
                     x='Displacement',
                     y='mean',
                     color='Element',
                     error_y='std',
                     barmode='group')

        fig.update_layout(
            title=
            f'Displacement per element after {int(interval[1]-interval[0])} timesteps',
            xaxis_title='Displacement (Angstrom)',
            yaxis_title='Nr. of atoms')

    return fig

displacement_per_atom(*, trajectory)

Plot displacement per atom.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per atom.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig = go.Figure()

    distances = [dist for dist in trajectory.distances_from_base_position()]

    for i, distance in enumerate(distances):
        fig.add_trace(
            go.Scatter(y=distance,
                       name=i,
                       mode='lines',
                       line={'width': 1},
                       showlegend=False))

    fig.update_layout(title='Displacement per atom',
                      xaxis_title='Time step',
                      yaxis_title='Displacement (Angstrom)')

    return fig

displacement_per_element(*, trajectory)

Plot displacement per element.

Parameters:

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig = go.Figure()

    grouped = defaultdict(list)

    species = trajectory.species

    for sp, distances in zip(species,
                             trajectory.distances_from_base_position()):
        grouped[sp.symbol].append(distances)

    for symbol, distances in grouped.items():
        mean_disp = np.mean(distances, axis=0)
        std_disp = np.std(distances, axis=0)
        fig.add_trace(
            go.Scatter(y=mean_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 3},
                       legendgroup=symbol))
        fig.add_trace(
            go.Scatter(y=mean_disp + std_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 0},
                       legendgroup=symbol,
                       showlegend=False))
        fig.add_trace(
            go.Scatter(y=mean_disp - std_disp,
                       name=symbol + ' + std',
                       mode='lines',
                       line={'width': 0},
                       legendgroup=symbol,
                       showlegend=False,
                       fill='tonexty'))

    fig.update_layout(title='Displacement per element',
                      xaxis_title='Time step',
                      yaxis_title='Displacement (Angstrom)')

    return fig

energy_along_path(path, *, structure, other_paths=None)

Plot energy along specified path.

Parameters:

  • path (Pathway) –

    Pathway object containing the energy along the path

  • structure (Structure) –

    Structure object to get the site information

  • other_paths (Pathway | list[Pathway], default: None ) –

    Optional list of alternative paths to plot

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_paths.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def energy_along_path(
    path: Pathway,
    *,
    structure: Structure,
    other_paths: list[Pathway] | None = None,
) -> plt.Figure:
    """Plot energy along specified path.

    Parameters
    ----------
    path : Pathway
        Pathway object containing the energy along the path
    structure : Structure
        Structure object to get the site information
    other_paths : Pathway | list[Pathway]
        Optional list of alternative paths to plot

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig, ax = plt.subplots(figsize=(8, 4))

    ax.plot(path.energy, marker='o', color='r', label='Optimal path')
    ax.set(ylabel='Free energy [eV]')

    nearest_sites = path.path_over_structure(structure)

    # Create costum labels for the x axis to avoid consecutive repetitions
    site_xlabel = []
    sitecoord_xlabel = []

    prev = nearest_sites[0]
    for i, site in enumerate(nearest_sites):
        # only non repeated labels will get an entry
        if (site.coords != prev.coords).any() or i == 0:
            sitecoord_xlabel.append(', '.join(f'{val:.1f}'
                                              for val in site.coords))
            site_xlabel.append(site.label)
        else:
            sitecoord_xlabel.append('')
            site_xlabel.append('')

        prev = site

    non_empty_ticks = [
        i for i, label in enumerate(sitecoord_xlabel) if label != ''
    ]

    extra_ticks = non_empty_ticks.copy()
    extra_ticks.append(ax.get_xlim()[1])
    centered_ticks = [(extra_ticks[i] + extra_ticks[i + 1]) / 2
                      for i in range(len(extra_ticks) - 1)]

    ax.set_xticks(centered_ticks)
    ax.set_xticklabels([sitecoord_xlabel[i] for i in non_empty_ticks],
                       rotation=45)

    # Change background color alternatively for different sites
    for i in range(0, len(non_empty_ticks), 2):
        if i + 1 < len(non_empty_ticks):
            ax.axvspan(non_empty_ticks[i],
                       non_empty_ticks[i + 1],
                       facecolor='lightgray',
                       edgecolor='none')
        else:
            ax.axvspan(non_empty_ticks[i],
                       max(non_empty_ticks),
                       facecolor='lightgray',
                       edgecolor='none')

    # and add on top the site labels
    ax_up = ax.twiny()
    ax_up.set_xlim(ax.get_xlim())
    ax_up.set_xticks(centered_ticks)
    ax_up.set_xticklabels([site_xlabel[i] for i in non_empty_ticks],
                          rotation=45)
    ax_up.get_yaxis().set_visible(False)

    # If available, plot the other pathways
    if other_paths:
        for idx, path in enumerate(other_paths):
            if path.energy is None:
                raise ValueError('Pathway does not contain energy data')
            ax.plot(range(len(path.energy)),
                    path.energy,
                    label=f'Alternative {idx+1}')
        ax.legend(fontsize=8)

    return fig

frequency_vs_occurence(*, trajectory)

Plot attempt frequency vs occurence.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_vibration.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
    """Plot attempt frequency vs occurence.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom

    Returns
    -------
    fig : go.figure.Figure
        Output figure
    """
    metrics = SimulationMetrics(trajectory)
    speed = metrics.speed()

    length = speed.shape[1]
    half_length = length // 2 + 1

    trans = np.fft.fft(speed)

    two_sided = np.abs(trans / length)
    one_sided = two_sided[:, :half_length]

    fig = go.Figure()

    f = trajectory.sampling_frequency * np.arange(half_length) / length

    sum_freqs = np.sum(one_sided, axis=0)
    smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
    fig.add_trace(
        go.Scatter(y=smoothed,
                   x=f,
                   mode='lines',
                   line={
                       'width': 3,
                       'color': 'blue'
                   },
                   showlegend=False))

    y_max = np.max(sum_freqs)

    attempt_freq, attempt_freq_std = metrics.attempt_frequency()

    if attempt_freq:
        fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
    if attempt_freq and attempt_freq_std:
        fig.add_vline(x=attempt_freq + attempt_freq_std,
                      line={
                          'width': 2,
                          'color': 'red',
                          'dash': 'dash'
                      })
        fig.add_vline(x=attempt_freq - attempt_freq_std,
                      line={
                          'width': 2,
                          'color': 'red',
                          'dash': 'dash'
                      })

    fig.update_layout(title='Frequency vs Occurence',
                      xaxis_title='Frequency (Hz)',
                      yaxis_title='Occurrence (a.u.)',
                      xaxis_range=[-0.1e13, 2.5e13],
                      yaxis_range=[0, y_max],
                      width=600,
                      height=500)

    return fig

jumps_3d(*, jumps)

Plot jumps in 3D.

Parameters:

  • jumps (Jumps) –

    Input data

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def jumps_3d(*, jumps: Jumps) -> go.Figure:
    """Plot jumps in 3D.

    Parameters
    ----------
    jumps : Jumps
        Input data

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    from ._plot3d import plot_3d
    return plot_3d(jumps=jumps, structure=jumps.sites)

jumps_3d_animation(*, jumps, t_start, t_stop, decay=0.05, skip=5, interval=20)

Plot jumps in 3D as an animation over time.

Parameters:

  • jumps (Jumps) –

    Input data

  • t_start (int) –

    Time step to start animation (relative to equilibration time)

  • t_stop (int) –

    Time step to stop animation (relative to equilibration time)

  • decay (float, default: 0.05 ) –

    Controls the decay of the line width (higher = faster decay)

  • skip (float, default: 5 ) –

    Skip frames (increase for faster, but less accurate rendering)

  • interval (int, default: 20 ) –

    Delay between frames in milliseconds.

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_jumps.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def jumps_3d_animation(
    *,
    jumps: Jumps,
    t_start: int,
    t_stop: int,
    decay: float = 0.05,
    skip: int = 5,
    interval: int = 20,
) -> animation.FuncAnimation:
    """Plot jumps in 3D as an animation over time.

    Parameters
    ----------
    jumps : Jumps
        Input data
    t_start : int
        Time step to start animation (relative to equilibration time)
    t_stop : int
        Time step to stop animation (relative to equilibration time)
    decay : float, optional
        Controls the decay of the line width (higher = faster decay)
    skip : float, optional
        Skip frames (increase for faster, but less accurate rendering)
    interval : int, optional
        Delay between frames in milliseconds.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    minwidth = 0.2
    maxwidth = 5.0

    trajectory = jumps.trajectory

    class LabelItems:

        def __init__(self, labels, coords):
            self.labels = labels
            self.coords = coords

        def items(self):
            yield from zip(self.labels, self.coords)

    coords = jumps.sites.frac_coords
    lattice = trajectory.get_lattice()

    color_from = colormaps['Set1'].colors  # type: ignore
    color_to = colormaps['Pastel1'].colors  # type: ignore

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
                                     [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])

    plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)

    plotter.plot_labels(xyz_labels,
                        lattice=lattice,
                        ax=ax,
                        color='green',
                        size=12)

    assert len(ax.collections) == 0
    plotter.plot_points(coords,
                        lattice=lattice,
                        ax=ax,
                        s=50,
                        color='white',
                        edgecolor='black')
    points = ax.collections

    events = jumps.data.sort_values('start time', ignore_index=True)

    for _, event in events.iterrows():
        site_i = event['start site']
        site_j = event['destination site']

        coord_i = coords[site_i]
        coord_j = coords[site_j]

        lw = 0

        _, image = lattice.get_distance_and_image(coord_i, coord_j)

        line = [coord_i, coord_j + image]

        plotter.plot_path(line,
                          lattice=lattice,
                          ax=ax,
                          color='red',
                          linewidth=lw)

    lines = ax.lines[3:]

    ax.set(
        title='Jumps between sites',
        xlabel="x' (ang)",
        ylabel="y' (ang)",
        zlabel="z' (ang)",
    )

    ax.set_aspect('equal')  # only auto is supported

    def update(frame_no):
        t_frame = t_start + (frame_no * skip)

        for i, event in events.iterrows():

            if event['start time'] > t_frame:
                break

            lw = max(maxwidth - decay * (t_frame - event['start time']),
                     minwidth)

            line = lines[i]
            line.set_color('red')
            line.set_linewidth(lw)

            points[event['start site']].set_facecolor(
                color_from[event['atom index'] % len(color_from)])
            points[event['destination site']].set_facecolor(
                color_to[event['atom index'] % len(color_to)])

        start_time = event['start time']
        ax.set_title(f'T: {t_frame} | Next jump: {start_time}')

    n_frames = int((t_stop - t_start) / skip)

    return animation.FuncAnimation(fig=fig,
                                   func=update,
                                   frames=n_frames,
                                   interval=interval,
                                   repeat=False)

jumps_vs_distance(*, jumps, jump_res=0.1, n_parts=1)

Plot jumps vs. distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • jump_res (float, default: 0.1 ) –

    Resolution of the bins in Angstrom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def jumps_vs_distance(*,
                      jumps: Jumps,
                      jump_res: float = 0.1,
                      n_parts: int = 1) -> go.Figure:
    """Plot jumps vs. distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    jump_res : float, optional
        Resolution of the bins in Angstrom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : plotly.graph_objects.Figure
        Output figure
    """
    sites = jumps.sites
    trajectory = jumps.trajectory
    lattice = trajectory.get_lattice()

    pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)

    bin_max = (1 + pdist.max() // jump_res) * jump_res
    n_bins = int(bin_max / jump_res) + 1
    x = np.linspace(0, bin_max, n_bins)

    bin_idx = np.digitize(pdist, bins=x)
    data = []
    for transitions_part in jumps.split(n_parts=n_parts):
        counts = np.zeros_like(x)
        for idx, n in zip(bin_idx.flatten(),
                          transitions_part.matrix().flatten()):
            counts[idx] += n
        for idx in range(n_bins):
            if counts[idx] > 0:
                data.append((x[idx], counts[idx]))

    df = pd.DataFrame(data=data, columns=['Displacement', 'count'])

    grouped = df.groupby(['Displacement'])
    mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
    std = grouped.std().reset_index().rename(columns={'count': 'std'})
    df = mean.merge(std, how='inner')

    if n_parts == 1:
        fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
    else:
        fig = px.bar(df,
                     x='Displacement',
                     y='mean',
                     error_y='std',
                     barmode='stack')

    fig.update_layout(title='Jumps vs. Distance',
                      xaxis_title='Distance (Angstrom)',
                      yaxis_title='Number of jumps')

    return fig

jumps_vs_time(*, jumps, bins=8, n_parts=1)

Plot jumps vs. distance histogram.

Parameters:

  • jumps (Jumps) –

    Input jumps data

  • bins (int, default: 8 ) –

    Number of bins

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_jumps.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def jumps_vs_time(*,
                  jumps: Jumps,
                  bins: int = 8,
                  n_parts: int = 1) -> go.Figure:
    """Plot jumps vs. distance histogram.

    Parameters
    ----------
    jumps : Jumps
        Input jumps data
    bins : int, optional
        Number of bins
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    maxlen = len(jumps.trajectory) / n_parts
    binsize = maxlen / bins + 1
    data = []

    for jumps_part in jumps.split(n_parts=n_parts):
        data.append(
            np.histogram(jumps_part.data['start time'],
                         bins=bins,
                         range=(0., maxlen))[0])

    df = pd.DataFrame(data=data)
    columns = [binsize / 2 + binsize * col for col in range(bins)]

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std),
                      columns=['time', 'count', 'std'])

    if n_parts > 1:
        fig = px.bar(df, x='time', y='count', error_y='std')
    else:
        fig = px.bar(df, x='time', y='count')

    fig.update_layout(bargap=0.2,
                      title='Jumps vs. time',
                      xaxis_title='Time (steps)',
                      yaxis_title='Number of jumps')

    return fig

msd_per_element(*, trajectory)

Plot mean squared displacement per element.

Parameters:

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_displacements.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
    """Plot mean squared displacement per element.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    fig = go.Figure()

    species = list(set(trajectory.species))

    # Since we want to plot in picosecond, we convert the time units
    time_ps = trajectory.time_step * 1e12

    for sp in species:
        traj = trajectory.filter(sp.symbol)

        msd = traj.mean_squared_displacement()
        msd_mean = np.mean(msd, axis=0)
        msd_std = np.std(msd, axis=0)
        t_values = np.arange(len(msd_mean)) * time_ps

        fig.add_trace(
            go.Scatter(x=t_values,
                       y=msd_mean,
                       error_y=dict(type='data',
                                    array=msd_std,
                                    width=0.1,
                                    thickness=0.1),
                       name=sp.symbol,
                       mode='lines',
                       line={'width': 3},
                       legendgroup=sp.symbol))

    fig.update_layout(title='Mean squared displacement per element',
                      xaxis_title='Time lag [ps]',
                      yaxis_title='MSD (Angstrom<sup>2</sup>)')

    return fig

path_on_grid(path)

Plot the 3d coordinates of the points that define a path.

Parameters:

  • path (Pathway) –

    Pathway to plot

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_paths.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def path_on_grid(path: Pathway) -> plt.Figure:
    """Plot the 3d coordinates of the points that define a path.

    Parameters
    ----------
    path : Pathway
        Pathway to plot

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    # Create a colormap to visualize the path
    colormap = plt.get_cmap()
    normalize = plt.Normalize(0, len(path.energy))

    fig, ax = plt.subplots()
    ax = fig.add_subplot(111, projection='3d')

    path_x, path_y, path_z = zip(*path.sites)

    for i in range(len(path.energy) - 1):
        ax.plot(path_x[i:i + 1],
                path_y[i:i + 1],
                path_z[i:i + 1],
                color=colormap(normalize(i)),
                marker='o',
                linestyle='-')

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    sm = plt.cm.ScalarMappable(cmap=colormap, norm=normalize)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label('Steps')

    return fig

plot_3d(*, volume=None, structure=None, paths=None, jumps=None, lattice=None, title='3D plot')

Plot 3d.

Source code in src/gemdat/plots/plotly/_plot3d.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def plot_3d(*,
            volume: Volume | None = None,
            structure: Structure | None = None,
            paths: Pathway | list[Pathway] | None = None,
            jumps: Jumps | None = None,
            lattice: Lattice | None = None,
            title: str = '3D plot') -> go.Figure:
    """Plot 3d."""
    fig = go.Figure()

    if not lattice:
        if volume:
            lattice = volume.lattice
        elif structure:
            lattice = structure.lattice
        elif jumps:
            lattice = jumps.trajectory.get_lattice()
    else:
        raise ValueError('Cannot derive lattice from input.')

    plot_lattice_vectors(lattice, fig=fig)

    if volume:
        plot_volume(volume, lattice=lattice, fig=fig)

    if structure:
        plot_structure(structure=structure, lattice=lattice, fig=fig)

    if paths:
        plot_paths(paths=paths, lattice=lattice, fig=fig)

    if jumps:
        plot_jumps(jumps=jumps, fig=fig)

    update_layout(title=title, lattice=lattice, fig=fig)

    return fig

radial_distribution(rdfs)

Plot radial distribution function.

Parameters:

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_rdf.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure:
    """Plot radial distribution function.

    Parameters
    ----------
    rdfs : Iterable[RDFData]
        List of RDF data to plot

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    fig, ax = plt.subplots()

    for rdf in rdfs:
        ax.plot(rdf.x, rdf.y, label=rdf.symbol)

    states = ', '.join({rdf.state for rdf in rdfs})

    ax.legend()
    ax.set(title=f'Radial distribution function ({states})',
           xlabel='Distance (Ang)',
           ylabel='Counts')

    return fig

rectilinear(*, orientations, shape=(90, 360), normalize_histo=True)

Plot a rectilinear projection of a spherical function. This function uses the transformed trajectory.

Parameters:

  • orientations (Orientations) –

    The unit vector trajectories

  • shape (tuple, default: (90, 360) ) –

    The shape of the spherical sector in which the trajectory is plotted

  • normalize_histo (bool, default: True ) –

    If True, normalize the histogram by the area of the bins, by default True

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_orientations.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def rectilinear(*,
                orientations: Orientations,
                shape: tuple[int, int] = (90, 360),
                normalize_histo: bool = True) -> plt.Figure:
    """Plot a rectilinear projection of a spherical function. This function
    uses the transformed trajectory.

    Parameters
    ----------
    orientations : Orientations
        The unit vector trajectories
    shape : tuple
        The shape of the spherical sector in which the trajectory is plotted
    normalize_histo : bool, optional
        If True, normalize the histogram by the area of the bins, by default True

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    # Convert the vectors to spherical coordinates
    az, el, _ = orientations.vectors_spherical.T
    az = az.flatten()
    el = el.flatten()

    hist, xedges, yedges = np.histogram2d(el, az, shape)

    if normalize_histo:
        # Normalize by the area of the bins
        areas = calculate_spherical_areas(shape)
        hist = np.divide(hist, areas)
        # Drop the bins at the poles where normalization is not possible
        hist = hist[1:-1, :]

    values = hist.T
    axis_phi, axis_theta = values.shape

    phi = np.linspace(0, 360, axis_phi)
    theta = np.linspace(0, 180, axis_theta)

    theta, phi = np.meshgrid(theta, phi)

    fig, ax = plt.subplots(subplot_kw=dict(projection='rectilinear'))
    cs = ax.contourf(phi, theta, values, cmap='viridis')
    ax.set_yticks(np.arange(0, 190, 45))
    ax.set_xticks(np.arange(0, 370, 45))

    ax.set_xlabel(r'azimuthal angle Ο† $[\degree$]')
    ax.set_ylabel(r'elevation ΞΈ $[\degree$]')

    ax.grid(visible=True)
    cbar = fig.colorbar(cs, label='areal probability', format='')

    # Rotate the colorbar label by 180 degrees
    cbar.ax.yaxis.set_label_coords(2.5,
                                   0.5)  # Adjust the position of the label
    cbar.set_label('areal probability', rotation=270, labelpad=15)
    return fig

shape(shape, bins=50, sites=None)

Plot site cluster shapes.

Parameters:

  • shape (ShapeData) –

    Shape data to plot

  • bins (int | Sequence[float], default: 50 ) –

    Number of bins or sequence of bin edges. See hist() for more info.

  • sites (Collection[PeriodicSite] | None, default: None ) –

    Plot these sites on the shape density

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/matplotlib/_shape.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def shape(
    shape: ShapeData,
    bins: int | Sequence[float] = 50,
    sites: Collection[PeriodicSite] | None = None,
) -> plt.Figure:
    """Plot site cluster shapes.

    Parameters
    ----------
    shape : ShapeData
        Shape data to plot
    bins : int | Sequence[float]
        Number of bins or sequence of bin edges.
        See [hist()][matplotlib.pyplot.hist] for more info.
    sites : Collection[PeriodicSite] | None
        Plot these sites on the shape density

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """
    x_labels = ('X / Γ…', 'Y / Γ…', 'Z / Γ…')
    y_labels = ('Y / Γ…', 'Z / Γ…', 'X / Γ…')

    fig, axes = plt.subplots(nrows=2,
                             ncols=3,
                             sharex=True,
                             figsize=(12, 5),
                             gridspec_kw={'height_ratios': (4, 1)})

    distances = shape.distances()
    distances_sq = distances**2

    msd = np.mean(distances_sq)
    std = np.std(distances_sq)
    title = f'{shape.name}: MSD = {msd:.3f}$~Γ…^2$, std = {std:.3f}'

    mean_dist = np.mean(distances)

    _tmp_dict = defaultdict(list)
    vector_dict = {}
    if sites:
        for site in sites:
            _tmp_dict[site.label].append(site.coords)
        for key, values in _tmp_dict.items():
            vector_dict[key] = np.array(values) - shape.origin

    coords = shape.coords

    axes[0, 1].set_title(title)

    for col, (i, j) in enumerate(((0, 1), (1, 2), (2, 0))):
        ax0 = axes[0, col]
        ax1 = axes[1, col]

        x_coords = coords[:, i]
        y_coords = coords[:, j]

        ax0.hist2d(x=x_coords, y=y_coords, bins=bins)
        ax0.set_ylabel(y_labels[col])

        circle = plt.Circle(
            (0, 0),
            mean_dist,
            color='r',
            linestyle='--',
            fill=False,
        )
        ax0.add_patch(circle)

        ax0.scatter(x=[0], y=[0], color='r', marker='.')

        for label, vects in vector_dict.items():
            x_vs = vects[:, i]
            y_vs = vects[:, j]

            for x, y in zip(x_vs, y_vs):
                ax0.text(x, y, s=label, color='r')

            ax0.scatter(x=x_vs, y=y_vs, color='r', marker='.', label=label)

        ax0.axis('equal')

        ax1.hist(x=x_coords, bins=bins, density=True)
        ax1.set_xlabel(x_labels[col])
        ax1.set_ylabel('density')

    fig.tight_layout()

    return fig

vibrational_amplitudes(*, trajectory, bins=50, n_parts=1)

Plot histogram of vibrational amplitudes with fitted Gaussian.

Parameters:

  • trajectory (Trajectory) –

    Input trajectory, i.e. for the diffusing atom

  • n_parts (int, default: 1 ) –

    Number of parts for error analysis

Returns:

  • fig ( Figure ) –

    Output figure

Source code in src/gemdat/plots/plotly/_vibration.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def vibrational_amplitudes(*,
                           trajectory: Trajectory,
                           bins: int = 50,
                           n_parts: int = 1) -> go.Figure:
    """Plot histogram of vibrational amplitudes with fitted Gaussian.

    Parameters
    ----------
    trajectory : Trajectory
        Input trajectory, i.e. for the diffusing atom
    n_parts : int
        Number of parts for error analysis

    Returns
    -------
    fig : matplotlib.figure.Figure
        Output figure
    """

    trajectories = trajectory.split(n_parts)
    single_metrics = SimulationMetrics(trajectory)
    metrics = [
        SimulationMetrics(trajectory).amplitudes()
        for trajectory in trajectories
    ]

    max_amp = max(max(metric) for metric in metrics)
    min_amp = min(min(metric) for metric in metrics)

    max_amp = max(abs(min_amp), max_amp)
    min_amp = -max_amp

    data = []

    for metric in metrics:
        data.append(
            np.histogram(metric,
                         bins=bins,
                         range=(min_amp, max_amp),
                         density=True)[0])

    df = pd.DataFrame(data=data)

    # offset to middle of bar
    offset = (max_amp - min_amp) / (bins * 2)

    columns = np.linspace(min_amp + offset,
                          max_amp + offset,
                          bins,
                          endpoint=False)

    mean = [df[col].mean() for col in df.columns]
    std = [df[col].std() for col in df.columns]

    df = pd.DataFrame(data=zip(columns, mean, std),
                      columns=['amplitude', 'count', 'std'])

    if n_parts == 1:
        fig = px.bar(df, x='amplitude', y='count')
    else:
        fig = px.bar(df, x='amplitude', y='count', error_y='std')

    x = np.linspace(min_amp, max_amp, 100)
    y_gauss = stats.norm.pdf(x, 0, single_metrics.vibration_amplitude())
    fig.add_trace(go.Scatter(x=x, y=y_gauss, name='Fitted Gaussian'))

    fig.update_layout(
        title='Histogram of vibrational amplitudes with fitted Gaussian',
        xaxis_title='Amplitude (Γ…ngstrom)',
        yaxis_title='Occurrence (a.u.)')

    return fig