Skip to content

Commit

Permalink
Reorganizes const generics to clean a bit of code dup
Browse files Browse the repository at this point in the history
  • Loading branch information
gkanwar committed Sep 12, 2021
1 parent 3fc5f4a commit e924e5d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 100 deletions.
8 changes: 5 additions & 3 deletions src/bin/test_simple.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use lqft::{c64, Action, Boundaries, ColorMat, Lattice4ColorMat, PureGaugeAction};
use ndarray::{Array4};
use lqft::{c64, Action, Boundaries, ColorMat, ColorMatrix, Lattice4ColorMat, PureGaugeAction};
use ndarray::Array4;

fn main() {
let shape = [8, 8, 8, 8];
Expand Down Expand Up @@ -36,7 +36,9 @@ fn main() {
beta: 5.5,
bcs: Boundaries::Periodic
};
println!("Action = {:?}", action.action(&u));
// FIXME: Nightly rustc gives ICE on finding impl of action??
// println!("Action = {:?}", action.action(&u));
println!("TEST: {:?}", u.u[0][[0,0,0,0]].dot(&u.u[0][[0,0,0,0]].adjoint()));

// println!("Initial U = {:?}, detU = {:?}", arr2(&U.data), U.det());
// let start = Instant::now();
Expand Down
115 changes: 42 additions & 73 deletions src/lib/color_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use crate::math::{c64, CMat, CVec};
use crate::rand_dist::{SampleHaar};
use crate::{IntoBytes};
use crate::rand_dist::SampleHaar;
use crate::IntoBytes;

use std::ops::{Add, Index, IndexMut};

// Color matrices
#[derive(Clone)]
#[derive(Clone,Debug)]
pub struct ColorMat<const N: usize> {
pub data: CMat<N>
}

impl<const N: usize> ColorMat<N> {
pub fn new(data: CMat<N>) -> ColorMat<N> {
ColorMat::<N> { data: data }
}

pub fn new_diag(diag_data: CVec<N>) -> ColorMat<N> {
let mut data: CMat<N> = [[0_f64.into(); N]; N];
for i in 0..N {
Expand All @@ -26,7 +22,7 @@ impl<const N: usize> ColorMat<N> {

impl<const N: usize> IntoBytes for ColorMat<N> {
fn into_bytes(&self) -> Vec<u8> {
let mut bs: Vec<u8> = vec!();
let mut bs: Vec<u8> = vec![];
for i in 0..N {
for j in 0..N {
bs.extend_from_slice(&self.data[i][j].re.to_le_bytes());
Expand Down Expand Up @@ -84,66 +80,59 @@ impl<const N: usize> num::Zero for ColorMat<N> {
}

pub trait ColorMatrix: SampleHaar + num::Zero + IndexMut<(usize, usize)> + Clone {
type Info: ColorInfo;
fn dot(&self, other: &Self) -> Self;
fn adjoint(&self) -> Self;
fn conj(&self) -> Self;
fn det(&self) -> c64;
fn tr(&self) -> c64;
fn reunit(&mut self) -> ();
}
const NC: usize;

pub trait ColorInfo {
fn nc() -> usize;
}
fn new(data: CMat<{ Self::NC }>) -> Self;
fn view_data(&self) -> &CMat<{ Self::NC }>;
fn view_data_mut(&mut self) -> &mut CMat<{ Self::NC }>;

pub struct Nc<const N: usize>;
impl<const N: usize> ColorInfo for Nc<N> {
fn nc() -> usize {
N
fn dot(&self, other: &Self) -> Self where [(); Self::NC]: {
Self::new(crate::math::matmul(self.view_data(), other.view_data()))
}
fn adjoint(&self) -> Self where [(); Self::NC]: {
Self::new(crate::math::adjoint(self.view_data()))
}
fn conj(&self) -> Self where [(); Self::NC]: {
Self::new(crate::math::conj(self.view_data()))
}
fn tr(&self) -> c64 where [(); Self::NC]: {
crate::math::tr(self.view_data())
}
fn reunit(&mut self) -> () where [(); Self::NC]: {
crate::math::reunit(self.view_data_mut());
}

fn det(&self) -> c64;
}

impl ColorMatrix for ColorMat<2> {
type Info = Nc<2>;

fn dot(&self, other: &Self) -> Self {
Self {
data: crate::math::matmul(&self.data, &other.data)
}
}

fn det(&self) -> c64 {
self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0]
}
const NC: usize = 2;

fn tr(&self) -> c64 {
crate::math::tr(&self.data)
fn new(data: CMat<{ Self::NC }>) -> Self {
Self { data: data }
}

fn adjoint(&self) -> Self {
Self {
data: crate::math::adjoint(&self.data)
}
fn view_data(&self) -> &CMat<{ Self::NC }> {
&self.data
}

fn conj(&self) -> Self {
Self {
data: crate::math::conj(&self.data)
}
fn view_data_mut(&mut self) -> &mut CMat<{ Self::NC }> {
&mut self.data
}

fn reunit(&mut self) -> () {
crate::math::reunit(&mut self.data);
fn det(&self) -> c64 {
self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0]
}
}
impl ColorMatrix for ColorMat<3> {
type Info = Nc<3>;
const NC: usize = 3;

fn dot(&self, other: &Self) -> Self {
Self {
data: crate::math::matmul(&self.data, &other.data)
}
fn new(data: CMat<{ Self::NC }>) -> Self {
Self { data: data }
}
fn view_data(&self) -> &CMat<{ Self::NC }> {
&self.data
}
fn view_data_mut(&mut self) -> &mut CMat<{ Self::NC }> {
&mut self.data
}

fn det(&self) -> c64 {
Expand All @@ -154,24 +143,4 @@ impl ColorMatrix for ColorMat<3> {
- self.data[0][1] * self.data[1][0] * self.data[2][2]
- self.data[0][2] * self.data[1][1] * self.data[2][0]
}

fn tr(&self) -> c64 {
crate::math::tr(&self.data)
}

fn adjoint(&self) -> Self {
Self {
data: crate::math::adjoint(&self.data)
}
}

fn conj(&self) -> Self {
Self {
data: crate::math::conj(&self.data)
}
}

fn reunit(&mut self) -> () {
crate::math::reunit(&mut self.data);
}
}
53 changes: 30 additions & 23 deletions src/lib/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#![feature(const_evaluatable_checked)]
#![feature(const_generics)]
#![feature(generic_const_exprs)]
#![allow(incomplete_features)]

mod color_matrix;
mod math;
mod rand_dist;

pub use color_matrix::{ColorInfo, ColorMat, ColorMatrix};
pub use color_matrix::{ColorMat, ColorMatrix};
pub use math::{c64, CMat, CVec};
pub use rand_dist::{sample_su2_theta, SampleHaar};

Expand All @@ -29,10 +28,7 @@ pub struct Lattice4ColorMat<T> {
pub u: [Array4<T>; 4]
}

impl<T> LatticeObject<4> for Lattice4ColorMat<T>
where
T: ColorMatrix
{
impl<T> LatticeObject<4> for Lattice4ColorMat<T> {
fn latt_shape(&self) -> [usize; 4] {
self.u[0]
.shape()
Expand Down Expand Up @@ -66,7 +62,8 @@ impl Direction {
// TODO: The cloning is a bit wasteful
fn apply_gauge_bcs<T>(u: &T, bcs: Boundaries, dir: Direction, crossed: bool) -> T
where
T: ColorMatrix
T: ColorMatrix,
[(); T::NC]:
{
match bcs {
Boundaries::Periodic => u.clone(),
Expand All @@ -88,7 +85,7 @@ pub trait Action {
fn action<T>(&self, u: &Lattice4ColorMat<T>) -> f64
where
T: ColorMatrix,
<T as ColorMatrix>::Info: ColorInfo;
[(); T::NC]: ;
}

pub struct PureGaugeAction {
Expand All @@ -100,7 +97,7 @@ impl Action for PureGaugeAction {
fn action<T>(&self, u: &Lattice4ColorMat<T>) -> f64
where
T: ColorMatrix,
<T as ColorMatrix>::Info: ColorInfo
[(); T::NC]:
{
let shape = u.latt_shape();
let mut tot_action: f64 = 0.0;
Expand Down Expand Up @@ -132,17 +129,25 @@ impl Action for PureGaugeAction {
}
}
}
return -(self.beta / (<T as ColorMatrix>::Info::nc() as f64)) * tot_action;
return -(self.beta / T::NC as f64) * tot_action;
}
}

pub fn pseudo_heatbath<T>(u: &mut T, a: &T, beta: f64) -> ()
where
T: ColorMatrix,
<T as ColorMatrix>::Info: ColorInfo,
<T as Index<(usize, usize)>>::Output: From<c64>
<T as Index<(usize, usize)>>::Output: From<c64> + Into<c64> + Copy
{
let a_beta = beta * a.det().norm().sqrt();
// FORNOW: 2x2
let mut a22: CMat<2> = [[0_f64.into(); 2]; 2];
for i in 0..2 {
for j in 0..2 {
a22[i][j] = a[(i,j)].into();
}
}
let a22 = ColorMat::<2>::new(a22);
let sqrt_det_a = a22.det().norm().sqrt();
let a_beta = beta * sqrt_det_a;
let theta = sample_su2_theta(a_beta);
let w = ColorMat::<2>::sample_haar();
let lam1 = c64 {
Expand All @@ -151,12 +156,11 @@ where
};
let eigs = [lam1, lam1.conj()];
let d = ColorMat::<2>::new_diag(eigs);
let up = w.dot(&d).dot(&w.adjoint());
let n = T::Info::nc();
assert_eq!(n, 2); // FORNOW
for i in 0..n {
for j in 0..n {
u[(i, j)] = up[(i, j)].into();
let up_unnorm = w.dot(&d).dot(&w.adjoint()).dot(&a22);
assert_eq!(T::NC, 2); // FORNOW
for i in 0..T::NC {
for j in 0..T::NC {
u[(i, j)] = (up_unnorm[(i, j)] / sqrt_det_a).into();
}
}
}
Expand All @@ -166,7 +170,8 @@ fn compute_staple<T>(
(mu, x, y, z, t): (usize, usize, usize, usize, usize)
) -> T
where
T: ColorMatrix
T: ColorMatrix,
[(); T::NC]:
{
let shape = u.latt_shape();
let coord = [x, y, z, t];
Expand Down Expand Up @@ -211,7 +216,8 @@ where
pub fn heatbath_sweep<T>(u: &mut Lattice4ColorMat<T>, s: &PureGaugeAction) -> ()
where
T: ColorMatrix,
<T as Index<(usize, usize)>>::Output: From<c64>
<T as Index<(usize, usize)>>::Output: From<c64> + Into<c64> + Copy,
[(); T::NC]:
{
let shape = u.latt_shape();
for mu in 0..4 {
Expand All @@ -235,7 +241,8 @@ pub fn update_sweep<T>(
) -> ()
where
T: ColorMatrix,
<T as Index<(usize, usize)>>::Output: From<c64>
<T as Index<(usize, usize)>>::Output: From<c64> + Into<c64> + Copy,
[(); T::NC]:
{
for _ in 0..n_heatbath {
heatbath_sweep(u, s);
Expand Down
2 changes: 1 addition & 1 deletion src/lib/rand_dist.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::color_matrix::ColorMat;
use crate::color_matrix::{ColorMat, ColorMatrix};
use crate::math::{c64, CMat};
use ndarray::arr2;
use rand::{thread_rng, Rng};
Expand Down

0 comments on commit e924e5d

Please sign in to comment.