-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
49 lines (44 loc) · 1.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# **************************************************************************** #
# #
# ::: :::::::: #
# train.py :+: :+: :+: #
# +:+ +:+ +:+ #
# By: akaseris <akaseris@student.42.fr> +#+ +:+ +#+ #
# +#+#+#+#+#+ +#+ #
# Created: 2018/10/04 18:23:45 by akaseris #+# #+# #
# Updated: 2018/10/04 20:14:51 by akaseris ### ########.fr #
# #
# **************************************************************************** #
from tools import getFile, estimatePrice, calcError, saveTheta
def calcSums(th0, th1, dataList):
sum0 = 0
sum1 = 1
for i in dataList:
km = float(i[0])
price = float(i[1])
sum0 += estimatePrice(km, th0, th1) - price
sum1 += (estimatePrice(km, th0, th1) - price) * km
return sum0, sum1
def calcTheta(theta0, theta1, learningRate, dataList):
M = len(dataList)
while (True):
sum0, sum1 = calcSums(theta0, theta1, dataList)
tmp0 = learningRate * sum0 / float(M)
tmp1 = learningRate * sum1 / float(M)
if abs(tmp0) < float(0.0000001) and abs(tmp1) < float(0.0000001):
return (theta0 * 1000, theta1)
theta0 = theta0 - tmp0
theta1 = theta1 - tmp1
def main():
dataList = getFile()
theta0 = 0.0
theta1 = 0.0
learningRate = 0.0001
print("Running...")
theta0, theta1 = calcTheta(theta0, theta1, learningRate, dataList)
print("Saving...")
err = calcError(theta0 / 1000, theta1, dataList, len(dataList))
saveTheta(theta0, theta1, err, dataList)
print("Theta0: {}\nTheta1: {}\nMSE: {}".format(theta0, theta1, err))
if __name__ == '__main__':
main()