Skip to content

Commit

Permalink
Add helper function CreationExtras.withCreationCallback() and Mutable…
Browse files Browse the repository at this point in the history
…CreationExtras.addCreationCallback()

These can be used to pass creation callbacks with functions like `viewModels()`.

RELNOTES=Add helper function CreationExtras.withCreationCallback() and MutableCreationExtras.addCreationCallback()
PiperOrigin-RevId: 577890589
  • Loading branch information
kuanyingchou authored and Dagger Team committed Oct 30, 2023
1 parent 05f2b70 commit 6fe4a23
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 29 deletions.
4 changes: 3 additions & 1 deletion java/dagger/hilt/android/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# Description:
# A library based on Hilt that provides standard components and automated injection for Android.
load("//:build_defs.bzl", "POM_VERSION")
load("//tools:maven.bzl", "gen_maven_artifact")
load("//tools:bazel_compat.bzl", "compat_kt_android_library")
load("//tools:maven.bzl", "gen_maven_artifact")

package(default_visibility = ["//:src"])

Expand All @@ -39,6 +39,7 @@ android_library(
"//java/dagger/hilt/android/internal/managers:component_supplier",
"//java/dagger/hilt/android/internal/modules",
"//java/dagger/hilt/android/lifecycle:hilt_view_model",
"//java/dagger/hilt/android/lifecycle:hilt_view_model_extensions",
"//java/dagger/hilt/codegen:originating_element",
"//java/dagger/hilt/internal:component_entry_point",
"//java/dagger/hilt/internal:component_manager",
Expand Down Expand Up @@ -197,6 +198,7 @@ gen_maven_artifact(
"//java/dagger/hilt/android/internal/migration:injected_by_hilt",
"//java/dagger/hilt/android/internal/modules",
"//java/dagger/hilt/android/lifecycle:hilt_view_model",
"//java/dagger/hilt/android/lifecycle:hilt_view_model_extensions",
"//java/dagger/hilt/android/lifecycle:package_info",
"//java/dagger/hilt/android/lifecycle:retained_lifecycle",
"//java/dagger/hilt/android/migration:custom_inject",
Expand Down
13 changes: 13 additions & 0 deletions java/dagger/hilt/android/lifecycle/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# Description:
# Hilt ViewModel integration.

load("//tools:bazel_compat.bzl", "compat_kt_android_library")

package(default_visibility = ["//:src"])

java_library(
Expand Down Expand Up @@ -54,6 +56,17 @@ android_library(
],
)

compat_kt_android_library(
name = "hilt_view_model_extensions",
srcs = ["HiltViewModelExtensions.kt"],
deps = [
":package_info",
"//java/dagger/hilt/android/internal/lifecycle",
"@maven//:androidx_annotation_annotation",
"@maven//:androidx_lifecycle_lifecycle_viewmodel",
],
)

filegroup(
name = "srcs_filegroup",
srcs = glob(["*"]),
Expand Down
41 changes: 40 additions & 1 deletion java/dagger/hilt/android/lifecycle/HiltViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,46 @@
* }
* </pre>
*
* <p>Exactly one constructor in the {@code ViewModel} must be annotated with {@code Inject}.
* <p>{@code ViewModel}s annotated with {@link HiltViewModel} can also be used with assisted
* injection:
*
* <pre>
* &#64;HiltViewModel(assistedFactory = DonutViewModel.Factory.class)
* public class DonutViewModel extends ViewModel {
* &#64;AssistedInject
* public DonutViewModel(
* SavedStateHandle handle,
* RecipeRepository repository,
* $#64;Assisted int donutId
* ) {
* // ...
* }
*
* &#64;AssistedFactory
* public interface Factory {
* DonutViewModel create(int donutId);
* }
* }
* </pre>
*
* <pre>
* &#64;AndroidEntryPoint
* public class CookingActivity extends AppCompatActivity {
* public void onCreate(Bundle savedInstanceState) {
* DonutViewModel vm = new ViewModelProvider(
* getViewModelStore(),
* getDefaultViewModelProviderFactory(),
* HiltViewModelExtensions.withCreationCallback(
* getDefaultViewModelCreationExtras(),
* (DonutViewModel.Factory factory) -> factory.create(1)
* )
* ).get(DonutViewModel.class);
* }
* }
* </pre>
*
* <p>Exactly one constructor in the {@code ViewModel} must be annotated with {@code Inject} or
* {@code AssistedInject}.
*
* <p>Only dependencies available in the {@link dagger.hilt.android.components.ViewModelComponent}
* can be injected into the {@code ViewModel}.
Expand Down
48 changes: 48 additions & 0 deletions java/dagger/hilt/android/lifecycle/HiltViewModelExtensions.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (C) 2023 The Dagger Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

@file:JvmName("HiltViewModelExtensions")

package dagger.hilt.android.lifecycle

import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.MutableCreationExtras
import dagger.hilt.android.internal.lifecycle.HiltViewModelFactory

/**
* Returns a new {@code CreationExtras} with the original entries plus the passed in creation
* callback. The callback is used by Hilt to create {@link AssistedInject}-annotated {@link
* HiltViewModel}s.
*
* @param callback A creation callback that takes an assisted factory and returns a {@code
* ViewModel}.
*/
fun <VMF> CreationExtras.withCreationCallback(callback: (VMF) -> ViewModel): CreationExtras =
MutableCreationExtras(this).addCreationCallback(callback)

/**
* Returns the {@code MutableCreationExtras} with the passed in creation callback added. The
* callback is used by Hilt to create {@link AssistedInject}-annotated {@link HiltViewModel}s.
*
* @param callback A creation callback that takes an assisted factory and returns a {@code
* ViewModel}.
*/
@Suppress("UNCHECKED_CAST")
fun <VMF> MutableCreationExtras.addCreationCallback(callback: (VMF) -> ViewModel): CreationExtras =
this.apply {
this[HiltViewModelFactory.CREATION_CALLBACK_KEY] = { factory -> callback(factory as VMF) }
}
1 change: 1 addition & 0 deletions javatests/dagger/hilt/android/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ android_local_test(
"//java/dagger/hilt/android:view_model_lifecycle",
"//java/dagger/hilt/android/internal/lifecycle",
"//java/dagger/hilt/android/lifecycle:hilt_view_model",
"//java/dagger/hilt/android/lifecycle:hilt_view_model_extensions",
"//java/dagger/hilt/android/scopes",
"//java/dagger/hilt/android/testing:hilt_android_test",
"//third_party/java/jsr330_inject",
Expand Down
49 changes: 22 additions & 27 deletions javatests/dagger/hilt/android/ViewModelAssistedTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,21 @@
import androidx.fragment.app.Fragment;
import androidx.fragment.app.FragmentActivity;
import androidx.annotation.Nullable;
import androidx.lifecycle.HasDefaultViewModelProviderFactory;
import androidx.lifecycle.SavedStateHandle;
import androidx.lifecycle.ViewModel;
import androidx.lifecycle.ViewModelProvider;
import androidx.lifecycle.viewmodel.CreationExtras;
import androidx.lifecycle.viewmodel.MutableCreationExtras;
import androidx.test.core.app.ActivityScenario;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import dagger.assisted.Assisted;
import dagger.assisted.AssistedFactory;
import dagger.assisted.AssistedInject;
import dagger.hilt.android.internal.lifecycle.HiltViewModelFactory;
import dagger.hilt.android.lifecycle.HiltViewModel;
import dagger.hilt.android.lifecycle.HiltViewModelExtensions;
import dagger.hilt.android.scopes.ViewModelScoped;
import dagger.hilt.android.testing.HiltAndroidRule;
import dagger.hilt.android.testing.HiltAndroidTest;
import dagger.hilt.android.testing.HiltTestApplication;
import javax.inject.Inject;
import kotlin.jvm.functions.Function1;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -158,16 +154,18 @@ protected void onCreate(@Nullable Bundle savedInstanceState) {
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("foo")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("foo")))
.get(MyViewModel.class);
} else {
vm =
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("bar")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("bar")))
.get(MyViewModel.class);
}
}
Expand All @@ -187,16 +185,18 @@ protected void onCreate(@Nullable Bundle savedInstanceState) {
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("foo")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("foo")))
.get("a", MyViewModel.class);

vm2 =
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("bar")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("bar")))
.get("b", MyViewModel.class);
}
}
Expand Down Expand Up @@ -228,8 +228,9 @@ protected void onCreate(@Nullable Bundle savedInstanceState) {
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("bar")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("bar")))
.get(MyInjectedViewModel.class);
}
}
Expand Down Expand Up @@ -266,8 +267,9 @@ protected void onCreate(@Nullable Bundle savedInstanceState) {
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.AnotherFactory) factory).create("foo")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.AnotherFactory factory) -> factory.create("foo")))
.get(MyViewModel.class);
}
}
Expand All @@ -284,20 +286,13 @@ public void onCreate(@Nullable Bundle bundle) {
new ViewModelProvider(
getViewModelStore(),
getDefaultViewModelProviderFactory(),
getCreationExtrasWithCreationCallback(
this, factory -> ((MyViewModel.Factory) factory).create("foo")))
HiltViewModelExtensions.withCreationCallback(
getDefaultViewModelCreationExtras(),
(MyViewModel.Factory factory) -> factory.create("foo")))
.get(MyViewModel.class);
}
}

private static CreationExtras getCreationExtrasWithCreationCallback(
HasDefaultViewModelProviderFactory owner, Function1<Object, ViewModel> callback) {
MutableCreationExtras extras =
new MutableCreationExtras(owner.getDefaultViewModelCreationExtras());
extras.set(HiltViewModelFactory.CREATION_CALLBACK_KEY, callback);
return extras;
}

@HiltViewModel(assistedFactory = MyViewModel.Factory.class)
static class MyViewModel extends ViewModel {

Expand Down

0 comments on commit 6fe4a23

Please sign in to comment.