Skip to content

Commit

Permalink
Updating one-hot to be a bit better
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Feb 21, 2024
1 parent f29cc90 commit daaf470
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 30 deletions.
8 changes: 7 additions & 1 deletion deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@
org.xerial.snappy/snappy-java {:mvn/version "1.1.8.4"}
;org.tribuo/tribuo-all {:mvn/version "4.2.0" :extension "pom"}
}
:extra-paths ["test"]}
:extra-paths ["test"]
:jvm-opts ["-Djdk.attach.allowAttachSelf=true" "--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED"
"--illegal-access=permit" "--add-opens=java.base/java.lang=ALL-UNNAMED"
"--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED"
"--add-opens=java.base/java.util.concurrent=ALL-UNNAMED"
]
}

:jdk-8 {}
:jdk-11
Expand Down
56 changes: 40 additions & 16 deletions src/tech/v3/dataset/categorical.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

(:require [tech.v3.dataset.base :as ds-base]
[tech.v3.dataset.protocols :as ds-proto]
[tech.v3.dataset.impl.dataset :as ds-impl]
[tech.v3.dataset.impl.column :as col-impl]
[tech.v3.dataset.impl.column-base :as col-base]
[tech.v3.datatype :as dtype]
[tech.v3.datatype.protocols :as dtype-proto]
[tech.v3.datatype.casting :as casting]
[tech.v3.datatype.errors :as errors]
[tech.v3.datatype.bitmap :as bitmap]
[ham-fisted.lazy-noncaching :as lznc]
[ham-fisted.set :as set]))


Expand Down Expand Up @@ -218,22 +220,44 @@ Non integers found: " (vec bad-mappings)))))
one-hot-fit-data
column (ds-base/column dataset src-column)
missing (ds-proto/missing column)
dataset (dissoc dataset src-column)]
(->> one-hot-table
(mapcat
(fn [[k v]]
[v (col-impl/new-column
v
(dtype/emap
#(if (= % k)
1
0)
result-datatype
column)
(assoc (meta column)
:one-hot-map one-hot-fit-data)
missing)]))
(apply assoc dataset))))
dataset (dissoc dataset src-column)
n-elems (dtype/ecount column)
op-space (casting/simple-operation-space (dtype-proto/operational-elemwise-datatype column))]
(merge dataset
(->> one-hot-table
(lznc/map
(fn [[k v]]
(col-impl/new-column
v
(case op-space
:int64
(let [buf (dtype/->reader column :int64)
k (long k)]
(reify tech.v3.datatype.LongBuffer
(elemwiseDatatype [this] :int8)
(lsize [this] n-elems)
(readLong [this idx]
(if (== k (.readLong buf idx))
1 0))))
:float64
(let [buf (dtype/->reader column :float64)
k (double k)]
(reify tech.v3.datatype.DoubleBuffer
(elemwiseDatatype [this] :int8)
(lsize [this] n-elems)
(readLong [this idx]
(if (== k (.readDouble buf idx))
1 0))))
(dtype/emap
#(if (= % k)
1
0)
result-datatype
column))
(assoc (meta column)
:one-hot-map one-hot-fit-data)
missing)))
(ds-impl/new-dataset)))))


(extend-type OneHotMap
Expand Down
25 changes: 14 additions & 11 deletions src/tech/v3/dataset/reductions.clj
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ user> (ds-reduce/group-by-column-agg
[colname]
(reducer->column-reducer (hamf-rf/reducer-with-finalize
(Sum.)
#((deref %) :sum)) colname))
#((deref %) :sum))
colname))


(defn mean
Expand Down Expand Up @@ -482,16 +483,18 @@ _unnamed [4 5]:
:skip-finalize? true
:min-n 1000}))))
([agg-map]
(let [c ((hamf-proto/->init-val-fn (io-mapseq/mapseq-reducer nil)))]
;;Also possible to parse N datasets in parallel and do a concat-copying
;;operation but in my experience this steps takes up nearly no time.
(.forEach ^ConcurrentHashMap agg-map 32
(hamf-fn/bi-consumer
k v
(do
(let [vv (finalize-fn v)]
(locking c (.accept ^Consumer c vv))))))
@c))))))
(if (get options :skip-finalize?)
agg-map
(let [c ((hamf-proto/->init-val-fn (io-mapseq/mapseq-reducer nil)))]
;;Also possible to parse N datasets in parallel and do a concat-copying
;;operation but in my experience this steps takes up nearly no time.
(.forEach ^ConcurrentHashMap agg-map 32
(hamf-fn/bi-consumer
k v
(do
(let [vv (finalize-fn v)]
(locking c (.accept ^Consumer c vv))))))
@c)))))))


(defn group-by-column-agg
Expand Down
4 changes: 2 additions & 2 deletions src/tech/v3/dataset/reductions/impl.clj
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
acc v (.invokePrim rfn acc (.readLong col v))))
hamf-proto/ParallelReducer
(->merge-fn [r] merge-fn)))
:float32
:float64
(let [rfn (Transformables/toDoubleReductionFn rfn)]
(reify
hamf-proto/Reducer
(->init-val-fn [r] init-fn)
(->rfn [r] (hamf-rf/long-accumulator
(->rfn [r] (hamf-rf/double-accumulator
acc v (.invokePrim rfn acc (.readDouble col v))))
hamf-proto/ParallelReducer
(->merge-fn [r] merge-fn)))
Expand Down

0 comments on commit daaf470

Please sign in to comment.