Skip to content

Commit

Permalink
Fixed bounds issues with inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Oct 25, 2023
1 parent b7dbe9f commit 4f11df8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
8 changes: 4 additions & 4 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ mutable struct GComputation <: CausalEstimator
"""Weights learned during training"""
β::Array{Float64}
"""The effect of exposure or treatment"""
causal_effect::Vector{Float64}
causal_effect::Float64

"""
GComputation(X, Y, T, task, quantity_of_interest, regularized, activation, temporal,
Expand Down Expand Up @@ -233,7 +233,7 @@ mutable struct DoublyRobust <: CausalEstimator
"""Predicted outcomes for the treatment group"""
μ₁::Array{Float64}
"""The effect of exposure or treatment"""
causal_effect::Vector{Float64}
causal_effect::Float64

"""
DoublyRobust(X, Xₚ, Y, T, task, quantity_of_interest, regularized, activation,
Expand Down Expand Up @@ -362,7 +362,7 @@ function estimatecausaleffect!(g::GComputation)
end

g.β = fit!(g.learner)
g.causal_effect = [sum(predict(g.learner, Xₜ) - predict(g.learner, Xᵤ))/size(Xₜ, 1)]
g.causal_effect = sum(predict(g.learner, Xₜ) - predict(g.learner, Xᵤ))/size(Xₜ, 1)

return g.causal_effect
end
Expand Down Expand Up @@ -446,7 +446,7 @@ function estimatecausaleffect!(DRE::DoublyRobust)
DRE.μ₁ = reduce(vcat, treatment_predictions)
end

DRE.causal_effect = [mean(fold_level_effects)]
DRE.causal_effect = mean(fold_level_effects)
return DRE.causal_effect
end

Expand Down
10 changes: 5 additions & 5 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Inference
using CausalELM: mean
using ..Metalearners: Metalearner
using ..Estimators: CausalEstimator, InterruptedTimeSeries, GComputation, DoublyRobust,
estimatecausaleffect!, mean
estimatecausaleffect!

import CausalELM: summarize

Expand Down Expand Up @@ -226,7 +226,7 @@ function generatenulldistribution(e::Union{CausalEstimator, Metalearner}, n::Int
for iter in 1:n
m.T = float(rand(0:1, nobs))
estimatecausaleffect!(m)
results[iter] = ifelse(e isa Metalearner, mean(m.causal_effect), m.causal_effect[1])
results[iter] = e isa Metalearner ? mean(m.causal_effect) : m.causal_effect
end
return results
end
Expand Down Expand Up @@ -311,9 +311,9 @@ julia> quantitiesofinterest(g_computer, 1000)
(0.114, 6.953133617011371)
```
"""
function quantitiesofinterest(model::Union{CausalEstimator, Metalearner}, n::Integer=1000)
local null_dist = generatenulldistribution(model, n)
local avg_effect = mean(model.causal_effect)
function quantitiesofinterest(m::Union{CausalEstimator, Metalearner}, n::Integer=1000)
local null_dist = generatenulldistribution(m, n)
local avg_effect = m isa Metalearner ? mean(m.causal_effect) : m.causal_effect

extremes = length(null_dist[abs(avg_effect) .< abs.(null_dist)])
pvalue = extremes/n
Expand Down
14 changes: 11 additions & 3 deletions src/model_validation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ModelValidation

using ..Estimators: InterruptedTimeSeries, estimatecausaleffect!
using CausalELM: mean
using ..Estimators: InterruptedTimeSeries, estimatecausaleffect!, GComputation
using CausalELM: mean, var
using LinearAlgebra: norm

"""
Expand Down Expand Up @@ -112,7 +112,7 @@ end
See how an omitted predictor/variable could change the results of an interrupted time series
analysis.
This method reestimates interrupted time series models with normal random variables with
This method reestimates interrupted time series models with normal random variables and
uniform noise. If the included covariates are good predictors of the counterfactual outcome,
adding a random variable as a covariate should not have a large effect on the predicted
counterfactual outcomes and therefore the estimated average effect.
Expand Down Expand Up @@ -260,6 +260,12 @@ function pval(x::Array{Float64}, y::Array{Float64}, β::Float64; n::Int=1000,
return p
end

function counterfactualconsistency(g::GComputation)
treatment_covariates, treatment_outcomes = g.X[g.T == 1, :], g.Y[g.T == 1]
= treatment_covariates\treatment_outcomes
observed_residual_variance = var(ŷ)
end

"""
ned(a, b)
Expand All @@ -283,6 +289,8 @@ function ned(a::Vector{T}, b::Vector{T}) where T <: Number
a = reduce(vcat, (a, zeros(abs(length(a)-length(b)))))
end
end

# Changing NaN to zero fixes divde by zero errors
@fastmath norm(replace(sort(a)./norm(a), NaN=>0) .- replace((sort(b)./norm(b)), NaN=>0))
end

Expand Down
12 changes: 6 additions & 6 deletions test/test_estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ end

@testset "G-Computation Estimation" begin
@test isa(g_computer.β, Array)
@test isa(g_computer.causal_effect, Vector{Float64})
@test isa(g_computer.causal_effect, Float64)

# Check that the estimats for ATE and ATT are different
@test g_computer.causal_effect[1] !== gcomputer_att.causal_effect[1]
@test g_computer.causal_effect !== gcomputer_att.causal_effect
end

@testset "Doubly Robust Estimation Structure" begin
Expand Down Expand Up @@ -127,23 +127,23 @@ end
@test dr.ps isa Array{Float64}
@test dr.μ₀ isa Array{Float64}
@test dr.μ₁ isa Array{Float64}
@test dr.causal_effect isa Vector{Float64}
@test dr.causal_effect isa Float64

# No regularization
@test dr_noreg.ps isa Array{Float64}
@test dr_noreg.μ₀ isa Array{Float64}
@test dr_noreg.μ₁ isa Array{Float64}
@test dr_noreg.causal_effect isa Vector{Float64}
@test dr_noreg.causal_effect isa Float64

# Using the ATT
@test dr_att.ps isa Array{Float64}
@test dr_att.μ₀ isa Array{Float64}
@test dr_att.causal_effect isa Vector{Float64}
@test dr_att.causal_effect isa Float64

# Using the ATT with no regularization
@test dr_att_noreg.ps isa Array{Float64}
@test dr_att_noreg.μ₀ isa Array{Float64}
@test dr_att_noreg.causal_effect isa Vector{Float64}
@test dr_att_noreg.causal_effect isa Float64
end

@testset "Quanities of Interest Errors" begin
Expand Down

0 comments on commit 4f11df8

Please sign in to comment.