Skip to content

Commit

Permalink
lasso-regularization (#3)
Browse files Browse the repository at this point in the history
* c++ stochastic multiplex implementation

* wider network

* Add LASSO regularization
  • Loading branch information
connormcmonigle authored Sep 28, 2023
1 parent 24984cf commit d9ec877
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 107 deletions.
22 changes: 12 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@ find_package (Threads REQUIRED)
set (CMAKE_CXX_STANDARD 17)

set (OPS_LIMIT 1000000000)

set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Ofast -flto -g -DNDEBUG -march=native -mtune=native -fopenmp -Wall -Wextra")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fconstexpr-ops-limit=${OPS_LIMIT}")

if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fconstexpr-ops-limit=${OPS_LIMIT}")
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fconstexpr-steps=${OPS_LIMIT}")
endif ()

message (STATUS "EVALFILE=${EVALFILE}")
add_compile_definitions (EVALFILE="${EVALFILE}")

include_directories (include seer-nnue/include/ seer-nnue/incbin/ seer-nnue/syzygy/)

add_executable (test_sample_reader test/test_sample_reader.cc seer-nnue/syzygy/tbchess.c seer-nnue/syzygy/tbprobe.c)
target_link_libraries (test_sample_reader Threads::Threads)
file(GLOB CHESS_SRC_FILES seer-nnue/src/chess/*.cc)
file(GLOB SEARCH_SRC_FILES seer-nnue/src/search/*.cc)

add_executable (test_data_gen test/test_data_gen.cc seer-nnue/syzygy/tbchess.c seer-nnue/syzygy/tbprobe.c)
target_link_libraries (test_data_gen Threads::Threads)

add_subdirectory (pybind11)
pybind11_add_module (seer_train src/seer_train.cc seer-nnue/syzygy/tbchess.c seer-nnue/syzygy/tbprobe.c)
pybind11_add_module (
seer_train
src/seer_train.cc

${CHESS_SRC_FILES}
${SEARCH_SRC_FILES}
seer-nnue/syzygy/tbprobe.cc
)

Empty file removed build/.gitkeep
Empty file.
24 changes: 14 additions & 10 deletions include/data_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
#include <memory>
#include <random>

#include <embedded_weights.h>
#include <nnue_util.h>
#include <chess/board_history.h>
#include <nnue/eval.h>
#include <search/search_constants.h>
#include <search/search_worker.h>
#include <nnue/embedded_weights.h>
#include <nnue/weights_streamer.h>

#include <sample.h>
#include <sample_reader.h>
Expand Down Expand Up @@ -66,13 +70,13 @@ struct data_generator{
data_generator& generate_data() {
const nnue::weights weights = [&, this]{
nnue::weights result{};
nnue::embedded_weight_streamer embedded(embed::weights_file_data);
nnue::embedded_weight_streamer embedded(nnue::embed::weights_file_data);
result.load(embedded);
return result;
}();

auto generate = [&, this]{
using worker_type = search::search_worker<false>;
using worker_type = search::search_worker;
auto gen = std::mt19937(std::random_device()());


Expand All @@ -87,7 +91,7 @@ struct data_generator{

std::vector<sample> block{};

chess::position_history hist{};
chess::board_history hist{};
state_type state = state_type::start_pos();

const result_type game_result = [&]{
Expand All @@ -96,11 +100,11 @@ struct data_generator{
const auto mv_list = state.generate_moves();
const size_t idx = std::uniform_int_distribution<size_t>(0, mv_list.size()-1)(gen);

hist.push_(state.hash());
hist.push(state.hash());
state = state.forward(mv_list[idx]);
} else {
worker->go(hist, state, 1);
worker->iterative_deepening_loop_();
worker->iterative_deepening_loop();
worker->stop();

const auto best_move = worker->best_move();
Expand All @@ -111,8 +115,8 @@ struct data_generator{

const auto view = search::stack_view::root((worker->internal).stack);
const auto evaluator = [&] {
nnue::eval result(&weights);
state.feature_full_refresh(result);
nnue::eval result(&weights, &worker->internal.scratchpad, 0, 0);
state.feature_full_reset(result);
return result;
}();

Expand All @@ -125,7 +129,7 @@ struct data_generator{

if (!state.is_check() && static_eval == q_eval) { block.emplace_back(state, best_score); }

hist.push_(state.hash());
hist.push(state.hash());
state = state.forward(best_move);
}
}
Expand Down
2 changes: 1 addition & 1 deletion include/file_reader_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct file_reader_iterator{
using value_type = T;
using pointer = const T*;
using reference = const T&;
using iterator_category = std::output_iterator_tag;
using iterator_category = std::input_iterator_tag;

std::optional<T> current_{std::nullopt};
std::function<std::optional<T>(std::ifstream&)> read_element_;
Expand Down
1 change: 1 addition & 0 deletions include/sample_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
namespace train{

struct sample_reader : line_count_size<sample_reader> {
using iterator = file_reader_iterator<sample>;
std::string path_;

file_reader_iterator<sample> begin() const { return file_reader_iterator<sample>(to_line_reader<sample>(sample::from_string), path_); }
Expand Down
85 changes: 85 additions & 0 deletions include/stochastic_multiplex_sample_reader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#pragma once

#include <random>
#include <vector>
#include <algorithm>
#include <optional>
#include <sample.h>
#include <sample_reader.h>


namespace train {

struct stochastic_multiplex_sample_reader_iterator {
using difference_type = long;
using value_type = sample;
using pointer = const sample*;
using reference = const sample&;
using iterator_category = std::input_iterator_tag;

using distribution_type = std::discrete_distribution<size_t>;

size_t advance_idx_{};
size_t total_remaining_{};

std::mt19937 random_number_generator_;
distribution_type distribution_;
std::vector<size_t> remaining_;
std::vector<sample_reader::iterator> reader_begins_;

constexpr bool empty() const { return total_remaining_ == 0; }

constexpr bool operator==(const stochastic_multiplex_sample_reader_iterator& other) const { return empty() && other.empty(); }
constexpr bool operator!=(const stochastic_multiplex_sample_reader_iterator& other) const { return !(*this == other); }

stochastic_multiplex_sample_reader_iterator& operator++() {
++reader_begins_[advance_idx_];
--remaining_[advance_idx_];
--total_remaining_;

if (!empty()){
const auto params = distribution_type::param_type(remaining_.begin(), remaining_.end());
advance_idx_ = distribution_(random_number_generator_, params);
}

return *this;
}

sample operator*() const { return *reader_begins_[advance_idx_]; }


stochastic_multiplex_sample_reader_iterator() {}

stochastic_multiplex_sample_reader_iterator(
const std::vector<size_t>& sizes,
const std::vector<sample_reader::iterator>& reader_begins
) :
random_number_generator_(),
distribution_(sizes.begin(), sizes.end()),
remaining_{sizes},
reader_begins_{reader_begins}
{
total_remaining_ = std::reduce(remaining_.begin(), remaining_.end());
if (!empty()) { advance_idx_ = distribution_(random_number_generator_); }
}

};


struct stochastic_multiplex_sample_reader {
std::vector<size_t> sizes_;
std::vector<sample_reader> readers_;

stochastic_multiplex_sample_reader_iterator begin() const {
std::vector<sample_reader::iterator> reader_begins;
std::transform(readers_.begin(), readers_.end(), std::back_inserter(reader_begins), [](const sample_reader& reader) { return reader.begin(); });
return stochastic_multiplex_sample_reader_iterator(sizes_, reader_begins);
}

stochastic_multiplex_sample_reader_iterator end() const { return stochastic_multiplex_sample_reader_iterator(); }

stochastic_multiplex_sample_reader(const std::vector<size_t>& sizes, const std::vector<sample_reader>& readers) :
sizes_{sizes}, readers_{readers} {}
};

}
31 changes: 16 additions & 15 deletions include/training.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#pragma once

#include <chess_types.h>
#include <feature_util.h>
#include <move.h>
#include <nnue_model.h>
#include <search_constants.h>
#include <search_worker.h>
#include <embedded_weights.h>
#include <weights_streamer.h>
#include <chess/types.h>
#include <feature/util.h>
#include <chess/move.h>
#include <chess/board_history.h>
#include <nnue/eval.h>
#include <search/search_constants.h>
#include <search/search_worker.h>
#include <nnue/embedded_weights.h>
#include <nnue/weights_streamer.h>

#include <atomic>
#include <chrono>
Expand Down Expand Up @@ -74,15 +75,15 @@ struct feature_set : chess::sided<feature_set, std::set<size_t>> {
feature_set() : white{}, black{} {}
};

bool is_terminal(const chess::position_history& hist, const state_type& state) {
if (hist.is_two_fold(state.hash())) { return true; }
bool is_terminal(const chess::board_history& hist, const state_type& state) {
if (hist.count(state.hash())) { return true; }
if (state.generate_moves().size() == 0) { return true; }

return false;
}

result_type get_result(const chess::position_history& hist, const state_type& state) {
if (hist.is_two_fold(state.hash())) { return result_type::draw; }
result_type get_result(const chess::board_history& hist, const state_type& state) {
if (hist.count(state.hash())) { return result_type::draw; }
if (state.generate_moves().size() == 0) {
if (state.is_check()) { return result_type::loss; }
return result_type::draw;
Expand All @@ -97,13 +98,13 @@ result_type relative_result(const bool& pov_a, const bool& pov_b, const result_t

feature_set get_features(const state_type& state) {
feature_set features{};
state.feature_full_refresh(features);
state.feature_full_reset(features);
return features;
}

std::tuple<std::vector<nnue::weights::parameter_type>, std::vector<nnue::weights::parameter_type>> feature_transformer_parameters() {
nnue::embedded_weight_streamer streamer(embed::weights_file_data);
using feature_transformer_type = nnue::big_affine<nnue::weights::parameter_type, feature::half_ka::numel, nnue::weights::base_dim>;
nnue::embedded_weight_streamer streamer(nnue::embed::weights_file_data);
using feature_transformer_type = nnue::sparse_affine_layer<nnue::weights::parameter_type, feature::half_ka::numel, nnue::weights::base_dim>;
feature_transformer_type feature_transformer = nnue::weights{}.load(streamer).shared;
std::vector<nnue::weights::parameter_type> weights(feature_transformer.W, feature_transformer.W + feature_transformer_type::W_numel);
std::vector<nnue::weights::parameter_type> bias(feature_transformer.b, feature_transformer.b + feature_transformer_type::b_numel);
Expand Down
2 changes: 1 addition & 1 deletion pybind11
Submodule pybind11 updated 241 files
29 changes: 16 additions & 13 deletions scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,35 @@ fine_tune: False
model_save_path: "model/save.pt"
bin_model_save_path: "model/save.bin"

concurrency: 12
concurrency: 16
fixed_depth: 128
fixed_nodes: 5120
data_write_path: "/media/connor/7F35A067038168A9/seer_train3/data.txt"
fixed_nodes: 10240
eval_limit: 1792
data_write_path: "/home/connor/steam_games/seer_train/data.txt"
tt_mb_size: 2048
target_sample_count: 600000000

data_read_paths:
[
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n10240_wdl_latest.txt,
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n20480_wdl.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_n5120_wdl.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_n5120_wdl.txt,
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n5120_wdl2.txt,
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n5120_wdl3.txt,
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n5120_wdl4.txt,
/home/connor/steam_games/seer_train_attu/terashuf/data_shuf_n5120_wdl5.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d6_wdl2.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d6_wdl3.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d10_wdl.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d10_wdl2.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d10_wdl3.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d8_wdl.txt,
/media/connor/7F35A067038168A9/seer_train3/terashuf/data_shuf_old_d8_wdl2.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d6_wdl2.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d6_wdl3.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d10_wdl.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d10_wdl2.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d10_wdl3.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d8_wdl.txt,
/media/connor/7F35A067038168A91/seer_train3/terashuf/data_shuf_old_d8_wdl2.txt,
]

data_read_lengths:
[
600000000,
1078540650,
1139897403,
883361685,
Expand All @@ -48,7 +51,7 @@ epochs: 100000
shuffle_buffer_size: 500_000

mirror_probability: 0.5
batch_size: 1024
batch_size: 4096

visual_directory: "visual"

Expand All @@ -57,7 +60,7 @@ report_rate: 50
test_rate: 50
max_queue_size: 128

gamma: 0.05
gamma: 1.0
step_size: 100

learning_rate: 0.07
Loading

0 comments on commit d9ec877

Please sign in to comment.