diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h
index ee4fb33349f7..6f26d07dc8a5 100644
--- a/include/tvm/te/schedule.h
+++ b/include/tvm/te/schedule.h
@@ -378,6 +378,18 @@ class Schedule : public ObjectRef {
    * \return A normalized schedule, can be same as current one.
    */
   Schedule normalize();
+
+  /*!
+   * \brief Normalize the schedule for feature extraction in auto-scheduler.
+   * This is similar to `Schedule::normalize`, but we do aggressive simplification
+   * to the TE compute with const_matrix=True for faster compilation and feature extraction.
+   * The resulted schedule may be wrong, but it is good enough for feature extraction
+   * purposes.
+   *
+   * \return A normalized schedule, can be same as current one.
+   */
+  Schedule normalize_for_feature_extraction();
+
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index a5df788d38cb..c3e14eff3919 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -337,7 +337,7 @@ def select_array(i, j):
                 )
         return now
 
-    return te.compute(matrix.shape, select_array, name=name)
+    return te.compute(matrix.shape, select_array, name=name, attrs={"const_matrix": True})
 
 
 def get_max_power2_factor(n, max_value=None):
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 27a30127ba65..e57fc8c9c2d9 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -1235,7 +1235,7 @@ State ComputeDAG::InferBound(const State& state) const {
   Array<te::Tensor> tensors;
   // Replay steps to tvm::Schedule
   std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
-  sch = sch.normalize();
+  sch = sch.normalize_for_feature_extraction();
   // Get bound information from TVM schedule
   Map<IterVar, Range> bounds = te::InferBound(sch);
 
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index a60c87cc600d..0df69b967d3b 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -669,7 +669,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
     math_op_counter(node->value);
     std::vector<float> mem_bytes_list;
     std::vector<float> compute_ops_list;
-    int cur_compute_ops;
+    double cur_compute_ops;
 
     // Group 1: Computation related features
     ExtractComputationFeature(node, math_op_counter);
@@ -768,7 +768,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
 
   // Extract buffer access related features (group 2)
   void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter,
-                                  int* cur_compute_ops, std::vector<float>* compute_ops_list,
+                                  double* cur_compute_ops, std::vector<float>* compute_ops_list,
                                   std::vector<float>* mem_bytes_list) {
     FeatureSet& fea = buffer_features[node->buffer];
 
@@ -920,7 +920,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
   }
 
   // Extract arithmetic intensity related feature (group 3)
-  void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, int cur_compute_ops,
+  void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, double cur_compute_ops,
                                          const std::vector<float>& compute_ops_list,
                                          const std::vector<float>& mem_bytes_list) {
     FeatureSet& fea = buffer_features[node->buffer];
@@ -1267,7 +1267,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
   Array<te::Tensor> tensors;
 
   std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps);
-  sch = sch.normalize();
+  sch = sch.normalize_for_feature_extraction();
   auto bounds = te::InferBound(sch);
 
   try {
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index 6aac3b769a47..bae8e069bcdb 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -502,7 +502,7 @@ void RebaseNonZeroMinLoop(ScheduleNode* sch) {
   }
 }
 
-void InjectInline(ScheduleNode* sch) {
+void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
   sch->InvalidateCache();
 
   std::vector<Array<PrimExpr> > new_body(sch->stages.size());
@@ -524,7 +524,15 @@ void InjectInline(ScheduleNode* sch) {
           args.push_back(iv->var);
         }
         ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
-        body = compute->body[0];
+
+        if (feature_extraction_mode && compute->attrs.count("const_matrix")) {
+          // Use constant value to replace access of const matrices.
+          // This produces wrong IR but is good enough for feature extraction purposes.
+          // This simplification can accelerate the feature extration and evolutionary search.
+          body = make_const(compute->output_dtype(0), 1.0f);
+        } else {
+          body = compute->body[0];
+        }
       }
       for (size_t j = i; j < sch->stages.size(); ++j) {
         Stage s = sch->stages[j];
@@ -700,7 +708,15 @@ void LegalizeInvalidAttach(ScheduleNode* sch) {
 
 Schedule Schedule::normalize() {
   Schedule sn = copy();
-  InjectInline(sn.operator->());
+  InjectInline(sn.operator->(), false);
+  RebaseNonZeroMinLoop(sn.operator->());
+  LegalizeInvalidAttach(sn.operator->());
+  return sn;
+}
+
+Schedule Schedule::normalize_for_feature_extraction() {
+  Schedule sn = copy();
+  InjectInline(sn.operator->(), true);
   RebaseNonZeroMinLoop(sn.operator->());
   LegalizeInvalidAttach(sn.operator->());
   return sn;
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index a28e98b8792a..9aeea8487444 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -85,7 +85,7 @@ def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
 #   during measurement and avoid other runtime conflicts.
 # * :code:`min_repeat_ms` defines the minimum duration of one "repeat" in every measurement.
 #   This can warmup the GPU, which is necessary to get accurate measurement results.
-#   Typically, we recommend a value > 300 ms.
+#   Typically, we recommend a value >= 300 ms.
 # * :code:`num_measure_trials` is the number of measurement trials we can use during the search.
 #   We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a
 #   good value for the search to converge. You can do more trials according to your time budget.
diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py
index 723b8d15ea88..8f8cf7f1e99a 100644
--- a/tutorials/auto_scheduler/tune_network_cuda.py
+++ b/tutorials/auto_scheduler/tune_network_cuda.py
@@ -167,7 +167,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
 #   during measurement and avoid other runtime conflicts.
 # * :code:`min_repeat_ms` defines the minimum duration of one "repeat" in every measurement.
 #   This can warmup the GPU, which is necessary to get accurate measurement results.
-#   Typically, we recommend a value > 300 ms.
+#   Typically, we recommend a value >= 300 ms.
 # * :code:`num_measure_trials` is the number of measurement trials we can use during the tuning.
 #   You can set it to a small number (e.g., 200) for a fast demonstrative run.
 #   In practice, we recommend setting it around :code:`1000 * len(tasks)`,
@@ -184,7 +184,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
 
 def run_tuning():
     print("Begin tuning...")
-    measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400, timeout=10)
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10)
 
     tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
     tune_option = auto_scheduler.TuningOptions(