"""
Tools for better reporting of the solution
"""
import json
from typing import Optional
import numpy as np
from matplotlib import pyplot as plt # type: ignore
from mtnlion.tools.helpers import norm_rmse, overlay_plt, save_fig
[docs]class Report:
"""
Simplify the reporting of the gathered solutions.
"""
def __init__(self, solution, sample_times, split=False, comsol_data=None):
"""
Create a reporting object.
:param solution: solution data
:param sample_times: times at which to resample the data
:param split: if true split the report by subdomain
:param comsol_data: data to compare against
"""
self.sample_times = sample_times
self.solution = solution
self.solutions = self.solution.interp_time(solution.time, self.solution.solutions)
self.split = split
self.comsol_data = {} if comsol_data is None else comsol_data.copy()
self.order = ["anode", "separator", "cathode"]
self._handle_join()
def _handle_join(self):
if not self.split:
self.mesh = np.concatenate((self.solution.mesh, self.solution.mesh + 1, self.solution.mesh + 2))
for name, domains in self.solutions.items():
data = None
cs_data = None
for domain in self.order:
if domains.get(domain, None) is None:
tmp = np.empty((data.shape[0], len(self.solution.mesh)))
tmp.fill(np.nan)
data = np.append(data, tmp, axis=1)
cs_data = np.append(cs_data, tmp, axis=1)
continue
if data is None:
data = domains[domain](self.sample_times)
if name in self.comsol_data:
cs_data = self.comsol_data[name][domain](self.sample_times)
else:
cs_data = np.empty((data.shape[0], len(self.solution.mesh)))
cs_data.fill(np.nan)
else:
data = np.append(data, domains[domain](self.sample_times), axis=1)
if name in self.comsol_data:
cs_data = np.append(cs_data, self.comsol_data[name][domain](self.sample_times), axis=1)
else:
tmp = np.empty((data.shape[0], len(self.solution.mesh)))
tmp.fill(np.nan)
cs_data = np.append(cs_data, tmp, axis=1)
self.solutions[name] = data
self.comsol_data[name] = cs_data
else:
self.solutions = {
name: {domain: data(self.sample_times) for domain, data in domains.items()}
for name, domains in self.solutions.items()
}
self.comsol_data = {
name: {domain: data(self.sample_times) for domain, data in domains.items()}
for name, domains in self.comsol_data.items()
}
@staticmethod
def _format_name(name, domain):
lookup = {
"phis": r"$\Phi_s$",
"phie": r"$\Phi_e$",
"cs": "$c_$",
"ce": "$c_e$",
"j": "$j$",
"cse": r"$c_{s,e}$",
"js": "$j_s$",
}
return "{}: {}".format(lookup.get(name, name), domain) if domain else "{}".format(lookup.get(name, name))
[docs] def plot(self, local_path: Optional[str] = None, save: Optional[str] = None):
"""
Plot the stored solutions.
:param local_path: Path of the calling module
:param save: true to save the plot to disk
"""
if self.split:
for name, domains in self.solutions.items():
for domain, data in domains.items():
fig = overlay_plt(
self.solution.mesh,
self.sample_times,
self._format_name(name, domain),
data,
self.comsol_data[name][domain],
)
plt.show()
if save is not None and local_path is not None:
save_fig(fig, local_path, "{}/{}_{}.png".format(save, name, domain))
else:
for name, data in self.solutions.items():
fig = overlay_plt(
self.mesh, self.sample_times, self._format_name(name, ""), data, self.comsol_data[name]
)
plt.show()
if not (save is None or local_path is None):
save_fig(fig, local_path, "{}/{}.png".format(save, name))
[docs] def report_rmse(self):
"""
Report the normalized RMSE
"""
if self.split:
return json.dumps(
{
name: {
domain: norm_rmse(estimated, self.comsol_data[name][domain]).tolist()
for domain, estimated in domains.items()
}
for name, domains in self.solutions.items()
},
indent=4,
)
return json.dumps(
{
name: norm_rmse(estimated, self.comsol_data[name]).tolist()
for name, estimated in self.solutions.items()
if not np.all(np.isnan(self.comsol_data[name]))
},
indent=4,
)