版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011707542/article/details/78730063
RFormula简单介绍
RFormula通过R模型公式来操作列。
支持R操作中的部分操作包括‘~’, ‘.’, ‘:’, ‘+’以及‘-‘。
1、 ~分隔目标和对象
2、 +合并对象,“+0”意味着删除空格
3、-删除一个对象,“-1”表示删除空格
4、 :交互(数值相乘,类别二值化)
5、 . 除了目标列的全部列
假设a和b为两列:
1、 y ~ a + b表示模型y ~ w0 + w1 * a +w2 * b其中w0为截距,w1和w2为相关系数
2、 y ~a + b + a:b – 1表示模型y ~ w1* a + w2 * b + w3 * a * b,其中w1,w2,w3是相关系数
RFormula产生一个向量特征列以及一个double或者字符串标签列。如果用R进行线性回归,则对String类型的输入列进行one-hot编码、对数值型的输入列进行double类型转化。如果类别列是字符串类型,它将通过StringIndexer转换为double类型。如果标签列不存在,则输出中将通过规定的响应变量创造一个标签列。
代码示例
/**
* Created by hhy on 2017/12/05.
*/
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.SparkSession
object RFormulaDemo {
def main(args: Array[String]): Unit = {
val spark=SparkSession.builder().appName(" ").master("local").getOrCreate()
val dataset = spark.createDataFrame(Seq(
(7,10, "US", 18, 1.0),
(8,22, "CA", 12, 0.0),
(9,100,"CA", 15, 0.0),
(10,29,"CA", 15, 1.0),
(11,88,"CA", 15, 0.0),
(12,99,"CA", 2, 0.0)
)).toDF("id", "count","country", "hour", "clicked")
/**
*country列6个不同取值时候占了五个维度 五个不同取值时候占了四个维度
*四个不同取值时候占了三个维度 三个不同取值占了两维度 两个不同取值占
*了一个维度,另外我们还操作了非StringType类型的hour 和 count列 因此*在country列所占维度基础上 再加上两个维度,就是所形成的新列features
*该列值是一个向量 由上面组成的维度构成
*/
val formula = new RFormula()
.setFormula("clicked ~ country + hour+count")
.setFeaturesCol("features")
.setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.show()
//output.write.json("spark-warehouse/Rformula")
// output.select("features", "label").show()
}
}
其中,如果我们查看输出列features的值时候,是按照系数向量存储的,如果该列维度小于3(或者4 )自己试验 忘记了,,那么正常输出,比如country列只有一个数值 那么我们用0就可以表示了 构成的向量格式如下:
[0,hour列值,count列值]
[0,hour列值,count列值]
[0,hour列值,count列值]
[0,hour列值,count列值]
如果country列不止一个值 在进行onehot编码时候该列肯定不能使用一维就可以表示完所有取值,形成的新列features大于等于4时候就要输出稀疏向量形式了结果如下所示:
源码解析
class RFormula(override val uid: String)
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
/**
*Identifiable.randomUID("rFormula")的作用是生成一个
*以rFormula为前缀 然后加上下划线_ 然后加上
*UUID.randomUUID().toString.takeRight(12)
*12个随机十六进制字符
*具体形式: “rFormula_12个十六进制字符”
*this指向了当前RFormula 然后为其参数uid赋值
*/
def this() = this(Identifiable.randomUID("rFormula"))
/**
* RFormula参数,String类型的参数
*/
val formula: Param[String] = new Param(this, "formula", "R model formula")
/**
* 设置R公式为RFormula转换器 使用之前必须先调用这个函数设置参数
* 例如"y ~ x + z"
*/
def setFormula(value: String): this.type = set(formula, value)
/** 得到RFormula的参数使用了 ${变量} 的形式*/
def getFormula: String = $(formula)
/** 设置得到的新列的列名字 */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** 针对R公式 设置label列的列名*/
def setLabelCol(value: String): this.type = set(labelCol, value)
/**
* 将label列索引化
* 一般情况我们我们只索引化字符串类型的label列
* 在分类算法中,我们设置其为true,及时该列是数值类型
* 我们也可以索引化
*/
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)
/**获取forceIndexLabel变量值 */
def getForceIndexLabel: Boolean = $(forceIndexLabel)
/** 设置forceIndexLabel变量的值*/
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)
/**
*是否特殊化拟合截距
*/
private[ml] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
RFormulaParser.parse($(formula)).hasIntercept
}
/**对于RFormula核心部分*/
override def fit(dataset: Dataset[_]): RFormulaModel = {
transformSchema(dataset.schema, logging = true)
require(isDefined(formula), "Formula must be defined first.")
/**解析给的R公式,返回类型ParsedRFormula
*RFormulaParser继承了RegexParsers
*/
val parsedFormula = RFormulaParser.parse($(formula))
/** 返回类型ResolvedRFormula :将RFormula terms转为列名
* 该类其中三个参数:
* label:String 列名;
* terms:Seq[Seq[String]] the simplified terms of the R formula
* hasIntercept:Boolean 是否特殊化拟合截距
*/
val resolvedFormula = parsedFormula.resolve(dataset.schema)
/**定义了一个存放PipelineStage的集合*/
val encoderStages = ArrayBuffer[PipelineStage]()
val prefixesToRewrite = mutable.Map[String, String]() //重写前缀
val tempColumns = ArrayBuffer[String]()
def tmpColumn(category: String): String = {
val col = Identifiable.randomUID(category) // 返回String类型 字符串形式 category_12个十六进制字符 然后存储到tempColumns中
tempColumns += col
col
}
// First we index each string column referenced by the input terms.将string类型的列索引化
val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = tmpColumn("stridx") //设置每个列操作后的输出列名 stridx_12个16进制字符 组成的字符串
encoderStages += new StringIndexer() // encoderStages是一个集合 存放了PipelineStage主要就是对每一个string类型列创建了一个StringIndexer操作
.setInputCol(term)
.setOutputCol(indexCol)
//替换一下前缀标识
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _ =>
(term, term)
}
}.toMap
// Then we handle one-hot encoding and interactions between terms.
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
encoderStages += new OneHotEncoder() //对每个string类型的列 继续增加PipelineStage操作 这个才做是onehot的操作 000 001 010 100 可以表示四个不同值不用110 011类似的
.setInputCol(indexed(term))
.setOutputCol(encodedCol)
prefixesToRewrite(encodedCol + "_") = term + "_"
encodedCol
case Seq(term) =>
term
case terms =>
val interactionCol = tmpColumn("interaction")
encoderStages += new Interaction()
.setInputCols(terms.map(indexed).toArray)
.setOutputCol(interactionCol)
prefixesToRewrite(interactionCol + "_") = ""
interactionCol
}
encoderStages += new VectorAssembler(uid) //继续添加通道 操作是将若干个列向量合并为一列 设置输入列 输出列参数
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)//通过前缀替换重写向量属性名字这里是将StringIndexer操作输出的strid_ 以及onehot操作命名的onehot_前缀替换掉 统一改为了term_ term是R公式的simplified terms
encoderStages += new ColumnPruner(tempColumns.toSet) //移除临时列
/**如果数据集中包含了给出的R公式中的label列,并且该列数据类型是String类型,或者我们设置的参数变量forceIndexLabel为true那么将label列使用StringIndexer索引化,此处并未执行,只是将这个转换器放到了 存储Piplinestage的encoderStages集合中了 */
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
}
/**调用fit执行encoderStages里面的PipelineStage,Pipeline的参数uid等于RFormula的参数uid*/
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
require(!hasLabelCol(schema) || !$(forceIndexLabel),
"If label column already exists, forceIndexLabel can not be set with true.")
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
StructField($(labelCol), DoubleType, true))
}
}
@Since("1.5.0")
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
@Since("2.0.0")
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
}
val resolvedFormula = parsedFormula.resolve(dataset.schema)代码中ParsedRFormula类的resolve()函数的源码如下:
private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
/**
* Resolves formula terms into column names. A schema is necessary for inferring the meaning
* of the special '.' term. Duplicate terms will be removed during resolution.
*/
def resolve(schema: StructType): ResolvedRFormula = {
val dotTerms = expandDot(schema)
var includedTerms = Seq[Seq[String]]()
terms.foreach {
case col: ColumnRef =>
includedTerms :+= Seq(col.value)
case ColumnInteraction(cols) =>
includedTerms ++= expandInteraction(schema, cols)
case Dot =>
includedTerms ++= dotTerms.map(Seq(_))
case Deletion(term: Term) =>
term match {
case inner: ColumnRef =>
includedTerms = includedTerms.filter(_ != Seq(inner.value))
case ColumnInteraction(cols) =>
val fromInteraction = expandInteraction(schema, cols).map(_.toSet)
includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet))
case Dot =>
// e.g. "- .", which removes all first-order terms
includedTerms = includedTerms.filter {
case Seq(t) => !dotTerms.contains(t)
case _ => true
}
case _: Deletion =>
throw new RuntimeException("Deletion terms cannot be nested")
case _: Intercept =>
}
case _: Intercept =>
}
ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
}