Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/activations.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
use std::f64::consts::E;

pub trait Activation {
fn function(x: f64) -> f64;
fn derivative(x: f64) -> f64;
}

pub struct Sigmoid;

impl Activation for Sigmoid {
fn function(x: f64) -> f64 {
1.0 / (1.0 + E.powf(-x))
}

fn derivative(x: f64) -> f64 {
x * (1.0 - x)
}
}

/*
#[derive(Clone)]
pub struct Activation<'a> {
pub function: &'a dyn Fn(f64) -> f64,
Expand All @@ -25,3 +43,4 @@ pub const RELU: Activation = Activation {
function: &|x| x.max(0.0),
derivative: &|x| if x > 0.0 { 1.0 } else { 0.0 },
};
*/
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use activations::SIGMOID;
use activations::Sigmoid;
use network::Network;

pub mod activations;
Expand All @@ -14,7 +14,7 @@ fn main() {
];
let targets = vec![vec![0.0], vec![1.0], vec![1.0], vec![0.0]];

let mut network = Network::new(vec![2, 3, 1], 0.5, SIGMOID);
let mut network = Network::<Sigmoid>::new(vec![2, 3, 1], 0.5);

network.train(inputs, targets, 1000);

Expand Down
2 changes: 1 addition & 1 deletion src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl Matrix {
res
}

pub fn map(&self, function: &dyn Fn(f64) -> f64) -> Matrix {
pub fn map(&self, function: impl Fn(f64) -> f64) -> Matrix {
Matrix::from(
(self.data)
.clone()
Expand Down
20 changes: 8 additions & 12 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use serde_json::{from_str, json};

use super::{activations::Activation, matrix::Matrix};

pub struct Network<'a> {
pub struct Network<T: Activation> {
layers: Vec<usize>,
weights: Vec<Matrix>,
biases: Vec<Matrix>,
data: Vec<Matrix>,
learning_rate: f64,
activation: Activation<'a>,
activation: std::marker::PhantomData<T>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -23,12 +23,8 @@ struct SaveData {
biases: Vec<Vec<Vec<f64>>>,
}

impl Network<'_> {
pub fn new<'a>(
layers: Vec<usize>,
learning_rate: f64,
activation: Activation<'a>,
) -> Network<'a> {
impl<T: Activation> Network<T> {
pub fn new<'a>(layers: Vec<usize>, learning_rate: f64) -> Self {
let mut weights = vec![];
let mut biases = vec![];

Expand All @@ -43,7 +39,7 @@ impl Network<'_> {
biases,
data: vec![],
learning_rate,
activation,
activation: std::marker::PhantomData,
}
}

Expand All @@ -59,7 +55,7 @@ impl Network<'_> {
current = self.weights[i]
.multiply(&current)
.add(&self.biases[i])
.map(self.activation.function);
.map(T::function);
self.data.push(current.clone());
}

Expand All @@ -73,7 +69,7 @@ impl Network<'_> {

let parsed = Matrix::from(vec![outputs]).transpose();
let mut errors = Matrix::from(vec![targets]).transpose().subtract(&parsed);
let mut gradients = parsed.map(self.activation.derivative);
let mut gradients = parsed.map(T::derivative);

for i in (0..self.layers.len() - 1).rev() {
gradients = gradients
Expand All @@ -84,7 +80,7 @@ impl Network<'_> {
self.biases[i] = self.biases[i].add(&gradients);

errors = self.weights[i].transpose().multiply(&errors);
gradients = self.data[i].map(self.activation.derivative);
gradients = self.data[i].map(T::derivative);
}
}

Expand Down