From 033830c3d745be0b076d53bada0cc18a565af27a Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 25 May 2024 10:46:18 -0500 Subject: [PATCH] update Signed-off-by: Joe McCain III --- core/src/func/activate/nl.rs | 62 +++++++++++++++++++++++---- core/src/macros/getters.rs | 41 +++++++++++++++--- models/transformers/src/config/mod.rs | 50 +++++++++++++++++---- models/transformers/src/model/ffn.rs | 10 +---- 4 files changed, 133 insertions(+), 30 deletions(-) diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index e488575..6901db0 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -3,7 +3,7 @@ Contrib: FL03 */ use crate::math::Exp; -use ndarray::*; +use nd::*; use num::complex::{Complex, ComplexFloat}; use num::traits::Zero; @@ -34,14 +34,16 @@ where &e / e.sum() } -// fn __softmax(args: &I) -> I -// where -// I: Clone + core::ops::Div + Exp, T: Exp + core::iter::Sum , -// for<'a> I: IntoIterator, -// { -// let e = args.exp(); -// e.clone() / e.into_iter().sum::() -// } +fn _softmax_axis(args: &ArrayBase, axis: usize) -> Array +where + A: ComplexFloat + ScalarOperand, + D: RemoveAxis, + S: Data, +{ + let axis = Axis(axis); + let e = args.exp(); + &e / &e.sum_axis(axis) +} fn _tanh(args: T) -> T where @@ -57,6 +59,22 @@ unary!( Tanh::tanh(self), ); +pub trait SoftmaxAxis { + type Output; + + fn softmax_axis(self, axis: usize) -> Self::Output; +} + +pub trait NonLinear { + type Output; + + fn relu(self) -> Self::Output; + fn sigmoid(self) -> Self::Output; + fn softmax(self) -> Self::Output; + fn softmax_axis(self, axis: usize) -> Self::Output; + fn tanh(self) -> Self::Output; +} + /* ********** Implementations ********** */ @@ -230,3 +248,29 @@ where _softmax(self) } } + +impl SoftmaxAxis for ArrayBase +where + A: ComplexFloat + ScalarOperand, + D: RemoveAxis, + S: Data, +{ + type Output = Array; + + fn softmax_axis(self, axis: usize) -> Self::Output { + _softmax_axis(&self, axis) + } +} + +impl<'a, A, S, D> SoftmaxAxis for &'a ArrayBase +where + A: ComplexFloat + ScalarOperand, + D: RemoveAxis, + S: Data, +{ + type Output = Array; + + fn softmax_axis(self, axis: usize) -> Self::Output { + _softmax_axis(&self, axis) + } +} diff --git a/core/src/macros/getters.rs b/core/src/macros/getters.rs index 86d232b..f8bbd3a 100644 --- a/core/src/macros/getters.rs +++ b/core/src/macros/getters.rs @@ -8,12 +8,12 @@ macro_rules! getters { ($($call:ident$(.$field:ident)?<$out:ty>),* $(,)?) => { $($crate::getters!(@impl $call$(.$field)?<$out>);)* }; - ($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => { - $($crate::getters!(@impl $via::$call$(.$field)?<$out>);)* - }; ($($call:ident$(.$field:ident)?),* $(,)? => $out:ty) => { $($crate::getters!(@impl $call$(.$field)?<$out>);)* }; + ($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => { + $($crate::getters!(@impl $via::$call$(.$field)?<$out>);)* + }; ($via:ident::<[$($call:ident$(.$field:ident)?),* $(,)?]> => $out:ty) => { $crate::getters!($via::<[$($call$(.$field)?<$out>),*]>); }; @@ -36,12 +36,43 @@ macro_rules! getters { }; (@impl $via:ident::$call:ident.$field:ident<$out:ty>) => { pub fn $call(&self) -> &$out { - &self.$via.$field + &self.$via.$field() } paste::paste! { pub fn [< $call _mut>](&mut self) -> &mut $out { - &mut self.$via.$field + self.$via.[<$field _mut>]() } } }; } + +#[macro_export] +macro_rules! getter { + ($($($field:ident).*::$call:ident<$out:ty>),* $(,)?) => { + $($crate::getter!(@impl $($field).*::$call<$out>);)* + }; + ($($($field:ident).*::$call:ident),* $(,)? => $out:ty) => { + $($crate::getter!(@impl $($field).*::$call<$out>);)* + }; + + (@impl $($field:ident).*::$call:ident<$out:ty>) => { + pub fn $call(&self) -> &$out { + &self.$($field).* + } + paste::paste! { + pub fn [< $call _mut>](&mut self) -> &mut $out { + &mut self.$($field).* + } + } + }; +} + +#[macro_export] +macro_rules! nested_getter { + ($($field:ident).*::<[$($call:ident<$out:ty>),* $(,)?]>) => { + $($crate::getter!($($field).*::$call<$out>);)* + }; + ($($field:ident).*::<[$($call:ident),* $(,)?]> => $out:ty) => { + $crate::getter!($($($field).*::$call<$out>)*); + }; +} \ No newline at end of file diff --git a/models/transformers/src/config/mod.rs b/models/transformers/src/config/mod.rs index b052876..8f57601 100644 --- a/models/transformers/src/config/mod.rs +++ b/models/transformers/src/config/mod.rs @@ -2,16 +2,45 @@ Appellation: config Contrib: FL03 */ - +use concision::getters; pub struct TransformerConfig { + pub dropout: Option, + pub features: Features, pub heads: usize, + pub layers: usize, +} + +impl TransformerConfig { + pub fn new(dropout: Option, features: Features, heads: usize, layers: usize) -> Self { + Self { + dropout, + features, + heads, + layers, + } + } + + getters!(dropout>, features, heads, layers); + getters!(features::<[d_model, qkv]>); + getters!(features::<[dk, dq, dv]> => usize); } pub struct Features { - pub d_model: usize, + pub qkv: QkvShape, +} +impl Features { + pub fn new(d_model: usize, qkv: QkvShape) -> Self { + Self { + d_model, + qkv, + } + } + + getters!(d_model, qkv); + getters!(qkv::<[dk, dq, dv]> => usize); } pub struct QkvShape { @@ -34,13 +63,18 @@ impl QkvShape { Self::new(dq, dk, dv) } -} - - -pub struct EmbedConfig { + getters!(dk, dq, dv => usize); } -pub struct FFNConfig { +impl From for QkvShape { + fn from(dk: usize) -> Self { + Self::std(dk) + } +} -} \ No newline at end of file +impl From<(usize, usize, usize)> for QkvShape { + fn from((dq, dk, dv): (usize, usize, usize)) -> Self { + Self::new(dq, dk, dv) + } +} diff --git a/models/transformers/src/model/ffn.rs b/models/transformers/src/model/ffn.rs index 00a29ab..ba77ea3 100644 --- a/models/transformers/src/model/ffn.rs +++ b/models/transformers/src/model/ffn.rs @@ -35,7 +35,7 @@ where where A: Clone + Default, { - let dropout = dropout.map(|p| Dropout::new(p)); + let dropout = dropout.map(Dropout::new); let input = Linear::from_features(d_model, features); let output = Linear::from_features(features, d_model); Self { @@ -50,13 +50,7 @@ impl FeedForwardNetwork where D: Dimension, { - pub fn input(&self) -> &Linear { - &self.input - } - - pub fn output(&self) -> &Linear { - &self.output - } + concision::getters!(input, output => Linear); } #[cfg(feature = "rand")]