-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathassociation.R
91 lines (67 loc) · 2.5 KB
/
association.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
library(sparklyr)
library(dplyr)
library(readr)
library(purrr)
library(igraph)
library(visNetwork)
# FPGrowth in Spark -------------------------------------------------------
# Spark properties
conf <- spark_config()
conf$`sparklyr.cores.local` <- 4
conf$`sparklyr.shell.driver-memory` <- "8G"
conf$`spark.memory.fraction` <- 0.9
sc <- spark_connect(master = "local", version = "2.2.0", config = conf)
# this is our data from Instacart
orders <- spark_read_csv(sc, "orders", "instacart_2017_05_01/order_products__prior.csv")
orders_wide <- orders %>%
group_by(order_id) %>%
summarise(items = collect_list(product_id))
# use FP Growth
fpg.fit <- ml_fpgrowth(orders_wide, items_col = "items", min_confidence = .03, min_support = .01)
rules <- ml_association_rules(fpg.fit) %>% collect()
# these are our rules
asso <-
tibble(
antecedent = unlist(rules$antecedent),
consequent = unlist(rules$consequent),
confidence = rules$confidence
)
# remember to close connection
spark_disconnect_all()
# iGraph ------------------------------------------------------------------
# get product names
products <- read_csv("instacart_2017_05_01/products.csv")
# bind to nodes
nodes <- data.frame(id = unique(asso$antecedent, asso$consequent)) %>%
distinct() %>%
left_join(products, by = c("id" = "product_id")) %>%
select(id, label = product_name)
edges <- asso %>% mutate(weight = confidence * 10)
df.g <- graph_from_data_frame(edges, directed = TRUE, vertices = nodes)
plot(
df.g,
edge.arrow.size = .1,
edge.curved = .3,
edge.width = edges$weight,
vertex.color = "lightblue",
vertex.label.color = "darkblue",
vertex.label.cex = .7,
edge.label.cex = .7
)
# VisNetwork --------------------------------------------------------------
nodes <- data.frame(id = unique(asso$antecedent, asso$consequent)) %>%
distinct() %>%
left_join(products, by = c("id" = "product_id")) %>%
select(id, label = product_name)
edges <- asso %>%
mutate(width = confidence * 20,
smooth = TRUE, arrows = "to",
label = format(confidence, digits = 2)) %>%
rename(from = antecedent, to = consequent)
visNetwork(nodes, edges, height = "600px", width = "100%")
# Finding Subgroups -------------------------------------------------------
net.sym <- as.undirected(df.g, mode = "collapse",
edge.attr.comb = list(weight = "sum", "ignore"))
ceb <- cluster_edge_betweenness(net.sym)
dendPlot(ceb, mode = "hclust")
plot(ceb, net.sym)