Skip to content

Commit

Permalink
Merge branch 'features/spline' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Axect committed May 1, 2024
2 parents 6804677 + 1e91054 commit 26906d5
Showing 1 changed file with 145 additions and 1 deletion.
146 changes: 145 additions & 1 deletion src/numerical/spline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,15 @@
//! }
//! ```
//!
//! # B-Spline (incomplete)
//!
//! - `UnitCubicBasis`: Single cubic B-Spline basis function
//! - `CubicBSplineBases`: Uniform Cubic B-Spline basis functions
//!
//! # References
//!
//! * Gary D. Knott, *Interpolating Splines*, Birkhäuser Boston, MA, (2000).
//! - Gary D. Knott, *Interpolating Splines*, Birkhäuser Boston, MA, (2000).
/// - [Wikipedia - Irwin-Hall distribution](https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution#Special_cases)
use self::SplineError::{NotEnoughNodes, NotEqualNodes, NotEqualSlopes, RedundantNodeX};
#[allow(unused_imports)]
Expand Down Expand Up @@ -843,3 +849,141 @@ fn quadratic_slopes(x: &[f64], y: &[f64]) -> Result<Vec<f64>> {

Ok(m)
}

// =============================================================================
// B-Spline
// =============================================================================
/// Unit Cubic Basis Function
///
/// # Description
/// Unit cubic basis function from Irwin-Hall distribution (n=4).
/// For general interval, we substitute t = 4 * (x - a) / (b - a).
///
/// # Reference
/// [Wikipedia](https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution#Special_cases)
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct UnitCubicBasis {
pub x_min: f64,
pub x_max: f64,
pub scale: f64,
}

impl UnitCubicBasis {
pub fn new(x_min: f64, x_max: f64, scale: f64) -> Self {
Self { x_min, x_max, scale }
}

pub fn eval(&self, x: f64) -> f64 {
let t = 4f64 * (x - self.x_min) / (self.x_max - self.x_min);

let result = if (0f64..1f64).contains(&t) {
t.powi(3) / 6f64
} else if (1f64..2f64).contains(&t) {
(-3f64 * t.powi(3) + 12f64 * t.powi(2) - 12f64 * t + 4f64) / 6f64
} else if (2f64..3f64).contains(&t) {
(3f64 * t.powi(3) - 24f64 * t.powi(2) + 60f64 * t - 44f64) / 6f64
} else if (3f64..4f64).contains(&t) {
(4f64 - t).powi(3) / 6f64
} else {
0f64
};

self.scale * result
}

pub fn eval_vec(&self, x: &[f64]) -> Vec<f64> {
x.iter().map(|x| self.eval(*x)).collect()
}
}

/// Uniform Cubic B-Spline basis functions
///
/// # Example
///
/// ```rust
/// use peroxide::fuga::*;
/// use core::ops::Range;
///
/// # #[allow(unused_variables)]
/// fn main() -> anyhow::Result<()> {
/// let cubic_b_spline = CubicBSplineBases::from_interval((0f64, 1f64), 5);
/// let x = linspace(0f64, 1f64, 1000);
/// let y = cubic_b_spline.eval_vec(&x);
///
/// # #[cfg(feature = "plot")] {
/// let mut plt = Plot2D::new();
/// plt.set_domain(x.clone());
///
/// for basis in &cubic_b_spline.bases {
/// plt.insert_image(basis.eval_vec(&x));
/// }
///
/// plt
/// .insert_image(y)
/// .set_xlabel(r"$x$")
/// .set_ylabel(r"$y$")
/// .set_style(PlotStyle::Nature)
/// .tight_layout()
/// .set_dpi(600)
/// .set_path("example_data/cubic_b_spline.png")
/// .savefig()?;
/// # }
/// Ok(())
/// }
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CubicBSplineBases {
pub ranges: Vec<Range<f64>>,
pub bases: Vec<UnitCubicBasis>,
}

impl CubicBSplineBases {
/// Create new Cubic B-Spline basis functions
pub fn new(ranges: Vec<Range<f64>>, bases: Vec<UnitCubicBasis>) -> Self {
Self { ranges, bases }
}

/// Create new Cubic B-Spline basis functions for `[a, b]`
pub fn from_interval((a, b): (f64, f64), num_bases: usize) -> Self {
let nodes = linspace(a, b, num_bases + 4);
let (ranges, bases) = nodes
.iter()
.zip(nodes.iter().skip(4))
.map(|(a, b)| (Range { start: *a, end: *b }, UnitCubicBasis::new(*a, *b, 1f64)))
.unzip();

Self::new(ranges, bases)
}

/// Rescale all basis functions
///
/// # Arguments
/// - `scale_vec` - scale vector
pub fn rescale(&mut self, scale_vec: &[f64]) -> Result<()> {
if scale_vec.len() != self.bases.len() {
bail!("The number of scales should be equal to the number of basis functions");
}

for (basis, scale) in self.bases.iter_mut().zip(scale_vec) {
basis.scale = *scale;
}

Ok(())
}

pub fn eval(&self, x: f64) -> f64 {
self.ranges.iter()
.enumerate()
.filter(|(_, range)| range.contains(&x))
.fold(0f64, |acc, (i, _)| {
let basis = &self.bases[i];
acc + basis.eval(x)
})
}

pub fn eval_vec(&self, x: &[f64]) -> Vec<f64> {
x.iter().map(|x| self.eval(*x)).collect()
}
}

0 comments on commit 26906d5

Please sign in to comment.