forked from mark-schmidt/Class
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot2Dclassifier.jl
45 lines (35 loc) · 1.06 KB
/
plot2Dclassifier.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using PyCall
using PyPlot
function plot2Dclassifier(X,y,model;proba=false)
increment = 100
figure()
plot(X[y.==-1,1],X[y.==-1,2],"bo")
plot(X[y.==1,1],X[y.==1,2],"r^")
(xmin,xmax) = xlim()
xDomain = range(xmin,stop=xmax,length=increment)
(ymin,ymax) = ylim()
yDomain = range(ymin,stop=ymax,length=increment)
xValues = repeat(xDomain,1,length(xDomain))
yValues = repeat(yDomain',length(yDomain),1)
z = model.predict([xValues[:] yValues[:]])
if proba
z = model.predict_proba([xValues[:] yValues[:]])[:,end]
end
@assert(length(z) == length(xValues),"Size of model function's output is wrong");
zValues = reshape(z,size(xValues))
if all(zValues[:] == 1)
cm = [(0,0,.5)];
elseif all(zValues[:] == -1)
cm = [(.5,0,0)];
else
cm = [(0,0,.5);(.5,0,0)];
end
matcolors = pyimport("matplotlib.colors")
if proba
cs = contourf(xValues, yValues, zValues, levels=15, cmap="RdBu_r")
colorbar(cs)
else
cmap = matcolors.ListedColormap(cm,"A")
contourf(xValues,yValues,zValues,cmap=cmap)
end
end