A proof-of-concept implementation of XMM.
XMM, short for eXtended Matrix Multiplication, allows you to define your own matmul-like operator, which takes the following form:
where
Please refer to xmm.pdf
for more details, about how it is generaized to both MLP and KAN and more.
- I believe the parser part needs to be improved for robustness, so please test and report in issue if failure occurs. Writing massive test or refactoring parser is appreciated ❤️.
- The current CUDA implementation is completely not optimized. The optimization pattern is nearly identical to GEMM optimizations(shared memory -> tiling -> register caching -> prefetch/double buffering...), but it may require careful inspection and modification to make small adjustments and fit our needs.
- Contribution of optimizing CUDA kernels is appreciated ❤️.
- Please refer to
xmm/templates/cpp.py
andxmm/templates/cuda.py
for current naive implementation. - A faster implementation (adapted from
sgemm_nt
in MAGMA) is provided, currently only fornrow = 1 && ncol = 1
.
- Further allow customization of
$\sigma$ (in replacement of$\sum$ ), via a binary function (resembling taking a functor as parameter instd::reduce
in C++).
astor
torch
sympy
sortedcontainers
- Clone this repo.
- Check
nvcc
compatibility withtorch
. - Run
python test.py
- [optional] Adjust the
expression
field inxmmtest.py
to modify the combinator.
- Note: Currently supported functions are listed in the end.
See xmmtest.py
and polynomials.py
for example usage.
The example code wraps the operator into a torch.autograd.Function
and builds a layer on top of it.
Core methods are SumOperator.__init__
, op.compile()
, op.forward()
and op.backward()
['exp', 'log', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan2', 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh']