Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ val scalatest = "3.0.1"
val shapeless = "2.3.2"
val scalacheck = "1.13.4"

// spark has scalatest and scalactic as a runtime dependency
// which can mess things up if you use a different version in your project
val exclusions = Seq(ExclusionRule("org.scalatest"), ExclusionRule("org.scalactic"))
Copy link
Copy Markdown
Contributor

@imarios imarios Feb 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fixs the IntelliJ issue with ScalaTest?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!


lazy val root = Project("frameless", file("." + "frameless")).in(file("."))
.aggregate(core, cats, dataset, docs)
.settings(framelessSettings: _*)
Expand All @@ -22,7 +26,7 @@ lazy val cats = project
.settings(publishSettings: _*)
.settings(libraryDependencies ++= Seq(
"org.typelevel" %% "cats" % catsv,
"org.apache.spark" %% "spark-core" % sparkVersion % "provided"))
"org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*)))

lazy val dataset = project
.settings(name := "frameless-dataset")
Expand All @@ -31,8 +35,8 @@ lazy val dataset = project
.settings(framelessTypedDatasetREPL: _*)
.settings(publishSettings: _*)
.settings(libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
"org.apache.spark" %% "spark-core" % sparkVersion % "provided" excludeAll(exclusions: _*),
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided" excludeAll(exclusions: _*)
))
.dependsOn(core % "test->test;compile->compile")

Expand Down
3 changes: 2 additions & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ object TypedColumn {
lgen: LabelledGeneric.Aux[T, H],
selector: Selector.Aux[H, K, V]
): Exists[T, K, V] = new Exists[T, K, V] {}

}

implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) {
Expand All @@ -279,4 +280,4 @@ object TypedColumn {
def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed
def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed
}
}
}
9 changes: 5 additions & 4 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
*
* It is statically checked that column with such name exists and has type `A`.
*/
def apply[A](column: Witness.Lt[Symbol])(
implicit
exists: TypedColumn.Exists[T, column.T, A],
def apply[A](selector: T => A)(implicit
encoder: TypedEncoder[A]
): TypedColumn[T, A] = col(column)
): TypedColumn[T, A] = macro frameless.macros.ColumnMacros.fromFunction[T, A]

/** Returns `TypedColumn` of type `A` given it's name.
*
Expand Down Expand Up @@ -319,6 +317,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}

def selectExpr[B](expr: T => B)(implicit encoder: TypedEncoder[B]): TypedDataset[B] =
macro frameless.macros.ColumnMacros.fromExpr[T, B]

/** Type-safe projection from type T to Tuple2[A,B]
* {{{
* d.select( d('a), d('a)+d('b), ... )
Expand Down
88 changes: 78 additions & 10 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import shapeless._

import scala.reflect.ClassTag

abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Serializable {
Expand Down Expand Up @@ -264,15 +265,30 @@ object TypedEncoder {
def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType)

def constructorFor(path: Expression): Expression = {
val arrayData = Invoke(
MapObjects(
underlying.constructorFor,
path,
underlying.targetDataType
),
"array",
ScalaReflection.dataTypeFor[Array[AnyRef]]
)
val arrayData = Option(underlying.sourceDataType)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoids boxing the primitives in the array

.filter(ScalaReflection.isNativeType)
.filter(_ == underlying.targetDataType)
.collect {
case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]]
case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]]
case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]]
case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]]
case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]]
case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]]
case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]]
}.map {
case (method, typ) => Invoke(path, method, typ)
}.getOrElse {
Invoke(
MapObjects(
underlying.constructorFor,
path,
underlying.targetDataType
),
"array",
ScalaReflection.dataTypeFor[Array[AnyRef]]
)
}

StaticInvoke(
TypedEncoderUtils.getClass,
Expand All @@ -296,6 +312,58 @@ object TypedEncoder {
}
}

implicit def arrayEncoder[A](
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New encoder for Array, could be a separate PR but I needed it for this branch

implicit
underlying: TypedEncoder[A],
classTag: ClassTag[Array[A]]
): TypedEncoder[Array[A]] = new TypedEncoder[Array[A]]() {
def nullable: Boolean = false

def sourceDataType: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag)

def targetDataType: DataType = DataTypes.createArrayType(underlying.targetDataType)

def constructorFor(path: Expression): Expression = {
Option(underlying.sourceDataType)
.filter(ScalaReflection.isNativeType)
.filter(_ == underlying.targetDataType)
.collect {
case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]]
case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]]
case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]]
case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]]
case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]]
case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]]
case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]]
}.map {
case (method, typ) => Invoke(path, method, typ)
}.getOrElse {
Invoke(
MapObjects(
underlying.constructorFor,
path,
underlying.targetDataType
),
"array",
ScalaReflection.dataTypeFor[Array[AnyRef]]
)
}
}

def extractorFor(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.sourceDataType)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.targetDataType, underlying.nullable)
)
} else {
MapObjects(underlying.extractorFor, path, underlying.sourceDataType)
}
}
}

/** Encodes things using injection if there is one defined */
implicit def usingInjection[A: ClassTag, B]
(implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] =
Expand All @@ -322,4 +390,4 @@ object TypedEncoder {
recordEncoder: Lazy[RecordEncoderFields[G]],
classTag: ClassTag[F]
): TypedEncoder[F] = new RecordEncoder[F, G]
}
}
2 changes: 2 additions & 0 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ package frameless

package object functions extends Udf {
object aggregate extends AggregateFunctions


}
Loading