-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
98 lines (75 loc) · 2.21 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/* g++ -O2 -flto main.cpp -lsfml-system -lsfml-window -lsfml-graphics -pthread */
#include "header.hpp"
static inline float mnp(float n, int start1, int stop1, int start2, int stop2) {
return ((n - start1) / (stop1 - start1)) * (stop2 - start2) + start2;
}
void setup() {
nn = NeuralNetwork(2, 4, 1);
nn.setLearningRate(learning_rate);
}
void gradientDescent() {
for (const auto& shape : shapes) {
const float& x = shape.getPosition().x;
const float& y = shape.getPosition().y;
const float& guess = m * x + b;
const float& error = y - guess;
m += error / x * learning_rate;
b += error * learning_rate;
}
}
void drawLine() {
float x1 = 0, y1 = m * x1 + b,
x2 = 0, y2 = m * x2 + b;
x1 = mnp(x1, 0, 1, 0, widthI);
y1 = mnp(y1, 0, 1, 0, heightI);
x2 = mnp(x2, 0, 1, widthI, 0);
y2 = mnp(y2, 0, 1, heightI, 0);
const sf::Vertex vertices[2] = {
sf::Vertex(sf::Vector2f(x1, y1), sf::Color::White),
sf::Vertex(sf::Vector2f(x2, y2), sf::Color::White),
};
window.draw(vertices, 2, sf::Lines);
}
void mousePressed() {
const auto& mousepos = window.mapPixelToCoords(sf::Mouse::getPosition(window));
const double x = mousepos.x;
const double y = mousepos.y;
sf::CircleShape point;
point.setRadius(8);
point.setOutlineColor(sf::Color::White);
point.setPosition(x, y);
shapes.push_back(point);
}
void itLegal() {
if (shapes.size() > 1) {
gradientDescent();
drawLine();
}
}
void draw() {
window.clear({red, green, blue});
for (const auto& shape : shapes) {
window.draw(shape);
}
std::thread statements(itLegal);
statements.join();
window.display();
}
int main() {
std::thread pre(setup);
pre.join();
while (window.isOpen()) {
std::thread first(setup);
first.join();
sf::Event event;
while (window.pollEvent(event)) {
if (event.type == sf::Event::Closed)
window.close();
else if (event.type == sf::Event::MouseButtonPressed)
mousePressed();
}
std::thread second(draw);
second.join();
}
return 0;
}