Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Joe McCain III <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed May 25, 2024
1 parent 3bfd068 commit 033830c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 30 deletions.
62 changes: 53 additions & 9 deletions core/src/func/activate/nl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::math::Exp;
use ndarray::*;
use nd::*;
use num::complex::{Complex, ComplexFloat};
use num::traits::Zero;

Expand Down Expand Up @@ -34,14 +34,16 @@ where
&e / e.sum()
}

// fn __softmax<T, I>(args: &I) -> I
// where
// I: Clone + core::ops::Div<T, Output = I> + Exp<Output = I>, T: Exp<Output = T> + core::iter::Sum ,
// for<'a> I: IntoIterator<Item = &'a T>,
// {
// let e = args.exp();
// e.clone() / e.into_iter().sum::<T>()
// }
fn _softmax_axis<A, S, D>(args: &ArrayBase<S, D>, axis: usize) -> Array<A, D>
where
A: ComplexFloat + ScalarOperand,
D: RemoveAxis,
S: Data<Elem = A>,
{
let axis = Axis(axis);
let e = args.exp();
&e / &e.sum_axis(axis)
}

fn _tanh<T>(args: T) -> T
where
Expand All @@ -57,6 +59,22 @@ unary!(
Tanh::tanh(self),
);

pub trait SoftmaxAxis {
type Output;

fn softmax_axis(self, axis: usize) -> Self::Output;
}

pub trait NonLinear {
type Output;

fn relu(self) -> Self::Output;
fn sigmoid(self) -> Self::Output;
fn softmax(self) -> Self::Output;
fn softmax_axis(self, axis: usize) -> Self::Output;
fn tanh(self) -> Self::Output;
}

/*
********** Implementations **********
*/
Expand Down Expand Up @@ -230,3 +248,29 @@ where
_softmax(self)
}
}

impl<A, S, D> SoftmaxAxis for ArrayBase<S, D>
where
A: ComplexFloat + ScalarOperand,
D: RemoveAxis,
S: Data<Elem = A>,
{
type Output = Array<A, D>;

fn softmax_axis(self, axis: usize) -> Self::Output {
_softmax_axis(&self, axis)
}
}

impl<'a, A, S, D> SoftmaxAxis for &'a ArrayBase<S, D>
where
A: ComplexFloat + ScalarOperand,
D: RemoveAxis,
S: Data<Elem = A>,
{
type Output = Array<A, D>;

fn softmax_axis(self, axis: usize) -> Self::Output {
_softmax_axis(&self, axis)

Check warning

Code scanning / clippy

this expression creates a reference which is immediately dereferenced by the compiler Warning

this expression creates a reference which is immediately dereferenced by the compiler
}
}
41 changes: 36 additions & 5 deletions core/src/macros/getters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ macro_rules! getters {
($($call:ident$(.$field:ident)?<$out:ty>),* $(,)?) => {
$($crate::getters!(@impl $call$(.$field)?<$out>);)*
};
($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => {
$($crate::getters!(@impl $via::$call$(.$field)?<$out>);)*
};
($($call:ident$(.$field:ident)?),* $(,)? => $out:ty) => {
$($crate::getters!(@impl $call$(.$field)?<$out>);)*
};
($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => {
$($crate::getters!(@impl $via::$call$(.$field)?<$out>);)*
};
($via:ident::<[$($call:ident$(.$field:ident)?),* $(,)?]> => $out:ty) => {
$crate::getters!($via::<[$($call$(.$field)?<$out>),*]>);
};
Expand All @@ -36,12 +36,43 @@ macro_rules! getters {
};
(@impl $via:ident::$call:ident.$field:ident<$out:ty>) => {
pub fn $call(&self) -> &$out {
&self.$via.$field
&self.$via.$field()
}
paste::paste! {
pub fn [< $call _mut>](&mut self) -> &mut $out {
&mut self.$via.$field
self.$via.[<$field _mut>]()
}
}
};
}

#[macro_export]
macro_rules! getter {
($($($field:ident).*::$call:ident<$out:ty>),* $(,)?) => {
$($crate::getter!(@impl $($field).*::$call<$out>);)*
};
($($($field:ident).*::$call:ident),* $(,)? => $out:ty) => {
$($crate::getter!(@impl $($field).*::$call<$out>);)*
};

(@impl $($field:ident).*::$call:ident<$out:ty>) => {
pub fn $call(&self) -> &$out {
&self.$($field).*
}
paste::paste! {
pub fn [< $call _mut>](&mut self) -> &mut $out {
&mut self.$($field).*
}
}
};
}

#[macro_export]
macro_rules! nested_getter {
($($field:ident).*::<[$($call:ident<$out:ty>),* $(,)?]>) => {
$($crate::getter!($($field).*::$call<$out>);)*
};
($($field:ident).*::<[$($call:ident),* $(,)?]> => $out:ty) => {
$crate::getter!($($($field).*::$call<$out>)*);
};
}
50 changes: 42 additions & 8 deletions models/transformers/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,45 @@
Appellation: config <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/

use concision::getters;

pub struct TransformerConfig {
pub dropout: Option<f64>,
pub features: Features,
pub heads: usize,
pub layers: usize,
}

impl TransformerConfig {
pub fn new(dropout: Option<f64>, features: Features, heads: usize, layers: usize) -> Self {
Self {
dropout,
features,
heads,
layers,
}
}

getters!(dropout<Option<f64>>, features<Features>, heads<usize>, layers<usize>);
getters!(features::<[d_model<usize>, qkv<QkvShape>]>);
getters!(features::<[dk, dq, dv]> => usize);
}

pub struct Features {

pub d_model: usize,
pub qkv: QkvShape,
}

impl Features {
pub fn new(d_model: usize, qkv: QkvShape) -> Self {
Self {
d_model,
qkv,
}
}

getters!(d_model<usize>, qkv<QkvShape>);
getters!(qkv::<[dk, dq, dv]> => usize);
}

pub struct QkvShape {
Expand All @@ -34,13 +63,18 @@ impl QkvShape {

Self::new(dq, dk, dv)
}
}


pub struct EmbedConfig {

getters!(dk, dq, dv => usize);
}

pub struct FFNConfig {
impl From<usize> for QkvShape {
fn from(dk: usize) -> Self {
Self::std(dk)
}
}

}
impl From<(usize, usize, usize)> for QkvShape {
fn from((dq, dk, dv): (usize, usize, usize)) -> Self {
Self::new(dq, dk, dv)
}
}
10 changes: 2 additions & 8 deletions models/transformers/src/model/ffn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
where
A: Clone + Default,
{
let dropout = dropout.map(|p| Dropout::new(p));
let dropout = dropout.map(Dropout::new);
let input = Linear::from_features(d_model, features);
let output = Linear::from_features(features, d_model);
Self {
Expand All @@ -50,13 +50,7 @@ impl<A, D, K> FeedForwardNetwork<A, K, D>
where
D: Dimension,
{
pub fn input(&self) -> &Linear<A, K, D> {
&self.input
}

pub fn output(&self) -> &Linear<A, K, D> {
&self.output
}
concision::getters!(input, output => Linear<A, K, D>);
}

#[cfg(feature = "rand")]
Expand Down

0 comments on commit 033830c

Please sign in to comment.