Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedOneHotEncoder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package frameless
Comment thread
cchantep marked this conversation as resolved.
Outdated
package ml
package feature

import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
import frameless.ml.internals.UnaryInputsChecker
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel}
import org.apache.spark.ml.linalg.Vector

/**
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
* at most a single one-value per row that indicates the input category index.
*
* @see `TypedStringIndexer` for converting categorical values into category indices
*/
class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoderEstimator, inputCol: String)
extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] {

override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder
.setInputCols(Array(inputCol))
.setOutputCols(Array(AppendTransformer.tempColumnName))

def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setHandleInvalid(value.sparkValue))

def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] =
copy(oneHotEncoder.setDropLast(value))

private def copy(newOneHotEncoder: OneHotEncoderEstimator): TypedOneHotEncoder[Inputs] =
new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol)
}

object TypedOneHotEncoder {

case class Outputs(output: Vector)

sealed abstract class HandleInvalid(val sparkValue: String)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sealed abstract class HandleInvalid(val sparkValue: String)
final class HandleInvalid private(val sparkValue: String) extends AnyVal

object HandleInvalid {
case object Error extends HandleInvalid("error")
case object Keep extends HandleInvalid("keep")
}

def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, Int]): TypedOneHotEncoder[Inputs] = {
Comment thread
cchantep marked this conversation as resolved.
Outdated
new TypedOneHotEncoder[Inputs](new OneHotEncoderEstimator(), inputsChecker.inputCol)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package frameless
Comment thread
cchantep marked this conversation as resolved.
Outdated
package ml
package feature

import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
import org.apache.spark.ml.linalg._
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop._
import shapeless.test.illTyped

class TypedOneHotEncoderTests extends FramelessMlSuite {

test(".fit() returns a correct TypedTransformer") {
implicit val arbInt = Arbitrary(Gen.choose(0, 99))
def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) =>
val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast)
val inputs = 0.to(x2.a).map(i => X2(i, x2.b))
val ds = TypedDataset.create(inputs)
val model = encoder.fit(ds).run()
val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]]
val result = resultDs.collect.run()
if (dropLast) {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray)))
} else {
result == Seq (X3(x2.a, x2.b,
Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0))))
}
}

check(prop[Double])
check(prop[String])
}

test("param setting is retained") {
implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary {
Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error)
}

val prop = forAll { handleInvalid: HandleInvalid =>
val encoder = TypedOneHotEncoder[X1[Int]]
.setHandleInvalid(handleInvalid)
val ds = TypedDataset.create(Seq(X1(1)))
val model = encoder.fit(ds).run()

model.transformer.getHandleInvalid == handleInvalid.sparkValue
}

check(prop)
}

test("create() compiles only with correct inputs") {
illTyped("TypedOneHotEncoder.create[Double]()")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you meant to use .apply[] instead of .create[]?

illTyped("TypedOneHotEncoder.create[X1[Double]]()")
illTyped("TypedOneHotEncoder.create[X2[String, Long]]()")
}
}