Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add bootstrap Param to RandomForest
Browse files Browse the repository at this point in the history
Signed-off-by: zero323 <mszymkiewicz@gmail.com>
  • Loading branch information
zero323 committed Jan 23, 2020
1 parent d457efe commit f8f6773
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
5 changes: 3 additions & 2 deletions third_party/3/pyspark/ml/classification.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClass
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): ...

class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificationModel], _RandomForestClassifierParams, JavaMLWritable, JavaMLReadable[RandomForestClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., seed: Optional[int] = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ...) -> RandomForestClassifier: ...
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ..., bootstrap: Optional[bool] = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., seed: Optional[int] = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ..., bootstrap: Optional[bool] = ...) -> RandomForestClassifier: ...
def setMaxDepth(self, value: int) -> RandomForestClassifier: ...
def setMaxBins(self, value: int) -> RandomForestClassifier: ...
def setMinInstancesPerNode(self, value: int) -> RandomForestClassifier: ...
Expand All @@ -196,6 +196,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificat
def setCacheNodeIds(self, value: bool) -> RandomForestClassifier: ...
def setImpurity(self, value: str) -> RandomForestClassifier: ...
def setNumTrees(self, value: int) -> RandomForestClassifier: ...
def setBootstrap(self, value: bool) -> RandomForestClassifier: ...
def setSubsamplingRate(self, value: float) -> RandomForestClassifier: ...
def setFeatureSubsetStrategy(self, value: str) -> RandomForestClassifier: ...
def setSeed(self, value: int) -> RandomForestClassifier: ...
Expand Down
5 changes: 3 additions & 2 deletions third_party/3/pyspark/ml/regression.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class DecisionTreeRegressionModel(_DecisionTreeModel[T], JavaMLWritable, JavaMLR
class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): ...

class RandomForestRegressor(JavaPredictor[RandomForestRegressionModel], _RandomForestRegressorParams, JavaMLWritable, JavaMLReadable[RandomForestRegressor]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., subsamplingRate: float = ..., seed: Optional[int] = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., subsamplingRate: float = ..., seed: Optional[int] = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ...) -> RandomForestRegressor: ...
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., subsamplingRate: float = ..., seed: Optional[int] = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ..., bootstrap: Optional[bool] = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., subsamplingRate: float = ..., seed: Optional[int] = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., minWeightFractionPerNode: float = ..., weightCol: Optional[str] = ..., bootstrap: Optional[bool] = ...) -> RandomForestRegressor: ...
def setMaxDepth(self, value: int) -> RandomForestRegressor: ...
def setMaxBins(self, value: int) -> RandomForestRegressor: ...
def setMinInstancesPerNode(self, value: int) -> RandomForestRegressor: ...
Expand All @@ -150,6 +150,7 @@ class RandomForestRegressor(JavaPredictor[RandomForestRegressionModel], _RandomF
def setCacheNodeIds(self, value: bool) -> RandomForestRegressor: ...
def setImpurity(self, value: str) -> RandomForestRegressor: ...
def setNumTrees(self, value: int) -> RandomForestRegressor: ...
def setBootstrap(self, value: bool) -> RandomForestRegressor: ...
def setSubsamplingRate(self, value: float) -> RandomForestRegressor: ...
def setFeatureSubsetStrategy(self, value: str) -> RandomForestRegressor: ...
def setCheckpointInterval(self, value: int) -> RandomForestRegressor: ...
Expand Down
4 changes: 3 additions & 1 deletion third_party/3/pyspark/ml/tree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ class _TreeEnsembleParams(_DecisionTreeParams):

class _RandomForestParams(_TreeEnsembleParams):
numTrees: Param[int]
bootstrap: Param[bool]
def __init__(self) -> None: ...
def getNumTrees(self) -> int: ...
def getBootstrap(self) -> bool: ...

class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
stepSize: Param[float]
Expand All @@ -70,7 +72,7 @@ class _HasVarianceImpurity(Params):
def __init__(self) -> None: ...
def getImpurity(self) -> str: ...

class _TreeClassifierParams:
class _TreeClassifierParams(Params):
supportedImpurities: List[str]
impurity: Param[str]
def __init__(self) -> None: ...
Expand Down

0 comments on commit f8f6773

Please sign in to comment.