From 9634dd30e09738ffd79511e96df63843b6b0374e Mon Sep 17 00:00:00 2001 From: Isaac Adams Date: Sun, 26 Nov 2023 15:33:49 -0500 Subject: [PATCH] implemented using traits --- src/activations.rs | 19 +++++++++++++++++++ src/main.rs | 4 ++-- src/matrix.rs | 2 +- src/network.rs | 20 ++++++++------------ 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/activations.rs b/src/activations.rs index 48adbd3..a95e269 100644 --- a/src/activations.rs +++ b/src/activations.rs @@ -1,5 +1,23 @@ use std::f64::consts::E; +pub trait Activation { + fn function(x: f64) -> f64; + fn derivative(x: f64) -> f64; +} + +pub struct Sigmoid; + +impl Activation for Sigmoid { + fn function(x: f64) -> f64 { + 1.0 / (1.0 + E.powf(-x)) + } + + fn derivative(x: f64) -> f64 { + x * (1.0 - x) + } +} + +/* #[derive(Clone)] pub struct Activation<'a> { pub function: &'a dyn Fn(f64) -> f64, @@ -25,3 +43,4 @@ pub const RELU: Activation = Activation { function: &|x| x.max(0.0), derivative: &|x| if x > 0.0 { 1.0 } else { 0.0 }, }; + */ diff --git a/src/main.rs b/src/main.rs index ddebeff..7bf4969 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use activations::SIGMOID; +use activations::Sigmoid; use network::Network; pub mod activations; @@ -14,7 +14,7 @@ fn main() { ]; let targets = vec![vec![0.0], vec![1.0], vec![1.0], vec![0.0]]; - let mut network = Network::new(vec![2, 3, 1], 0.5, SIGMOID); + let mut network = Network::::new(vec![2, 3, 1], 0.5); network.train(inputs, targets, 1000); diff --git a/src/matrix.rs b/src/matrix.rs index 979944a..4e7f1ed 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -107,7 +107,7 @@ impl Matrix { res } - pub fn map(&self, function: &dyn Fn(f64) -> f64) -> Matrix { + pub fn map(&self, function: impl Fn(f64) -> f64) -> Matrix { Matrix::from( (self.data) .clone() diff --git a/src/network.rs b/src/network.rs index f9345be..a6f01ca 100644 --- a/src/network.rs +++ b/src/network.rs @@ -8,13 +8,13 @@ use serde_json::{from_str, json}; use super::{activations::Activation, matrix::Matrix}; -pub struct Network<'a> { +pub struct Network { layers: Vec, weights: Vec, biases: Vec, data: Vec, learning_rate: f64, - activation: Activation<'a>, + activation: std::marker::PhantomData, } #[derive(Serialize, Deserialize)] @@ -23,12 +23,8 @@ struct SaveData { biases: Vec>>, } -impl Network<'_> { - pub fn new<'a>( - layers: Vec, - learning_rate: f64, - activation: Activation<'a>, - ) -> Network<'a> { +impl Network { + pub fn new<'a>(layers: Vec, learning_rate: f64) -> Self { let mut weights = vec![]; let mut biases = vec![]; @@ -43,7 +39,7 @@ impl Network<'_> { biases, data: vec![], learning_rate, - activation, + activation: std::marker::PhantomData, } } @@ -59,7 +55,7 @@ impl Network<'_> { current = self.weights[i] .multiply(¤t) .add(&self.biases[i]) - .map(self.activation.function); + .map(T::function); self.data.push(current.clone()); } @@ -73,7 +69,7 @@ impl Network<'_> { let parsed = Matrix::from(vec![outputs]).transpose(); let mut errors = Matrix::from(vec![targets]).transpose().subtract(&parsed); - let mut gradients = parsed.map(self.activation.derivative); + let mut gradients = parsed.map(T::derivative); for i in (0..self.layers.len() - 1).rev() { gradients = gradients @@ -84,7 +80,7 @@ impl Network<'_> { self.biases[i] = self.biases[i].add(&gradients); errors = self.weights[i].transpose().multiply(&errors); - gradients = self.data[i].map(self.activation.derivative); + gradients = self.data[i].map(T::derivative); } }