"""
Assorted useful helper functions
"""
import os
import time
from typing import Mapping, Tuple
import dolfin as fem
import munch
import numpy as np
from matplotlib import pyplot as plt # type: ignore
from scipy import interpolate
from mtnlion import deprecated_engine
from mtnlion.domain import Domain, eval_domain
@eval_domain("auto")
def interp_time2(sample_time: np.ndarray, data: np.ndarray) -> interpolate.interp1d:
"""
Interpolate time for all elements in a domain.
:param sample_time: Times at which the data is sampled
:param data: Data to interpolate
"""
return interpolate.interp1d(sample_time, data, axis=0, fill_value="extrapolate")
[docs]def set_domain_data(anode=None, cathode=None, separator=None):
"""
Convenience function for unpacking data into a dictionary
:param anode:
:param cathode:
:param separator:
"""
data = Domain()
if anode is not None:
data["anode"] = anode
if separator is not None:
data["separator"] = separator
if cathode is not None:
data["cathode"] = cathode
return data
[docs]def create_solution_matrices(num_rows: int, num_cols: int, num_solutions: int) -> Tuple[np.ndarray, ...]:
"""
Create numpy arrays for storing data.
:param num_rows:
:param num_cols:
:param num_solutions:
:return:
"""
return tuple(np.empty((num_rows, num_cols)) for _ in range(num_solutions))
[docs]def get_1d(func: fem.Function) -> np.ndarray: # pylint: disable=invalid-name
"""
Fetch the one-dimensional solution from a FEniCS function
:param func: FEniCS function
:param V: Function space
"""
return func.compute_vertex_values()
[docs]def save_fig(fig: plt.Figure, local_module_path: str, name: str):
"""
Save a figure to the given path using the given name
:param fig: figure to save
:param local_module_path: path at which to save
:param name: name of the file
"""
file = os.path.join(os.path.dirname(local_module_path), name)
directory = os.path.dirname(os.path.abspath(file))
if not os.path.exists(directory):
os.makedirs(directory)
fig.savefig(name)
[docs]def overlay_plt(
xdata: np.ndarray,
sample_time: np.ndarray,
title: str,
*ydata: np.ndarray,
figsize: Tuple[int, int] = (15, 9),
linestyles: Tuple[str, str] = ("-", "--"),
):
"""
Plot solution data at multiple time slices against a comparison data set.
:param xdata: Common x axis
:param sample_time: Sample times
:param title: Title of the plot
:param ydata: One or more sets of data in both space and time
:param figsize: Size of the figure
:param linestyles: style of the lines
"""
fig, ax = plt.subplots(figsize=figsize) # pylint: disable=invalid-name
new_x = np.repeat([xdata], len(sample_time), axis=0).T
for i, data in enumerate(ydata):
if i == 1:
plt.plot(new_x, data.T, linestyles[i], marker="o")
else:
plt.plot(new_x, data.T, linestyles[i])
plt.gca().set_prop_cycle(None)
plt.grid()
plt.title(title)
legend1 = plt.legend(
["t = {}".format(t) for t in sample_time], title="Time", bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.0
)
ax.add_artist(legend1)
h = [ # pylint: disable=invalid-name
plt.plot([], [], color="gray", ls=linestyles[i])[0] for i in range(len(linestyles))
]
plt.legend(
handles=h, labels=["FEniCS", "COMSOL"], title="Solver", bbox_to_anchor=(1.01, 0), loc=3, borderaxespad=0.0
)
return fig
[docs]def norm_rmse(estimated: np.ndarray, true: np.ndarray):
"""
Calculate the normalized RMSE
:param estimated: Estimated quantity
:param true: True quantity
"""
estimated = estimated[:, ~np.isnan(estimated).any(axis=0)]
true = true[:, ~np.isnan(true).any(axis=0)]
return deprecated_engine.rmse(estimated, true) / (np.max(true) - np.min(true))
[docs]class Timer:
"""
Convenient class for measuring time with `with` operators.
"""
def __init__(self):
self.start = None
self.end = None
self.interval = None
def __enter__(self):
self.start = time.clock()
return self
def __exit__(self, *args):
self.end = time.clock()
self.interval = self.end - self.start
[docs]def gather_expressions() -> Mapping[str, str]:
"""
Collect C++ based expressions
:return: Dictionary of C++ strings
"""
# TODO: read entire directory
localdir = os.path.dirname(__file__)
code = dict()
with open(os.path.join(localdir, "../headers/xbar.h")) as file:
code["xbar"] = "".join(file.readlines())
with open(os.path.join(localdir, "../headers/composition.h")) as file:
code["composition"] = "".join(file.readlines())
with open(os.path.join(localdir, "../headers/piecewise.h")) as file:
code["piecewise"] = "".join(file.readlines())
with open(os.path.join(localdir, "../headers/xbar_simple.h")) as file:
code["xbar_simple"] = "".join(file.readlines())
with open(os.path.join(localdir, "../headers/template.h")) as file:
code["template"] = "".join(file.readlines())
with open(os.path.join(localdir, "../newman/j_newman.h")) as file:
code["j_newman"] = "".join(file.readlines())
return munch.Munch(code)
# TODO: this is ugly
EXPRESSIONS = gather_expressions()
[docs]def build_expression_class(class_name: str, eval_expr: str, **kwargs: Mapping[str, str]):
"""
Create a FEniCS C++ expression from a template
:param class_name: Name of the expression
:param eval_expr: Expression to evaluate
:param kwargs: Required arguments
"""
# pylint: disable=too-many-locals
generic_func_preamble = "std::shared_ptr<dolfin::GenericFunction> "
eigen_map = "Eigen::Map<Eigen::Matrix<double, 1, 1>>"
generic_func_vars = list(kwargs)
generic_func_vars_declare = ["double " + v + ";" for v in generic_func_vars]
generic_funcs = ["generic_function_" + v for v in generic_func_vars]
generic_func_eval = [
f + "->eval({} (&{}), x, cell);".format(eigen_map, v) for f, v in zip(generic_funcs, generic_func_vars)
]
generic_func_expose = [
'.def_readwrite("{}", &{}::{})'.format(v, class_name, g) for v, g in zip(generic_func_vars, generic_funcs)
]
# Find the proper indent level... Not really a requirement, I'm just OCD. But man, this is UGLY.
template_list = EXPRESSIONS["template"].split("\n")
keys = ("{COMMANDS}", "{GENERIC_FUNCTIONS}", "{EXPOSE_GENERIC_FUNCTIONS}")
indents = {}
for template in template_list:
for key in keys:
if key in template:
indents[key] = "\n"
for char in template.split(key)[0]:
if char.isspace():
indents[key] += " "
eval_command = "values[0] = " + eval_expr + ";"
commands = (
indents[keys[0]].join(generic_func_vars_declare)
+ indents[keys[0]]
+ indents[keys[0]].join(generic_func_eval)
+ indents[keys[0]]
+ eval_command
)
generic_functions = indents[keys[1]].join("{} {};".format(generic_func_preamble, t) for t in generic_funcs)
expose_generic_functions = indents[keys[2]].join(generic_func_expose)
return EXPRESSIONS["template"].format(
CLASS_NAME=class_name,
COMMANDS=commands,
GENERIC_FUNCTIONS=generic_functions,
EXPOSE_GENERIC_FUNCTIONS=expose_generic_functions,
)