diff --git a/src/estimators.jl b/src/estimators.jl index afb45f0a..888698ab 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -145,7 +145,7 @@ mutable struct GComputation <: CausalEstimator """Weights learned during training""" β::Array{Float64} """The effect of exposure or treatment""" - causal_effect::Vector{Float64} + causal_effect::Float64 """ GComputation(X, Y, T, task, quantity_of_interest, regularized, activation, temporal, @@ -233,7 +233,7 @@ mutable struct DoublyRobust <: CausalEstimator """Predicted outcomes for the treatment group""" μ₁::Array{Float64} """The effect of exposure or treatment""" - causal_effect::Vector{Float64} + causal_effect::Float64 """ DoublyRobust(X, Xₚ, Y, T, task, quantity_of_interest, regularized, activation, @@ -362,7 +362,7 @@ function estimatecausaleffect!(g::GComputation) end g.β = fit!(g.learner) - g.causal_effect = [sum(predict(g.learner, Xₜ) - predict(g.learner, Xᵤ))/size(Xₜ, 1)] + g.causal_effect = sum(predict(g.learner, Xₜ) - predict(g.learner, Xᵤ))/size(Xₜ, 1) return g.causal_effect end @@ -446,7 +446,7 @@ function estimatecausaleffect!(DRE::DoublyRobust) DRE.μ₁ = reduce(vcat, treatment_predictions) end - DRE.causal_effect = [mean(fold_level_effects)] + DRE.causal_effect = mean(fold_level_effects) return DRE.causal_effect end diff --git a/src/inference.jl b/src/inference.jl index b9d4ba9e..8edefeda 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -4,7 +4,7 @@ module Inference using CausalELM: mean using ..Metalearners: Metalearner using ..Estimators: CausalEstimator, InterruptedTimeSeries, GComputation, DoublyRobust, - estimatecausaleffect!, mean + estimatecausaleffect! import CausalELM: summarize @@ -226,7 +226,7 @@ function generatenulldistribution(e::Union{CausalEstimator, Metalearner}, n::Int for iter in 1:n m.T = float(rand(0:1, nobs)) estimatecausaleffect!(m) - results[iter] = ifelse(e isa Metalearner, mean(m.causal_effect), m.causal_effect[1]) + results[iter] = e isa Metalearner ? mean(m.causal_effect) : m.causal_effect end return results end @@ -311,9 +311,9 @@ julia> quantitiesofinterest(g_computer, 1000) (0.114, 6.953133617011371) ``` """ -function quantitiesofinterest(model::Union{CausalEstimator, Metalearner}, n::Integer=1000) - local null_dist = generatenulldistribution(model, n) - local avg_effect = mean(model.causal_effect) +function quantitiesofinterest(m::Union{CausalEstimator, Metalearner}, n::Integer=1000) + local null_dist = generatenulldistribution(m, n) + local avg_effect = m isa Metalearner ? mean(m.causal_effect) : m.causal_effect extremes = length(null_dist[abs(avg_effect) .< abs.(null_dist)]) pvalue = extremes/n diff --git a/src/model_validation.jl b/src/model_validation.jl index 1c70c701..902b3830 100644 --- a/src/model_validation.jl +++ b/src/model_validation.jl @@ -1,7 +1,7 @@ module ModelValidation -using ..Estimators: InterruptedTimeSeries, estimatecausaleffect! -using CausalELM: mean +using ..Estimators: InterruptedTimeSeries, estimatecausaleffect!, GComputation +using CausalELM: mean, var using LinearAlgebra: norm """ @@ -112,7 +112,7 @@ end See how an omitted predictor/variable could change the results of an interrupted time series analysis. -This method reestimates interrupted time series models with normal random variables with +This method reestimates interrupted time series models with normal random variables and uniform noise. If the included covariates are good predictors of the counterfactual outcome, adding a random variable as a covariate should not have a large effect on the predicted counterfactual outcomes and therefore the estimated average effect. @@ -260,6 +260,12 @@ function pval(x::Array{Float64}, y::Array{Float64}, β::Float64; n::Int=1000, return p end +function counterfactualconsistency(g::GComputation) + treatment_covariates, treatment_outcomes = g.X[g.T == 1, :], g.Y[g.T == 1] + ŷ = treatment_covariates\treatment_outcomes + observed_residual_variance = var(ŷ) +end + """ ned(a, b) @@ -283,6 +289,8 @@ function ned(a::Vector{T}, b::Vector{T}) where T <: Number a = reduce(vcat, (a, zeros(abs(length(a)-length(b))))) end end + + # Changing NaN to zero fixes divde by zero errors @fastmath norm(replace(sort(a)./norm(a), NaN=>0) .- replace((sort(b)./norm(b)), NaN=>0)) end diff --git a/test/test_estimators.jl b/test/test_estimators.jl index f2bd6b2e..9371ba19 100644 --- a/test/test_estimators.jl +++ b/test/test_estimators.jl @@ -81,10 +81,10 @@ end @testset "G-Computation Estimation" begin @test isa(g_computer.β, Array) - @test isa(g_computer.causal_effect, Vector{Float64}) + @test isa(g_computer.causal_effect, Float64) # Check that the estimats for ATE and ATT are different - @test g_computer.causal_effect[1] !== gcomputer_att.causal_effect[1] + @test g_computer.causal_effect !== gcomputer_att.causal_effect end @testset "Doubly Robust Estimation Structure" begin @@ -127,23 +127,23 @@ end @test dr.ps isa Array{Float64} @test dr.μ₀ isa Array{Float64} @test dr.μ₁ isa Array{Float64} - @test dr.causal_effect isa Vector{Float64} + @test dr.causal_effect isa Float64 # No regularization @test dr_noreg.ps isa Array{Float64} @test dr_noreg.μ₀ isa Array{Float64} @test dr_noreg.μ₁ isa Array{Float64} - @test dr_noreg.causal_effect isa Vector{Float64} + @test dr_noreg.causal_effect isa Float64 # Using the ATT @test dr_att.ps isa Array{Float64} @test dr_att.μ₀ isa Array{Float64} - @test dr_att.causal_effect isa Vector{Float64} + @test dr_att.causal_effect isa Float64 # Using the ATT with no regularization @test dr_att_noreg.ps isa Array{Float64} @test dr_att_noreg.μ₀ isa Array{Float64} - @test dr_att_noreg.causal_effect isa Vector{Float64} + @test dr_att_noreg.causal_effect isa Float64 end @testset "Quanities of Interest Errors" begin