Skip to content

Commit dceecc3

Browse files
committed
Better type inference for V<n> pattern synonyms
Previously, the V<n> pattern synonyms (V2, V4, etc.) had suboptimal typing in two different ways: 1. They did not force the type constructor of the constructed vector to be `Vec`, hence the following type annotation was required absent external type signatures: > generate (I1 10000) (\(I1 idx) -> V2 idx (idx + 1) :: Exp (Vec _ _)) 2. The IsVector instances required it to be known that the types of each of the arguments of V<n> are equal before the instance could be selected. This means that these type annotations were also required: > fold1 (\(V2 a b) (V2 c d) -> V2 (a + c :: Exp Int) (b * d :: Exp Int)) $ > generate _ _ This commit fixes both issues, meaning that the following code is now accepted in ghci: > fold1 (\(V2 a b) (V2 c d) -> V2 (a + c) (b * d)) $ > generate (I1 10000) (\(I1 idx) -> V2 idx (idx + 1))
1 parent 3f681a5 commit dceecc3

2 files changed

Lines changed: 27 additions & 13 deletions

File tree

src/Data/Array/Accelerate/Data/Complex.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a)
178178
constructComplex r i =
179179
case complexR (eltR @a) of
180180
ComplexTup -> coerce $ T2 r i
181-
ComplexVec _ -> V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i)
181+
ComplexVec _ -> coerce $ V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i)
182182

183183
deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a)
184184
deconstructComplex c@(Exp c') =

src/Data/Array/Accelerate/Pattern.hs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,35 @@ runQ $ do
188188
|]
189189

190190
-- Generate instance declarations for IsVector of the form:
191-
-- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a)
191+
-- instance (Elt v, EltR v ~ Vec 2 a1, Elt a1, a1 ~ a2) => IsVector Exp v (Exp a1, Exp a2)
192+
-- The element type is `a1`; we leave the element types for each of the
193+
-- tuple components as separate type variables in the instance head
194+
-- (only to equate them in the instance context) so that type inference
195+
-- can already select this instance before it is known that the tuple
196+
-- elements are indeed homogeneously typed.
192197
mkVecPattern :: Int -> Q [Dec]
198+
mkVecPattern n | n <= 0 = error "mkVecPattern: must be > 0"
193199
mkVecPattern n = do
194-
a <- newName "a"
200+
as@(a1 : a2s) <- mapM (\i -> newName ("a" ++ show i)) [1 .. n]
195201
v <- newName "v"
196202
let
197-
-- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example
198-
tup = tupT (replicate n ([t| Exp $(varT a)|]))
199-
-- Representation as a vector, eg (Vec 2 a)
200-
vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |]
201-
-- Constraints for the type class, consisting of Elt constraints on all type variables,
202-
-- and an equality constraint on the representation type of `a` and the vector representation `vec`.
203-
context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |]
203+
-- Last argument to `IsVector`, eg (Exp a1, Exp a2) in the example
204+
tup = tupT [ [t| Exp $(varT ai) |] | ai <- as ]
205+
-- Representation as a vector, eg (Vec 2 a1)
206+
vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a1) |]
207+
-- Constraints for the type class, consisting of:
208+
-- - Elt constraints on all type variables;
209+
-- - an equality constraint on the representation type of `a` and
210+
-- the vector representation `vec`;
211+
-- - equality constraints for the `ai` variables.
212+
context = tupT $ [t| Elt $(varT v) |]
213+
: [t| VecElt $(varT a1) |]
214+
: [t| EltR $(varT v) ~ $vec |]
215+
: [ [t| $(varT a1) ~ $(varT ai) |] | ai <- a2s]
204216
--
205-
vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |])
206-
tR = tupT (replicate n (varT a))
217+
vecR = iterate ([| VecRsucc |] `appE`) [| VecRnil (singleType @($(varT a1))) |]
218+
Prelude.!! n
219+
tR = tupT (replicate n (varT a1))
207220
--
208221
[d| instance $context => IsVector Exp $(varT v) $tup where
209222
vpack x = case builder x :: Exp $tR of
@@ -273,12 +286,13 @@ runQ $ do
273286
]
274287

275288
mkV :: Int -> Q [Dec]
289+
mkV n | n <= 0 = error "mkV: must be > 0"
276290
mkV n =
277291
let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
278292
ts = map varT xs
279293
name = mkName ('V':show n)
280294
con = varT (mkName "con")
281-
ty1 = varT (mkName "vec")
295+
ty1 = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT (head xs)) |]
282296
ty2 = tupT (map (con `appT`) ts)
283297
sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts
284298
in

0 commit comments

Comments
 (0)