diff --git a/accelerate.cabal b/accelerate.cabal index 118210a38..f8eae9527 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -361,6 +361,10 @@ library , unique , unordered-containers >= 0.2 , vector >= 0.10 + , posable >= 1.0.0.1 + , ghc-typelits-knownnat >= 0.6 + , generics-sop >= 0.4.0 + , finite-typelits >= 0.1.4 exposed-modules: -- The core language and reference implementation @@ -400,6 +404,7 @@ library Data.Array.Accelerate.Lifetime Data.Array.Accelerate.Pretty Data.Array.Accelerate.Representation.Array + Data.Array.Accelerate.Representation.POS Data.Array.Accelerate.Representation.Elt Data.Array.Accelerate.Representation.Shape Data.Array.Accelerate.Representation.Slice @@ -463,6 +468,7 @@ library Data.Array.Accelerate.Lift Data.Array.Accelerate.Orphans Data.Array.Accelerate.Pattern + Data.Array.Accelerate.Pattern.Matchable Data.Array.Accelerate.Pattern.Bool Data.Array.Accelerate.Pattern.Either Data.Array.Accelerate.Pattern.Maybe diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index a6d0f75f7..bdd550562 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.AST @@ -146,6 +147,7 @@ import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Foreign +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Primitive.Vec @@ -198,9 +200,8 @@ type ArrayVar = Var ArrayR type ArrayVars aenv = Vars ArrayR aenv -- Bool is not a primitive type -type PrimBool = TAG -type PrimMaybe a = (TAG, ((), a)) - +type PrimBool = EltR Bool +type PrimMaybe a = EltR (Maybe a) -- Trace messages data Message a where Message :: (a -> String) -- embedded show @@ -940,8 +941,10 @@ primFunType = \case integral = num . IntegralNumType floating = num . FloatingNumType - tbool = TupRsingle scalarTypeWord8 - tint = TupRsingle scalarTypeInt + tbool :: TypeR PrimBool + tbool = TupRpair (TupRsingle (scalarType @TAG)) TupRunit + tint :: TypeR Int + tint = TupRsingle (scalarType @Int) -- Normal form data diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index e212c0869..dffe94a42 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -114,107 +114,107 @@ instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where -- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of -- the (unremarkable) boilerplate for us. -- -runQ $ do - let - -- Generate instance declarations for IsPattern of the form: - -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) - mkAccPattern :: Int -> Q [Dec] - mkAccPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example - b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) - snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Arrays constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Arrays $(varT a) |] - : [t| ArraysR $(varT a) ~ $snoc |] - : map (\t -> [t| Arrays $(varT t)|]) xs - -- - get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] - get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - [d| instance $context => IsPattern Acc $(varT a) $b where - builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = - Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) - matcher (Acc $(varP _x)) = - $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) - |] +-- runQ $ do +-- let +-- -- Generate instance declarations for IsPattern of the form: +-- -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) +-- mkAccPattern :: Int -> Q [Dec] +-- mkAccPattern n = do +-- a <- newName "a" +-- let +-- -- Type variables for the elements +-- xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] +-- -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example +-- b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) +-- -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) +-- snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs +-- -- Constraints for the type class, consisting of Arrays constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. +-- context = tupT +-- $ [t| Arrays $(varT a) |] +-- : [t| ArraysR $(varT a) ~ $snoc |] +-- : map (\t -> [t| Arrays $(varT t)|]) xs +-- -- +-- get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] +-- get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) +-- -- +-- _x <- newName "_x" +-- [d| instance $context => IsPattern Acc $(varT a) $b where +-- builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = +-- Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) +-- matcher (Acc $(varP _x)) = +-- $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) +-- |] - -- Generate instance declarations for IsPattern of the form: - -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) - mkExpPattern :: Int -> Q [Dec] - mkExpPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Variables for sub-pattern matches - ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] - tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms - -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example - b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) - snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Elt $(varT a) |] - : [t| EltR $(varT a) ~ $snoc |] - : map (\t -> [t| Elt $(varT t)|]) xs - -- - get x 0 = [| SmartExp (Prj PairIdxRight $x) |] - get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - _y <- newName "_y" - [d| instance $context => IsPattern Exp $(varT a) $b where - builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = - let _unmatch :: SmartExp a -> SmartExp a - _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) - _unmatch x = x - in - Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) - matcher (Exp $(varP _x)) = - case $(varE _x) of - SmartExp (Match $tags $(varP _y)) - -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) - _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) - |] +-- -- Generate instance declarations for IsPattern of the form: +-- -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) +-- mkExpPattern :: Int -> Q [Dec] +-- mkExpPattern n = do +-- a <- newName "a" +-- let +-- -- Type variables for the elements +-- xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] +-- -- Variables for sub-pattern matches +-- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] +-- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms +-- -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example +-- b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) +-- -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) +-- snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs +-- -- Constraints for the type class, consisting of Elt constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. +-- context = tupT +-- $ [t| Elt $(varT a) |] +-- : [t| EltR $(varT a) ~ $snoc |] +-- : map (\t -> [t| Elt $(varT t)|]) xs +-- -- +-- get x 0 = [| SmartExp (Prj PairIdxRight $x) |] +-- get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) +-- -- +-- _x <- newName "_x" +-- _y <- newName "_y" +-- [d| instance $context => IsPattern Exp $(varT a) $b where +-- builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = +-- let _unmatch :: SmartExp a -> SmartExp a +-- _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) +-- _unmatch x = x +-- in +-- Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) +-- matcher (Exp $(varP _x)) = +-- case $(varE _x) of +-- SmartExp (Match $tags $(varP _y)) +-- -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) +-- _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) +-- |] - -- Generate instance declarations for IsVector of the form: - -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do - a <- newName "a" - v <- newName "v" - let - -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example - tup = tupT (replicate n ([t| Exp $(varT a)|])) - -- Representation as a vector, eg (Vec 2 a) - vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the vector representation `vec`. - context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] - -- - vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) - tR = tupT (replicate n (varT a)) - -- - [d| instance $context => IsVector Exp $(varT v) $tup where - vpack x = case builder x :: Exp $tR of - Exp x' -> Exp (SmartExp (VecPack $vecR x')) - vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) - |] - -- - es <- mapM mkExpPattern [0..16] - as <- mapM mkAccPattern [0..16] - vs <- mapM mkVecPattern [2,3,4,8,16] - return $ concat (es ++ as ++ vs) +-- -- Generate instance declarations for IsVector of the form: +-- -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) +-- mkVecPattern :: Int -> Q [Dec] +-- mkVecPattern n = do +-- a <- newName "a" +-- v <- newName "v" +-- let +-- -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example +-- tup = tupT (replicate n ([t| Exp $(varT a)|])) +-- -- Representation as a vector, eg (Vec 2 a) +-- vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] +-- -- Constraints for the type class, consisting of Elt constraints on all type variables, +-- -- and an equality constraint on the representation type of `a` and the vector representation `vec`. +-- context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] +-- -- +-- vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) +-- tR = tupT (replicate n (varT a)) +-- -- +-- [d| instance $context => IsVector Exp $(varT v) $tup where +-- vpack x = case builder x :: Exp $tR of +-- Exp x' -> Exp (SmartExp (VecPack $vecR x')) +-- vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) +-- |] +-- -- +-- es <- mapM mkExpPattern [0..16] +-- as <- mapM mkAccPattern [0..16] +-- vs <- mapM mkVecPattern [2,3,4,8,16] +-- return $ concat (es ++ as ++ vs) -- | Specialised pattern synonyms for tuples, which may be more convenient to diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index d968aaf34..4b98cbe79 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -1,9 +1,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Bool -- Copyright : [2018..2020] The Accelerate Team @@ -20,7 +20,33 @@ module Data.Array.Accelerate.Pattern.Bool ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -mkPattern ''Bool +{-# COMPLETE False_, True_ #-} +pattern False_ :: Exp Bool +pattern False_ <- (matchFalse -> Just ()) where + False_ = buildFalse +matchFalse :: Exp Bool -> Maybe () +matchFalse x = case match (Proxy @0) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildFalse :: Exp Bool +buildFalse = build (Proxy @0) SOP.Nil + +pattern True_ :: Exp Bool +pattern True_ <- (matchTrue -> Just x) where + True_ = buildTrue + +matchTrue :: Exp Bool -> Maybe () +matchTrue x = case match (Proxy @1) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildTrue :: Exp Bool +buildTrue = build (Proxy @1) SOP.Nil diff --git a/src/Data/Array/Accelerate/Pattern/Either.hs b/src/Data/Array/Accelerate/Pattern/Either.hs index 67c7b3a3f..59e052667 100644 --- a/src/Data/Array/Accelerate/Pattern/Either.hs +++ b/src/Data/Array/Accelerate/Pattern/Either.hs @@ -1,9 +1,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Either -- Copyright : [2018..2020] The Accelerate Team @@ -20,7 +20,45 @@ module Data.Array.Accelerate.Pattern.Either ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -mkPattern ''Either +{-# COMPLETE Left_, Right_ #-} +pattern Left_ :: + forall a b . + ( Elt a + , POSable a + , POSable b + , Matchable a + ) => Exp a -> Exp (Either a b) +pattern Left_ x <- (matchLeft -> Just x) where + Left_ = buildLeft +matchLeft :: forall a b . (POSable a, Elt a, POSable b) => Exp (Either a b) -> Maybe (Exp a) +matchLeft x = case match (Proxy @0) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildLeft :: forall a b . (Elt a, POSable a, POSable b) => Exp a -> Exp (Either a b) +buildLeft x = build (Proxy @0) (x :* SOP.Nil) + +pattern Right_ :: + forall a b . + ( Elt a + , POSable a + , POSable b + , Matchable a + ) => Exp b -> Exp (Either a b) +pattern Right_ x <- (matchRight -> Just x) where + Right_ = buildRight + +matchRight :: forall a b . (Elt a, POSable a, POSable b) => Exp (Either a b) -> Maybe (Exp b) +matchRight x = case match (Proxy @1) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildRight :: forall a b . (Elt a, POSable a, POSable b) => Exp b -> Exp (Either a b) +buildRight x = build (Proxy @1) (x :* SOP.Nil) diff --git a/src/Data/Array/Accelerate/Pattern/Matchable.hs b/src/Data/Array/Accelerate/Pattern/Matchable.hs new file mode 100644 index 000000000..2c9ef60c6 --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Matchable.hs @@ -0,0 +1,309 @@ +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} + +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ConstraintKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} + + +module Data.Array.Accelerate.Pattern.Matchable (Matchable(..)) where + +import Data.Array.Accelerate.Smart as Smart +import GHC.TypeLits +import Data.Proxy +import Data.Kind +import Generics.SOP as SOP +import Data.Type.Equality +import Data.Array.Accelerate.Representation.POS as POS +import Data.Array.Accelerate.Representation.Tag +import Unsafe.Coerce +import qualified Data.Array.Accelerate.AST as AST +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Representation.Type +-- import Data.Array.Accelerate.Pretty + + +class Matchable a where + type SOPCode a :: [[Type]] + type SOPCode a = Code a + + -- type Choices' a :: Nat + -- type Choices' a = Choices a + + build :: + ( KnownNat n + ) => Proxy n + -> NP Exp (SOPCode a !! n) + -> Exp a + default build :: + ( KnownNat n + , Elt a + ) => Proxy n + -> NP Exp (SOPCode a !! n) + -> Exp a + + build n _ = case sameNat (Proxy :: Proxy (EltChoices a)) (Proxy :: Proxy 1) of + -- no tag + Just Refl -> undefined + -- tagged + Nothing -> undefined + + match :: ( KnownNat n + ) => Proxy n + -> Exp a + -> Maybe (NP Exp (SOPCode a !! n)) + +buildTag :: SOP.All POSable xs => NP Exp xs -> Exp TAG +buildTag SOP.Nil = constant 0 -- exp of 0 :: Finite 1 +buildTag (((Exp x) :: (Exp x)) :* (xs :: xs)) = case sameNat (Proxy @(Choices x)) (Proxy :: Proxy 1) of + -- x doesn't contain a tag, skip + Just Refl + -> buildTag xs + -- x contains a tag, build an Exp to calculate the product + Nothing + | Refl :: (EltR x :~: (TAG, _r)) <- unsafeCoerce Refl + -- TODO: this is incorrect, we need the size of the TAG here (return to Finite?) + -> mkMul (Exp (SmartExp (Prj PairIdxLeft x))) (buildTag xs) + + + +type family (!!) (xs :: [[Type]]) (y :: Nat) :: [Type] where + (x ': xs) !! 0 = x + (x ': xs) !! n = xs !! (n - 1) + +infixl 9 !! + +instance Matchable Bool where + build n _ = Exp (SmartExp (Pair (unExp $ constant @TAG (fromInteger $ natVal n)) (SmartExp Smart.Nil))) + + match n (Exp e) = case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match (TagR l u) _x) + | l == 0 + , u == 1 + -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (TagR l u) _x) + | l == 1 + , u == 2 + -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" + +makeTag :: TAG -> SmartExp TAG +makeTag x = SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeTAG))) x) + +instance (POSable (Maybe a), POSable a) => Matchable (Maybe a) where + build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Just of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + -- a has at least 1 choice. + -- this means that Maybe a always has a tag + Nothing | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + -- Produce a Nothing + Just Refl -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @() @a (SmartExp Smart.Nil)))) + Nothing + | Exp x :* SOP.Nil <- fs + -> case sameNat n (Proxy :: Proxy 1) of + -- Add 1 to the tag because we have skipped 1 choice: Nothing + Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant 1) (buildTAG fs)) (makeRight @() @a (unTag @a x)))) + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bounds" + Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" + + match n (Exp e) = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Just of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + -- a has at least 1 choice. + -- this means that Maybe a always has a tag + Nothing | Refl :: (EltR (Maybe a) :~: (TAG, FlattenProduct (Fields (Maybe a)))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + Just Refl -> + case e of + SmartExp (Match (TagR l u) _x) + | l == 0 + , u == 1 + -> Just SOP.Nil + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchJust + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (TagR l u) x) + | l == 1 + , u == tagVal @(Choices a) + -- remove one from the tag as we are not in left anymore + -- the `tag` function will apply the new tag if necessary + -> Just (Exp (tag @a (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant 1)) (splitRight @() @a $ prjRight x)) :* SOP.Nil) + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" + +splitLeft :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) -> SmartExp (FlattenProduct (Fields a)) +splitLeft x = splitLeft' x (emptyFields @a) (emptyFields @b) + +splitLeft' :: forall a b . SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct a) +splitLeft' _ PTNil _ = SmartExp Smart.Nil +splitLeft' x (PTCons _ ls) PTNil = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitLeft' (prjRight x) ls PTNil) +splitLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitLeft' (prjRight x) ls rs) + +splitRight :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) -> SmartExp (FlattenProduct (Fields b)) +splitRight x = splitRight' x (emptyFields @a) (emptyFields @b) + +splitRight' :: forall a b . SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct b) +splitRight' _ _ PTNil = SmartExp Smart.Nil +splitRight' x PTNil (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitRight' (prjRight x) PTNil rs) +splitRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp $ Pair (SmartExp $ Union (prjLeft x)) (splitRight' (prjRight x) ls rs) + +makeLeft :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields a)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) +makeLeft x = makeLeft' x (emptyFields @a) (emptyFields @b) + +makeLeft' :: forall a b . SmartExp (FlattenProduct a) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) +makeLeft' _ PTNil PTNil = SmartExp Smart.Nil +makeLeft' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeLeft' x PTNil rs)) +makeLeft' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeLeft' (prjRight x) ls PTNil)) +makeLeft' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeLeft' (prjRight x) ls rs)) + +prjLeft :: SmartExp (x, xs) -> SmartExp x +prjLeft = SmartExp . Prj PairIdxLeft + +prjRight :: SmartExp (x, xs) -> SmartExp xs +prjRight = SmartExp . Prj PairIdxRight + +makeRight :: forall a b . (POSable a, POSable b) => SmartExp (FlattenProduct (Fields b)) -> SmartExp (FlattenProduct (Merge (Fields a ++ '[]) (Fields b ++ '[]))) +makeRight x = makeRight' x (emptyFields @a) (emptyFields @b) + +makeRight' :: forall a b . SmartExp (FlattenProduct b) -> ProductType a -> ProductType b -> SmartExp (FlattenProduct (Merge (a ++ '[]) (b ++ '[]))) +makeRight' _ PTNil PTNil = SmartExp Smart.Nil +makeRight' x PTNil (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeRight' (prjRight x) PTNil rs)) +makeRight' x (PTCons _ ls) PTNil = SmartExp (Pair (SmartExp (Union (SmartExp (LiftUnion (SmartExp (Const (SingleScalarType UndefSingleType) POS.Undef)))))) (makeRight' x ls PTNil)) +makeRight' x (PTCons _ ls) (PTCons _ rs) = SmartExp (Pair (SmartExp (Union (prjLeft x))) (makeRight' (prjRight x) ls rs)) + +unTag :: forall x . (POSable x) => SmartExp (EltR x) -> SmartExp (FlattenProduct (Fields x)) +unTag x = case eltRType @x of + SingletonType -> SmartExp (Pair (SmartExp (LiftUnion x)) (SmartExp Smart.Nil)) + TaglessType -> x + TaggedType -> prjRight x + +tag :: forall x . (POSable x) => SmartExp TAG -> SmartExp (FlattenProduct (Fields x)) -> SmartExp (EltR x) +tag t x = case eltRType @x of + SingletonType -> SmartExp $ PrjUnion $ prjLeft x + TaglessType -> x + TaggedType -> SmartExp $ Pair t x + +instance (POSable (Either a b), POSable a, POSable b) => Matchable (Either a b) where + + build n fs = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Left of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + Nothing -> case sameNat (Proxy @(Choices b)) (Proxy @0) of + -- b has 0 valid choices (which means we cannot create a Right of this type) + -- we ignore the implementation too + Just Refl -> undefined + -- a and b have at least 1 choice. + -- this means that Either a b always has a tag + Nothing | Refl :: EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of + -- Product a Left + Just Refl + | Exp x :* SOP.Nil <- fs + -> Exp (SmartExp (Pair (unExp $ buildTAG fs) (makeLeft @a @b (unTag @a x)))) + Nothing + | Exp x :* SOP.Nil <- fs + -> case sameNat n (Proxy :: Proxy 1) of + -- Add natVal @(Choices to the tag) + Just Refl -> Exp (SmartExp (Pair (unExp $ mkAdd @TAG (constant $ tagVal @(Choices a)) (buildTag fs)) (makeRight @a @b (unTag @b x)))) + Nothing -> error $ "Impossible situation requested: Maybe has 2 constructors, constructor " ++ show (natVal n) ++ "is out of bounds" + Nothing -> error "Impossible situation requested: Just a expects a single value, got 0 or more then 1" + + match n (Exp e) = case sameNat (Proxy @(Choices a)) (Proxy @0) of + -- a has 0 valid choices (which means we cannot create a Left of this type) + -- we ignore the implementation for now, because this is not really useful + Just Refl -> undefined + Nothing -> case sameNat (Proxy @(Choices b)) (Proxy @0) of + -- b has 0 valid choices (which means we cannot create a Right of this type) + -- we ignore the implementation too + Just Refl -> undefined + -- a and b have at least 1 choice. + -- this means that Either a b always has a tag + Nothing | Refl :: EltR (Either a b) :~: (TAG, FlattenProduct (Fields (Either a b))) <- unsafeCoerce Refl + -> case sameNat n (Proxy :: Proxy 0) of -- matchLeft + Just Refl -> + case e of + SmartExp (Match (TagR l u) x) + | l == 0 + , u == tagVal @(Choices a) + -> Just (Exp (tag @a (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitLeft @a @b $ prjRight x)) :* SOP.Nil) + + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + Nothing -> -- matchRight + case sameNat n (Proxy :: Proxy 1) of + Just Refl -> + case e of + SmartExp (Match (TagR l u) x) + | l == tagVal @(Choices a) + , u == tagVal @(Choices b) + -- remove one from the tag as we are not in left anymore + -- the `tag` function will apply the new tag if necessary + -> Just (Exp (tag @b (unExp $ mkSub @TAG (Exp $ prjLeft x) (constant $ tagVal @(Choices a))) (splitRight @a @b $ prjRight x)) :* SOP.Nil) + SmartExp Match {} -> Nothing + + _ -> error "Embedded pattern synonym used outside 'match' context." + + Nothing -> + error "Impossible type encountered" + +-- like combineProducts, but lifted to the AST +buildTAG :: (All POSable xs) => NP Exp xs -> Exp TAG +buildTAG SOP.Nil = Exp $ makeTag 0 +buildTAG (x :* xs) = combineProduct x (buildTAG xs) + +-- like Finite.combineProduct, but lifted to the AST +-- basically `tag x + tag y * natVal x` +combineProduct :: forall x. (POSable x) => Exp x -> Exp TAG -> Exp TAG +combineProduct x y = case sameNat (Proxy @(Choices x)) (Proxy :: Proxy 1) of + -- untagged type: `tag x = 0`, `natVal x = 1` + Just Refl -> y + -- tagged type + Nothing + | Refl :: (EltR x :~: (TAG, y)) <- unsafeCoerce Refl + -> mkAdd (mkExp $ Prj PairIdxLeft (unExp x)) (mkMul y (constant (tagVal @(Choices x)))) + +tagVal :: forall a . (KnownNat a) => TAG +tagVal = fromInteger $ natVal (Proxy @a) diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 67e341d64..de6121ea9 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -1,9 +1,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} -- | -- Module : Data.Array.Accelerate.Pattern.Maybe -- Copyright : [2018..2020] The Accelerate Team @@ -20,7 +20,43 @@ module Data.Array.Accelerate.Pattern.Maybe ( ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart as Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Pattern.Matchable +import Generics.SOP as SOP +import Data.Array.Accelerate.Representation.POS as POS -mkPattern ''Maybe +{-# COMPLETE Nothing_, Just_ #-} +pattern Nothing_ :: + forall a . + ( Elt a + , POSable a + , Matchable a + ) => Exp (Maybe a) +pattern Nothing_ <- (matchNothing -> Just ()) where + Nothing_ = buildNothing +matchNothing :: forall a . (POSable a, Elt a) => Exp (Maybe a) -> Maybe () +matchNothing x = case match (Proxy @0) x of + Just SOP.Nil -> Just () + Nothing -> Nothing + +buildNothing :: forall a . (Elt a, POSable a) => Exp (Maybe a) +buildNothing = build (Proxy @0) SOP.Nil + +pattern Just_ :: + forall a . + ( Elt a + , POSable a + , Matchable a + ) => Exp a -> Exp (Maybe a) +pattern Just_ x <- (matchJust -> Just x) where + Just_ = buildJust + +matchJust :: forall a . (Elt a, POSable a) => Exp (Maybe a) -> Maybe (Exp a) +matchJust x = case match (Proxy @1) x of + Just (x' :* SOP.Nil) -> Just x' + Nothing -> Nothing + +buildJust :: forall a . (Elt a, POSable a) => Exp a -> Exp (Maybe a) +buildJust x = build (Proxy @1) (x :* SOP.Nil) diff --git a/src/Data/Array/Accelerate/Pattern/Ordering.hs b/src/Data/Array/Accelerate/Pattern/Ordering.hs index 2407cf9e9..e6c783043 100644 --- a/src/Data/Array/Accelerate/Pattern/Ordering.hs +++ b/src/Data/Array/Accelerate/Pattern/Ordering.hs @@ -16,11 +16,11 @@ module Data.Array.Accelerate.Pattern.Ordering ( - Ordering, pattern LT_, pattern EQ_, pattern GT_, + -- Ordering, pattern LT_, pattern EQ_, pattern GT_, ) where import Data.Array.Accelerate.Pattern.TH -mkPattern ''Ordering +-- mkPattern ''Ordering diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 0323f8d1a..bf26ee8bb 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -12,444 +12,7 @@ module Data.Array.Accelerate.Pattern.TH ( - mkPattern, - mkPatterns, + -- mkPattern, + -- mkPatterns, ) where - -import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern -import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type - -import Control.Monad -import Data.Bits -import Data.Char -import Data.List ( (\\), foldl' ) -import Language.Haskell.TH.Extra hiding ( Exp, Match, match ) -import Numeric -import Text.Printf -import qualified Language.Haskell.TH.Extra as TH - -import GHC.Stack - - --- | As 'mkPattern', but for a list of types --- -mkPatterns :: [Name] -> DecsQ -mkPatterns nms = concat <$> mapM mkPattern nms - --- | Generate pattern synonyms for the given simple (Haskell'98) sum or --- product data type. --- --- Constructor and record selectors are renamed to add a trailing --- underscore if it does not exist, or to remove it if it does. For infix --- constructors, the name is prepended with a colon ':'. For example: --- --- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float } --- > deriving (Generic, Elt) --- --- Will create the pattern synonym: --- --- > Point_ :: Exp Float -> Exp Float -> Exp Point --- --- together with the selector functions --- --- > xcoord :: Exp Point -> Exp Float --- > ycoord :: Exp Point -> Exp Float --- -mkPattern :: Name -> DecsQ -mkPattern nm = do - info <- reify nm - case info of - TyConI dec -> mkDec dec - _ -> fail "mkPatterns: expected the name of a newtype or datatype" - -mkDec :: Dec -> DecsQ -mkDec dec = - case dec of - DataD _ nm tv _ cs _ -> mkDataD nm tv cs - NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c - _ -> fail "mkPatterns: expected the name of a newtype or datatype" - -mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ -mkNewtypeD tn tvs c = mkDataD tn tvs [c] - -mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ -mkDataD tn tvs cs = do - (pats, decs) <- unzip <$> go cs - comp <- pragCompleteD pats Nothing - return $ comp : concat decs - where - -- For single-constructor types we create the pattern synonym for the - -- type directly in terms of Pattern - go [] = fail "mkPatterns: empty data declarations not supported" - go [c] = return <$> mkConP tn tvs c - go _ = go' [] (map fieldTys cs) ctags cs - - -- For sum-types, when creating the pattern for an individual - -- constructor we need to know about the types of the fields all other - -- constructors as well - go' prev (this:next) (tag:tags) (con:cons) = do - r <- mkConS tn tvs prev next tag con - rs <- go' (this:prev) next tags cons - return (r : rs) - go' _ [] [] [] = return [] - go' _ _ _ _ = fail "mkPatterns: unexpected error" - - fieldTys (NormalC _ fs) = map snd fs - fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs - fieldTys (InfixC a _ b) = [snd a, snd b] - fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - - -- TODO: The GTags class demonstrates a way to generate the tags for - -- a given constructor, rather than backwards-engineering the structure - -- as we've done here. We should use that instead! - -- - ctags = - let n = length cs - m = n `quot` 2 - l = take m (iterate (True:) [False]) - r = take (n-m) (iterate (True:) [True]) - -- - bitsToTag = foldl' f 0 - where - f i False = i `shiftL` 1 - f i True = setBit (i `shiftL` 1) 0 - in - map bitsToTag (l ++ r) - - -mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) -mkConP tn' tvs' con' = do - checkExts [ PatternSynonyms ] - case con' of - NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) - RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) - InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] - _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - where - mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) - mkNormalC tn cn tvs fs = do - xs <- replicateM (length fs) (newName "_x") - r <- sequence [ patSynSigD pat sig - , patSynD pat - (prefixPatSyn xs) - implBidir - [p| Pattern $(tupP (map varP xs)) |] - ] - return (pat, r) - where - pat = rename cn - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec]) - mkRecC tn cn tvs xs fs = do - r <- sequence [ patSynSigD pat sig - , patSynD pat - (recordPatSyn xs) - implBidir - [p| Pattern $(tupP (map varP xs)) |] - ] - return (pat, r) - where - pat = rename cn - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) - mkInfixC tn cn tvs fs = do - mf <- reifyFixity cn - _a <- newName "_a" - _b <- newName "_b" - r <- sequence [ patSynSigD pat sig - , patSynD pat - (infixPatSyn _a _b) - implBidir - [p| Pattern $(tupP [varP _a, varP _b]) |] - ] - r' <- case mf of - Nothing -> return r - Just f -> return (InfixD f pat : r) - return (pat, r') - where - pat = mkName (':' : nameBase cn) - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - -mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) -mkConS tn' tvs' prev' next' tag' con' = do - checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] - case con' of - NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' - RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' - InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' - _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" - where - mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkNormalC tn cn tag tvs ps fs ns = do - let pat = rename cn - (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns - dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkRecC tn cn tag tvs xs ps fs ns = do - let pat = rename cn - (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns - dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkInfixC tn cn tag tvs ps fs ns = do - let pat = mkName (':' : nameBase cn) - (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns - dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match - return $ (pat, concat [dec_pat, dec_build, dec_match]) - - mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkNormalC_pattern tn pat tvs fs build match = do - xs <- replicateM (length fs) (newName "_x") - r <- sequence [ patSynSigD pat sig - , patSynD pat - (prefixPatSyn xs) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) - ] - return r - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkRecC_pattern tn pat tvs xs fs build match = do - r <- sequence [ patSynSigD pat sig - , patSynD pat - (recordPatSyn xs) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) - ] - return r - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] - mkInfixC_pattern tn cn pat tvs fs build match = do - mf <- reifyFixity cn - _a <- newName "_a" - _b <- newName "_b" - r <- sequence [ patSynSigD pat sig - , patSynD pat - (infixPatSyn _a _b) - (explBidir [clause [] (normalB (varE build)) []]) - (parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |]) - ] - r' <- case mf of - Nothing -> return r - Just f -> return (InfixD f pat : r) - return r' - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkBuild tn cn tvs tag fs0 fs fs1 = do - fun <- newName ("_build" ++ cn) - xs <- replicateM (length fs) (newName "_x") - let - vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |] - $ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat (reverse fs0)) - ++ map varE xs - ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] - body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] - - r <- sequence [ sigD fun sig - , funD fun [body] - ] - return (fun, r) - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) - (foldr (\t ts -> [t| $t -> $ts |]) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] - (map (\t -> [t| Exp $(return t) |]) fs)) - - - mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) - mkMatch tn pn cn tvs tag fs0 fs fs1 = do - fun <- newName ("_match" ++ cn) - e <- newName "_e" - x <- newName "_x" - (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] - unbind <- isExtEnabled RebindableSyntax - let - eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id - lhs = [p| (Exp $(varP e)) |] - body = normalB $ eqE $ caseE (varE e) - [ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) [] - , TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) [] - , TH.match wildP (normalB [| error $error_msg |]) [] - ] - - r <- sequence [ sigD fun sig - , funD fun [clause [lhs] body []] - ] - return (fun, r) - where - sig = forallT - (map (`plainInvisTV` specifiedSpec) tvs) - (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) - [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - - matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] - where - pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] - - extract [] _ ps es = return (ps, es) - extract (u:us) x ps es = do - _u <- newName "_u" - let x' = [| Prj PairIdxLeft (SmartExp $x) |] - if not u - then extract us x' (wildP:ps) es - else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) - - vs = reverse - $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] - - error_msg = - let pv = unwords - $ take (length fs + 1) - $ concatMap (map reverse) - $ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""] - in stringE $ unlines - [ "Embedded pattern synonym used outside 'match' context." - , "" - , "To use case statements in the embedded language the case statement must" - , "be applied as an n-ary function to the 'match' operator. For single" - , "argument case statements this can be done inline using LambdaCase, for" - , "example:" - , "" - , "> x & match \\case" - , printf "> %s%s -> ..." pn pv - , printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ') - ] - -fst3 :: (a,b,c) -> a -fst3 (a,_,_) = a - -thd3 :: (a,b,c) -> c -thd3 (_,_,c) = c - -rename :: Name -> Name -rename nm = - let - split acc [] = (reverse acc, '\0') -- shouldn't happen - split acc [l] = (reverse acc, l) - split acc (l:ls) = split (l:acc) ls - -- - nm' = nameBase nm - (base, suffix) = split [] nm' - in - case suffix of - '_' -> mkName base - _ -> mkName (nm' ++ "_") - -checkExts :: [Extension] -> Q () -checkExts req = do - enabled <- extsEnabled - let missing = req \\ enabled - unless (null missing) . fail . unlines - $ printf "You must enable the following language extensions to generate pattern synonyms:" - : map (printf " {-# LANGUAGE %s #-}" . show) missing - --- A simplified version of that stolen from GHC/Utils/Encoding.hs --- -type EncodedString = String - -zencode :: String -> EncodedString -zencode [] = [] -zencode (h:rest) = encode_digit h ++ go rest - where - go [] = [] - go (c:cs) = encode_ch c ++ go cs - -unencoded_char :: Char -> Bool -unencoded_char 'z' = False -unencoded_char 'Z' = False -unencoded_char c = isAlphaNum c - -encode_digit :: Char -> EncodedString -encode_digit c | isDigit c = encode_as_unicode_char c - | otherwise = encode_ch c - -encode_ch :: Char -> EncodedString -encode_ch c | unencoded_char c = [c] -- Common case first -encode_ch '(' = "ZL" -encode_ch ')' = "ZR" -encode_ch '[' = "ZM" -encode_ch ']' = "ZN" -encode_ch ':' = "ZC" -encode_ch 'Z' = "ZZ" -encode_ch 'z' = "zz" -encode_ch '&' = "za" -encode_ch '|' = "zb" -encode_ch '^' = "zc" -encode_ch '$' = "zd" -encode_ch '=' = "ze" -encode_ch '>' = "zg" -encode_ch '#' = "zh" -encode_ch '.' = "zi" -encode_ch '<' = "zl" -encode_ch '-' = "zm" -encode_ch '!' = "zn" -encode_ch '+' = "zp" -encode_ch '\'' = "zq" -encode_ch '\\' = "zr" -encode_ch '/' = "zs" -encode_ch '*' = "zt" -encode_ch '_' = "zu" -encode_ch '%' = "zv" -encode_ch c = encode_as_unicode_char c - -encode_as_unicode_char :: Char -> EncodedString -encode_as_unicode_char c - = 'z' - : if isDigit (head hex_str) then hex_str - else '0':hex_str - where - hex_str = showHex (ord c) "U" - diff --git a/src/Data/Array/Accelerate/Representation/Array.hs b/src/Data/Array/Accelerate/Representation/Array.hs index d61304e76..c337f301d 100644 --- a/src/Data/Array/Accelerate/Representation/Array.hs +++ b/src/Data/Array/Accelerate/Representation/Array.hs @@ -26,6 +26,7 @@ import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Shape hiding ( zip ) import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Sugar.Elt import Data.List ( intersperse ) import Data.Maybe ( isJust ) @@ -98,7 +99,7 @@ arraysRpair a b = TupRunit `TupRpair` TupRsingle a `TupRpair` TupRsingle b -- allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e) allocateArray (ArrayR shR eR) sh = do - adata <- newArrayData eR (size shR sh) + adata <- newArrayData eR (toElt $ size shR sh) return $! Array sh adata -- | Create an array from its representation function, applied at each @@ -114,13 +115,13 @@ fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f) fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM (ArrayR shR eR) sh f = do let !n = size shR sh - arr <- newArrayData eR n + arr <- newArrayData eR (toElt n) -- let write !i | i >= n = return () | otherwise = do v <- f (fromIndex shR sh i) - writeArrayData eR arr i v + writeArrayData eR arr (toElt i) v write (i+1) -- write 0 @@ -137,9 +138,9 @@ fromList (ArrayR shR eR) sh xs = adata `seq` Array sh adata -- !n = size shR sh (adata, _) = runArrayData @e $ do - arr <- newArrayData eR n + arr <- newArrayData eR (toElt n) let go !i _ | i >= n = return () - go !i (v:vs) = writeArrayData eR arr i v >> go (i+1) vs + go !i (v:vs) = writeArrayData eR arr (toElt i) v >> go (i+1) vs go _ [] = error "Data.Array.Accelerate.fromList: not enough input data" -- go 0 xs @@ -156,16 +157,16 @@ toList (ArrayR shR eR) (Array sh adata) = go 0 -- !n = size shR sh go !i | i >= n = [] - | otherwise = indexArrayData eR adata i : go (i+1) + | otherwise = indexArrayData eR adata (toElt i) : go (i+1) concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e -concatVectors tR vs = adata `seq` Array ((), len) adata +concatVectors tR vs = adata `seq` Array ((), fromElt len) adata where offsets = scanl (+) 0 (map (size dim1 . shape) vs) - len = last offsets + len = toElt $ last offsets (adata, _) = runArrayData @e $ do arr <- newArrayData tR len - sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) + sequence_ [ writeArrayData tR arr (toElt (i + k)) (indexArrayData tR ad (toElt i)) | (Array ((), n) ad, k) <- vs `zip` offsets , i <- [0 .. n - 1] ] return (arr, undefined) @@ -217,7 +218,9 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) | rows * cols == 0 = "[]" | otherwise = "\n [" ++ ppMat 0 0 where - (((), rows), cols) = sh + rows = toElt rows' + cols = toElt cols' + (((), rows'), cols') = sh lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr i) "")) widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) -- @@ -321,7 +324,7 @@ liftArray (ArrayR shR adR) (Array sh adata) = [|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] `at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |] where sz :: Int - sz = size shR sh + sz = toElt (size shR sh) at :: CodeQ t -> Q Type -> CodeQ t at e t = unsafeCodeCoerce $ sigE (unTypeCode e) t diff --git a/src/Data/Array/Accelerate/Representation/POS.hs b/src/Data/Array/Accelerate/Representation/POS.hs new file mode 100644 index 000000000..fb4f53394 --- /dev/null +++ b/src/Data/Array/Accelerate/Representation/POS.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} + +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- This is needed to derive POSable for tuples of size more then 4 +{-# OPTIONS_GHC -fconstraint-solver-iterations=16 #-} +-- | +-- Module : Data.Array.Accelerate.Representation.POS +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Representation.POS ( + POSable(..), Product(..), Sum(..), + Ground(..), Finite, ProductType(..), SumType(..), POSable.Generic, type (++), + mkPOSableGround, Undef(..), type Merge) + where + + +import Generics.POSable.POSable as POSable +import Generics.POSable.Representation +import Generics.POSable.Instances () +import Generics.POSable.TH diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index fa3651c03..1abf0ac2f 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -2,6 +2,14 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleContexts #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Shape @@ -19,6 +27,8 @@ module Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS +import Generics.POSable.Representation import Language.Haskell.TH.Extra import Prelude hiding ( zip ) @@ -195,3 +205,34 @@ liftShapeR :: ShapeR sh -> CodeQ (ShapeR sh) liftShapeR ShapeRz = [|| ShapeRz ||] liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||] + +instance POSable (ShapeR ()) where + type Choices (ShapeR ()) = 1 + choices _ = 0 + + tags = [1] + + fromPOSable _ _ = ShapeRz + + type Fields (ShapeR ()) = '[] + + fields ShapeRz = Nil + + emptyFields = PTNil + + +instance (POSable (ShapeR sh)) => POSable (ShapeR (sh, Int)) where + type Choices (ShapeR (sh, Int)) = 1 + choices _ = 0 + + tags = [1] + + fromPOSable 0 (Cons _ xs) = ShapeRsnoc (fromPOSable 0 xs) + + type Fields (ShapeR (sh, Int)) = '[Undef] ': Fields (ShapeR sh) + + fields (ShapeRsnoc sh) = Cons (Pick Undef) (fields sh) + + emptyFields = PTCons (STSucc Undef STZero) (emptyFields @(ShapeR sh)) + + diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index ed7e07e80..f61116b79 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -11,19 +11,13 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Representation.Tag +module Data.Array.Accelerate.Representation.Tag (TAG, TagR(..)) where -import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Type ( TAG ) -import Language.Haskell.TH.Extra --- | The type of the runtime value used to distinguish constructor --- alternatives in a sum type. --- -type TAG = Word8 - -- | This structure both witnesses the layout of our representation types -- (as TupR does) and represents a complete path of pattern matching -- through this type. It indicates which fields of the structure represent @@ -38,31 +32,5 @@ type TAG = Word8 -- (((),(1#,())),(0#,())) -- (True, False) -- (((),(1#,())),(1#,())) -- (True, True) -- -data TagR a where - TagRunit :: TagR () - TagRsingle :: ScalarType a -> TagR a - TagRundef :: ScalarType a -> TagR a - TagRtag :: TAG -> TagR a -> TagR (TAG, a) - TagRpair :: TagR a -> TagR b -> TagR (a, b) - -instance Show (TagR a) where - show TagRunit = "()" - show TagRsingle{} = "." - show TagRundef{} = "undef" - show (TagRtag v t) = "(" ++ show v ++ "#," ++ show t ++ ")" - show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" - -rnfTag :: TagR a -> () -rnfTag TagRunit = () -rnfTag (TagRsingle t) = rnfScalarType t -rnfTag (TagRundef t) = rnfScalarType t -rnfTag (TagRtag v t) = v `seq` rnfTag t -rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb - -liftTag :: TagR a -> CodeQ (TagR a) -liftTag TagRunit = [|| TagRunit ||] -liftTag (TagRsingle t) = [|| TagRsingle $$(liftScalarType t) ||] -liftTag (TagRundef t) = [|| TagRundef $$(liftScalarType t) ||] -liftTag (TagRtag v t) = [|| TagRtag v $$(liftTag t) ||] -liftTag (TagRpair ta tb) = [|| TagRpair $$(liftTag ta) $$(liftTag tb) ||] - +data TagR a = TagR TAG TAG + deriving Show diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..5c6bfa87e 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -93,6 +93,7 @@ import Data.Array.Accelerate.Representation.Stencil hiding ( Ste import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec +import Data.Array.Accelerate.Representation.POS hiding (Nil, Undef) import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Foreign @@ -510,6 +511,15 @@ data PreSmartExp acc exp t where -> exp (t1, t2) -> PreSmartExp acc exp t + LiftUnion :: exp t1 + -> PreSmartExp acc exp (UnionScalar '[t2]) + + Union :: exp (UnionScalar t1) + -> PreSmartExp acc exp (UnionScalar t2) + + PrjUnion :: exp (UnionScalar '[t1]) + -> PreSmartExp acc exp t1 + VecPack :: KnownNat n => VecR n s tup -> exp tup @@ -627,7 +637,7 @@ class Stencil sh e stencil where -- DIM1 instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e) - = EltR (e, e, e) + = ((((), EltR e), EltR e), EltR e) stencilR = StencilRunit3 @(EltR e) $ eltR @e stencilPrj s = (Exp $ prj2 s, Exp $ prj1 s, @@ -635,7 +645,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e) + = ((((((), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit5 $ eltR @e stencilPrj s = (Exp $ prj4 s, Exp $ prj3 s, @@ -645,7 +655,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e) + = ((((((((), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit7 $ eltR @e stencilPrj s = (Exp $ prj6 s, Exp $ prj5 s, @@ -658,7 +668,7 @@ instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = EltR (e, e, e, e, e, e, e, e, e) + = ((((((((((), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e), EltR e) stencilR = StencilRunit9 $ eltR @e stencilPrj s = (Exp $ prj8 s, Exp $ prj7 s, @@ -1156,21 +1166,13 @@ mkMin = mkPrimBinary $ PrimMin singleType -- Logical operators mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool -mkLAnd (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLAnd (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b +mkLAnd (Exp a) (Exp b) = mkExp $ PrimApp PrimLAnd (SmartExp $ Pair a b) mkLOr :: Exp Bool -> Exp Bool -> Exp Bool -mkLOr (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLOr (SmartExp $ Pair x y)) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b +mkLOr (Exp a) (Exp b) = mkExp $ PrimApp PrimLOr (SmartExp $ Pair a b) mkLNot :: Exp Bool -> Exp Bool -mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil - where - x = SmartExp $ Prj PairIdxLeft a +mkLNot (Exp a) = mkExp $ PrimApp PrimLNot a -- Numeric conversions @@ -1260,10 +1262,10 @@ mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool -mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary +mkPrimUnaryBool = mkPrimUnary mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool -mkPrimBinaryBool = mkCoerce @PrimBool $$$ mkPrimBinary +mkPrimBinaryBool = mkPrimBinary unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index b55158900..6cf9fee8b 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -10,7 +10,11 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE GADTs #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_HADDOCK hide #-} +{-# OPTIONS_GHC -ddump-splices #-} -- | -- Module : Data.Array.Accelerate.Sugar.Elt -- Copyright : [2008..2020] The Accelerate Team @@ -21,21 +25,25 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) +module Data.Array.Accelerate.Sugar.Elt ( Elt(..), eltRType, EltRType(..) ) where -import Data.Array.Accelerate.Representation.Elt -import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Sugar.POS () import Data.Array.Accelerate.Type -import Data.Bits import Data.Char import Data.Kind import Language.Haskell.TH.Extra hiding ( Type ) -import GHC.Generics - +import GHC.TypeLits +import Unsafe.Coerce +import Data.Type.Equality +import Data.Proxy +import Data.Typeable +import Data.Finite.Internal (Finite(..)) -- | The 'Elt' class characterises the allowable array element types, and -- hence the types which can appear in scalar Accelerate expressions of @@ -74,184 +82,206 @@ import GHC.Generics -- See the function 'Data.Array.Accelerate.match' for details on how to use -- sum types in embedded code. -- -class Elt a where +class (KnownNat (EltChoices a)) => Elt a where -- | Type representation mapping, which explains how to convert a type -- from the surface type into the internal representation type consisting -- only of simple primitive types, unit '()', and pair '(,)'. -- type EltR a :: Type - type EltR a = GEltR () (Rep a) + type EltR a = POStoEltR (Choices a) (Fields a) + + type EltChoices a :: Nat + type EltChoices a = Choices a + -- eltR :: TypeR (EltR a) tagsR :: [TagR (EltR a)] fromElt :: a -> EltR a toElt :: EltR a -> a - default eltR - :: (GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => TypeR (EltR a) - eltR = geltR @(Rep a) TupRunit - - default tagsR - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => [TagR (EltR a)] - tagsR = gtagsR @(Rep a) TagRunit - - default fromElt - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => a - -> EltR a - fromElt = gfromElt () . from - - default toElt - :: (Generic a, GElt (Rep a), EltR a ~ GEltR () (Rep a)) - => EltR a - -> a - toElt = to . snd . gtoElt @(Rep a) @() - - -class GElt f where - type GEltR t f - geltR :: TypeR t -> TypeR (GEltR t f) - gtagsR :: TagR t -> [TagR (GEltR t f)] - gfromElt :: t -> f a -> GEltR t f - gtoElt :: GEltR t f -> (t, f a) - -- - gundef :: t -> GEltR t f - guntag :: TagR t -> TagR (GEltR t f) - -instance GElt U1 where - type GEltR t U1 = t - geltR t = t - gtagsR t = [t] - gfromElt t U1 = t - gtoElt t = (t, U1) - gundef t = t - guntag t = t - -instance GElt a => GElt (M1 i c a) where - type GEltR t (M1 i c a) = GEltR t a - geltR = geltR @a - gtagsR = gtagsR @a - gfromElt t (M1 x) = gfromElt t x - gtoElt x = let (t, x1) = gtoElt x in (t, M1 x1) - gundef = gundef @a - guntag = guntag @a - -instance Elt a => GElt (K1 i a) where - type GEltR t (K1 i a) = (t, EltR a) - geltR t = TupRpair t (eltR @a) - gtagsR t = TagRpair t <$> tagsR @a - gfromElt t (K1 x) = (t, fromElt x) - gtoElt (t, x) = (t, K1 (toElt x)) - gundef t = (t, undefElt (eltR @a)) - guntag t = TagRpair t (untag (eltR @a)) - -instance (GElt a, GElt b) => GElt (a :*: b) where - type GEltR t (a :*: b) = GEltR (GEltR t a) b - geltR = geltR @b . geltR @a - gtagsR = concatMap (gtagsR @b) . gtagsR @a - gfromElt t (a :*: b) = gfromElt (gfromElt t a) b - gtoElt t = - let (t1, b) = gtoElt t - (t2, a) = gtoElt t1 - in - (t2, a :*: b) - gundef t = gundef @b (gundef @a t) - guntag t = guntag @b (guntag @a t) - -instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where - type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) - geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry TagRtag <$> gsumTagsR @(a :+: b) 0 t - gfromElt = gsumFromElt 0 - gtoElt (k,x) = gsumToElt k x - gundef t = (0xff, gsumUndef @(a :+: b) t) - guntag t = TagRpair (TagRundef scalarType) (gsumUntag @(a :+: b) t) - - -class GSumElt f where - type GSumEltR t f - gsumEltR :: TypeR t -> TypeR (GSumEltR t f) - gsumTagsR :: TAG -> TagR t -> [(TAG, TagR (GSumEltR t f))] - gsumFromElt :: TAG -> t -> f a -> (TAG, GSumEltR t f) - gsumToElt :: TAG -> GSumEltR t f -> (t, f a) - gsumUndef :: t -> GSumEltR t f - gsumUntag :: TagR t -> TagR (GSumEltR t f) - -instance GSumElt U1 where - type GSumEltR t U1 = t - gsumEltR t = t - gsumTagsR n t = [(n, t)] - gsumFromElt n t U1 = (n, t) - gsumToElt _ t = (t, U1) - gsumUndef t = t - gsumUntag t = t - -instance GSumElt a => GSumElt (M1 i c a) where - type GSumEltR t (M1 i c a) = GSumEltR t a - gsumEltR = gsumEltR @a - gsumTagsR = gsumTagsR @a - gsumFromElt n t (M1 x) = gsumFromElt n t x - gsumToElt k x = let (t, x') = gsumToElt k x in (t, M1 x') - gsumUntag = gsumUntag @a - gsumUndef = gsumUndef @a - -instance Elt a => GSumElt (K1 i a) where - type GSumEltR t (K1 i a) = (t, EltR a) - gsumEltR t = TupRpair t (eltR @a) - gsumTagsR n t = (n,) . TagRpair t <$> tagsR @a - gsumFromElt n t (K1 x) = (n, (t, fromElt x)) - gsumToElt _ (t, x) = (t, K1 (toElt x)) - gsumUntag t = TagRpair t (untag (eltR @a)) - gsumUndef t = (t, undefElt (eltR @a)) - -instance (GElt a, GElt b) => GSumElt (a :*: b) where - type GSumEltR t (a :*: b) = GEltR t (a :*: b) - gsumEltR = geltR @(a :*: b) - gsumTagsR n t = (n,) <$> gtagsR @(a :*: b) t - gsumFromElt n t (a :*: b) = (n, gfromElt (gfromElt t a) b) - gsumToElt _ t0 = - let (t1, b) = gtoElt t0 - (t2, a) = gtoElt t1 - in - (t2, a :*: b) - gsumUndef = gundef @(a :*: b) - gsumUntag = guntag @(a :*: b) - -instance (GSumElt a, GSumElt b) => GSumElt (a :+: b) where - type GSumEltR t (a :+: b) = GSumEltR (GSumEltR t a) b - gsumEltR = gsumEltR @b . gsumEltR @a - - gsumFromElt n t (L1 a) = let (m,r) = gsumFromElt n t a - in (shiftL m 1, gsumUndef @b r) - gsumFromElt n t (R1 b) = let (m,r) = gsumFromElt n (gsumUndef @a t) b - in (setBit (m `shiftL` 1) 0, r) - - gsumToElt k t0 = - let (t1, b) = gsumToElt (shiftR k 1) t0 - (t2, a) = gsumToElt (shiftR k 1) t1 - in - if testBit k 0 - then (t2, R1 b) - else (t2, L1 a) - - gsumTagsR k t = - let a = gsumTagsR @a k t - b = gsumTagsR @b k (gsumUntag @a t) - in - map (\(x,y) -> (x `shiftL` 1, gsumUntag @b y)) a ++ - map (\(x,y) -> (setBit (x `shiftL` 1) 0, y)) b - - gsumUndef t = gsumUndef @b (gsumUndef @a t) - gsumUntag t = gsumUntag @b (gsumUntag @a t) - - -untag :: TypeR t -> TagR t -untag TupRunit = TagRunit -untag (TupRsingle t) = TagRundef t -untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) + default eltR :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => TypeR (EltR a) + eltR = mkEltRT @a + + default fromElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => a -> EltR a + fromElt = mkEltR + + default toElt :: (POSable a, POStoEltR (Choices a) (Fields a) ~ EltR a) => EltR a -> a + toElt = fromEltR + + default tagsR :: (POSable a) => [TagR (EltR a)] + tagsR = f 0 (map fromInteger (tags @a)) + where + f :: TAG -> [TAG] -> [TagR (EltR a)] + f n l = case l of + [] -> [] + x : xs -> (TagR n (n + x)) : f (n + x) xs + + +-- function to bring the contraints in scope that are needed to work with EltR, +-- without needing to inspect how POS2EltR works +data EltRType x where + SingletonType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ x, Fields x ~ '[ '[x]]) => EltRType x + TaglessType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ FlattenProduct (Fields x)) => EltRType x + TaggedType :: (EltR x ~ POStoEltR (Choices x) (Fields x), EltR x ~ (TAG, FlattenProduct (Fields x))) => EltRType x + +eltRType :: forall x . POSable x => EltRType x +eltRType = case sameNat (Proxy :: Proxy (Choices x)) (Proxy :: Proxy 1) of + Just Refl -> case emptyFields @x of + PTCons (STSucc _ STZero) PTNil + | Refl :: (EltR x :~: x) <- unsafeCoerce Refl + , Refl :: (Fields x :~: '[ '[x]]) <- unsafeCoerce Refl + -> SingletonType + _ + | Refl :: (EltR x :~: FlattenProduct (Fields x)) <- unsafeCoerce Refl + , Refl :: (POStoEltR 1 (Fields x) :~: EltR x) <- unsafeCoerce Refl + -> TaglessType + Nothing + | Refl :: (EltR x :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl + , Refl :: (POStoEltR (Choices x) (Fields x) :~: (TAG, FlattenProduct (Fields x))) <- unsafeCoerce Refl + -> TaggedType + + +flattenProductType :: ProductType a -> TypeR (FlattenProduct a) +flattenProductType PTNil = TupRunit +flattenProductType (PTCons x xs) = TupRpair (TupRsingle (flattenSumType x)) (flattenProductType xs) + +flattenSumType :: SumType a -> ScalarType (UnionScalar a) +flattenSumType STZero = UnionScalarType ZeroScalarType +flattenSumType (STSucc x xs) = case flattenSumType xs of + UnionScalarType xs' -> UnionScalarType (SuccScalarType (mkSingleType x) xs') + +-- This is an unsafe conversion, and should be kept strictly in sync with the +-- set of types that implement Ground +mkScalarType :: forall a . (Typeable a, Ground a) => a -> ScalarType a +mkScalarType _ + | Just Refl <- eqT @a @Int + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int8 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int16 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int32 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Int64 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word8 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word16 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word32 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Word64 + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Half + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Float + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Double + = scalarType @a +mkScalarType _ + | Just Refl <- eqT @a @Undef + = scalarType @a + + +-- This is an unsafe conversion, and should be kept strictly in sync with the +-- set of types that implement Ground +mkSingleType :: forall a . (Typeable a, Ground a) => a -> SingleType a +mkSingleType _ + | Just Refl <- eqT @a @Int + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int8 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int16 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int32 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Int64 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word8 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word16 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word32 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Word64 + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Half + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Float + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Double + = singleType @a +mkSingleType _ + | Just Refl <- eqT @a @Undef + = singleType @a + + +mkEltRT :: forall a . (POSable a) => TypeR (POStoEltR (Choices a) (Fields a)) +mkEltRT = case eltRType @a of + SingletonType | PTCons (STSucc x STZero) PTNil <- emptyFields @a -> TupRsingle (mkScalarType x) + TaglessType -> flattenProductType (emptyFields @a) + TaggedType -> TupRpair (TupRsingle scalarTypeTAG) (flattenProductType (emptyFields @a)) + + +mkEltR :: forall a . (POSable a) => a -> POStoEltR (Choices a) (Fields a) +mkEltR x = case eltRType @a of + SingletonType | Cons (Pick f) Nil <- fields x -> f + TaglessType -> fs + TaggedType -> (cs, fs) + where + cs = fromInteger @TAG $ toInteger $ choices x + fs = flattenProduct (fields x) + +fromEltR :: forall a . (POSable a) => POStoEltR (Choices a) (Fields a) -> a +fromEltR x = case eltRType @a of + SingletonType -> x + TaglessType -> fromPOSable 0 (unFlattenProduct (emptyFields @a) x) + TaggedType | (t, fs) <- x -> fromPOSable (Finite $ toInteger t) (unFlattenProduct (emptyFields @a) fs) +unFlattenProduct :: ProductType a -> FlattenProduct a -> Product a +unFlattenProduct PTNil () = Nil +unFlattenProduct (PTCons x xs) (y, ys) = Cons (unFlattenSum x y) (unFlattenProduct xs ys) + +unFlattenSum :: SumType a -> UnionScalar a -> Sum a +unFlattenSum (STSucc x xs) (PickScalar y) = Pick y +unFlattenSum (STSucc x xs) (SkipScalar ys) = Skip $ unFlattenSum xs ys + + +flattenProduct :: Product a -> FlattenProduct a +flattenProduct Nil = () +flattenProduct (Cons x xs) = (flattenSum x, flattenProduct xs) + +flattenSum :: Sum a -> UnionScalar a +flattenSum (Pick x) = PickScalar x +flattenSum (Skip xs) = SkipScalar (flattenSum xs) -- Note: [Deriving Elt] -- @@ -284,16 +314,19 @@ untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) instance Elt () instance Elt Bool instance Elt Ordering -instance Elt a => Elt (Maybe a) -instance (Elt a, Elt b) => Elt (Either a b) +instance (POSable (Maybe a), Elt a) => Elt (Maybe a) +instance (POSable (Either a b), Elt a, Elt b) => Elt (Either a b) instance Elt Char where type EltR Char = Word32 + type EltChoices Char = 1 eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] + tagsR = [TagR 0 1] toElt = chr . fromIntegral fromElt = fromIntegral . ord +-- Anything that has a POS instance has a default Elt instance +-- TODO: build instances for the sections of newtypes runQ $ do let -- XXX: we might want to do the digItOut trick used by FromIntegral? @@ -340,12 +373,7 @@ runQ $ do mkSimple name = let t = conT name in - [d| instance Elt $t where - type EltR $t = $t - eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] - fromElt = id - toElt = id + [d| instance Elt $t |] mkTuple :: Int -> Q Dec @@ -374,23 +402,23 @@ runQ $ do -- TyConI (NewtypeD [] Foreign.C.Types.CFloat [] Nothing (NormalC Foreign.C.Types.CFloat [(Bang NoSourceUnpackedness NoSourceStrictness,ConT GHC.Types.Float)]) []) -- mkNewtype :: Name -> Q [Dec] - mkNewtype name = do - r <- reify name - base <- case r of - TyConI (NewtypeD _ _ _ _ (NormalC _ [(_, ConT b)]) _) -> return b - _ -> error "unexpected case generating newtype Elt instance" - -- - [d| instance Elt $(conT name) where - type EltR $(conT name) = $(conT base) - eltR = TupRsingle scalarType - tagsR = [TagRsingle scalarType] - fromElt $(conP (mkName (nameBase name)) [varP (mkName "x")]) = x - toElt = $(conE (mkName (nameBase name))) + mkNewtype name = + let t = conT name + in + [d| instance Elt $t |] -- ss <- mapM mkSimple (integralTypes ++ floatingTypes) + -- TODO: ns <- mapM mkNewtype newtypes - ts <- mapM mkTuple [2..16] + -- ts <- mapM mkTuple [2..16] -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] - return (concat ss ++ concat ns ++ ts) + return (concat ss ++ concat ns) + +instance Elt Undef +-- TODO: bring this back into TH +instance (POSable a, POSable b) => Elt (a, b) +instance (POSable a, POSable b, POSable c) => Elt (a, b, c) +instance (POSable a, POSable b, POSable c, POSable d) => Elt (a, b, c, d) +instance (POSable a, POSable b, POSable c, POSable d, POSable e) => Elt (a, b, c, d, e) diff --git a/src/Data/Array/Accelerate/Sugar/POS.hs b/src/Data/Array/Accelerate/Sugar/POS.hs new file mode 100644 index 000000000..f4f36a4e0 --- /dev/null +++ b/src/Data/Array/Accelerate/Sugar/POS.hs @@ -0,0 +1,127 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# OPTIONS_HADDOCK hide #-} +{-# OPTIONS_GHC -ddump-splices #-} + +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +-- This is needed to derive POSable for tuples of size more then 4 +{-# OPTIONS_GHC -fconstraint-solver-iterations=16 #-} +-- | +-- Module : Data.Array.Accelerate.Representation.POS +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Sugar.POS + where + +-- import Data.Array.Accelerate.Type + +import Language.Haskell.TH.Extra hiding ( Type ) + +import Generics.POSable.POSable as POSable +import Generics.POSable.Representation +import Generics.POSable.TH + +import Data.Int +import Data.Word +import Numeric.Half +import Foreign.C.Types + +import Data.Array.Accelerate.Type + + +runQ $ do + let + -- XXX: we might want to do the digItOut trick used by FromIntegral? + -- + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + ] + + newtypes :: [Name] + newtypes = + [ ''CShort + , ''CUShort + , ''CInt + , ''CUInt + , ''CLong + , ''CULong + , ''CLLong + , ''CULLong + , ''CFloat + , ''CDouble + , ''CChar + , ''CSChar + , ''CUChar + ] + + mkSimple :: Name -> Name -> Name -> Q [Dec] + mkSimple typ val name = + let t = conT name + -- tr = pure $ AppE (ConE val) (ConE $ mkName ("Type" ++ nameBase name)) + in + [d| + instance Ground $t where + mkGround = 0 + |] + + mkTuple :: Int -> Q Dec + mkTuple n = + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = mapM (appT [t| POSable |]) ts + in + instanceD ctx [t| POSable $res |] [] + + mkNewtype :: Name -> Q [Dec] + mkNewtype name = + let t = conT name + in + [d| + instance Ground $t where + mkGround = 0 + |] + + -- + si <- mapM (mkSimple ''IntegralType 'IntegralNumType) integralTypes + sf <- mapM (mkSimple ''FloatingType 'FloatingNumType) floatingTypes + ns <- mapM mkPOSableGround (floatingTypes ++ integralTypes) + ts <- mapM mkNewtype newtypes + nts <- mapM mkPOSableGround newtypes + -- ts <- mapM mkTuple [2..16] + -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] + return (concat si ++ concat sf ++ concat ns ++ concat ts ++ concat nts) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 1ac8bd0c4..debd97360 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -8,6 +8,8 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DataKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Sugar.Shape @@ -33,11 +35,13 @@ module Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS as POS import qualified Data.Array.Accelerate.Representation.Shape as R import qualified Data.Array.Accelerate.Representation.Slice as R import Data.Kind -import GHC.Generics +import GHC.Generics as GHC +import GHC.TypeLits (type (+)) -- Shorthand for common shape types @@ -56,14 +60,14 @@ type DIM9 = DIM8 :. Int -- | Rank-0 index -- data Z = Z - deriving (Show, Eq, Generic, Elt) + deriving (Show, Eq, GHC.Generic, POS.Generic, POSable, Elt) -- | Increase an index rank by one dimension. The ':.' operator is used to -- construct both values and types. -- infixl 3 :. data tail :. head = !tail :. !head - deriving (Eq, Generic) -- Not deriving Elt or Show + deriving (Eq, GHC.Generic) -- Not deriving Elt or Show -- We don't we use a derived Show instance for (:.) because this will insert -- parenthesis to demonstrate which order the operator is applied, i.e.: @@ -97,7 +101,7 @@ instance (Show sh, Show sz) => Show (sh :. sz) where -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data All = All - deriving (Show, Eq, Generic, Elt) + deriving (Show, Eq, GHC.Generic, POS.Generic, POSable, Elt) -- | Marker for arbitrary dimensions in 'Data.Array.Accelerate.Language.slice' -- and 'Data.Array.Accelerate.Language.replicate' descriptors. @@ -109,7 +113,7 @@ data All = All -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data Any sh = Any - deriving (Show, Eq, Generic) + deriving (Show, Eq, GHC.Generic) -- | Marker for splitting along an entire dimension in division descriptors. -- @@ -305,16 +309,21 @@ class (Slice (DivisionSlice sl)) => Division sl where instance (Elt t, Elt h) => Elt (t :. h) where type EltR (t :. h) = (EltR t, EltR h) + type EltChoices (t :. h) = 1 eltR = TupRpair (eltR @t) (eltR @h) - tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h] + tagsR = [TagR 0 1] fromElt (t:.h) = (fromElt t, fromElt h) toElt (t, h) = toElt t :. toElt h +instance POS.Generic (Any Z) +instance POSable (Any Z) instance Elt (Any Z) + instance Shape sh => Elt (Any (sh :. Int)) where type EltR (Any (sh :. Int)) = (EltR (Any sh), ()) + type EltChoices (Any (sh :. Int)) = 1 eltR = TupRpair (eltR @(Any sh)) TupRunit - tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)] + tagsR = [TagR 0 1] fromElt _ = (fromElt (Any :: Any sh), ()) toElt _ = Any diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index 723d32c7b..4972793a3 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -1,9 +1,15 @@ {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE MagicHash #-} {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -- | -- Module : Data.Array.Accelerate.Sugar.Vec -- Copyright : [2008..2020] The Accelerate Team @@ -18,22 +24,47 @@ module Data.Array.Accelerate.Sugar.Vec where import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Representation.POS import Data.Array.Accelerate.Type import Data.Primitive.Types import Data.Primitive.Vec -import GHC.TypeLits -import GHC.Prim +type VecElt a = (Elt a, Prim a, IsSingle a, Ground a, Num a) -type VecElt a = (Elt a, Prim a, IsSingle a, EltR a ~ a) +instance VecElt a => POSable (Vec2 a) where + type Choices (Vec2 a) = 1 -instance (KnownNat n, VecElt a) => Elt (Vec n a) where - type EltR (Vec n a) = Vec n a - eltR = TupRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType)) - tagsR = [TagRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType))] - toElt = id - fromElt = id + choices _ = 0 + tags = [1] -- TODO: can a Vec contain non-singleton values? + + fromPOSable 0 (Cons (Pick a) (Cons (Pick b) Nil)) = Vec2 a b + + type Fields (Vec2 a) = '[ '[a], '[a]] + fields (Vec2 a b) = Cons (Pick a) (Cons (Pick b) Nil) + + emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil) + +-- Elt instance automatically derived from POSable instance +instance VecElt a => Elt (Vec2 a) + + +instance VecElt a => POSable (Vec4 a) where + type Choices (Vec4 a) = 1 + + choices _ = 0 + + tags = [1] -- TODO: can a Vec contain non-singleton values? + + fromPOSable 0 ( Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil)))) = Vec4 a b c d + + type Fields (Vec4 a) = '[ '[a], '[a], '[a], '[a]] + fields (Vec4 a b c d) = Cons (Pick a) (Cons (Pick b) (Cons (Pick c) (Cons (Pick d) Nil))) + + emptyFields = PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) (PTCons (STSucc (mkGround @a) STZero) PTNil))) + +-- Elt instance automatically derived from POSable instance +instance VecElt a => Elt (Vec4 a) + +-- TODO: instances for 8 and 16, probably with some TH diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 94e891cc1..7b1f03d0d 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -1,19 +1,25 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Type @@ -66,6 +72,7 @@ module Data.Array.Accelerate.Type ( ) where import Data.Array.Accelerate.Orphans () -- Prim Half +import Data.Array.Accelerate.Representation.POS import Data.Primitive.Vec import Data.Bits @@ -73,10 +80,11 @@ import Data.Int import Data.Primitive.Types import Data.Type.Equality import Data.Word +import Data.Kind import Foreign.C.Types import Foreign.Storable ( Storable ) import Formatting -import Language.Haskell.TH.Extra +import Language.Haskell.TH.Extra hiding (Type) import Numeric.Half import Text.Printf @@ -84,6 +92,26 @@ import GHC.Prim import GHC.TypeLits + +-- | The type of the runtime value used to distinguish constructor +-- alternatives in a sum type. +-- +type TAG = Word8 + + +type family POStoEltR (cs :: Nat) fs :: Type where + POStoEltR 1 '[ '[x]] = x -- singletontypes + POStoEltR 1 x = FlattenProduct x -- tagless types + POStoEltR n x = (TAG, FlattenProduct x) -- all other types + +type family FlattenProduct (xss :: f [a]) = (r :: Type) | r -> f where + FlattenProduct '[] = () + FlattenProduct (x ': xs) = (UnionScalar x, FlattenProduct xs) + +type family FlattenProductType (xss :: [[a]]) :: Type where + FlattenProductType '[] = () + FlattenProductType (x ': xs) = (UnionScalarType x, FlattenProductType xs) + -- Scalar types -- ------------ @@ -120,6 +148,7 @@ data IntegralType a where TypeWord16 :: IntegralType Word16 TypeWord32 :: IntegralType Word32 TypeWord64 :: IntegralType Word64 + TypeTAG :: IntegralType TAG -- | Floating-point types supported in array computations. -- @@ -144,9 +173,22 @@ data BoundedType a where data ScalarType a where SingleScalarType :: SingleType a -> ScalarType a VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) + UnionScalarType :: UnionScalarType a -> ScalarType (UnionScalar a) + +class IsUnionScalar a where + unionScalarType :: UnionScalarType a + +data UnionScalar x where + PickScalar :: x -> UnionScalar (x ': xs) + SkipScalar :: UnionScalar xs -> UnionScalar (x ': xs) + +data UnionScalarType a where + SuccScalarType :: SingleType x -> UnionScalarType xs -> UnionScalarType (x ': xs) + ZeroScalarType :: UnionScalarType '[] data SingleType a where NumSingleType :: NumType a -> SingleType a + UndefSingleType :: SingleType Undef data VectorType a where VectorType :: KnownNat n => {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) @@ -162,6 +204,7 @@ instance Show (IntegralType a) where show TypeWord16 = "Word16" show TypeWord32 = "Word32" show TypeWord64 = "Word64" + show TypeTAG = "TAG" instance Show (FloatingType a) where show TypeHalf = "Half" @@ -177,6 +220,7 @@ instance Show (BoundedType a) where instance Show (SingleType a) where show (NumSingleType ty) = show ty + show UndefSingleType = "Undef" instance Show (VectorType a) where show (VectorType n ty) = printf "<%d x %s>" n (show ty) @@ -184,6 +228,12 @@ instance Show (VectorType a) where instance Show (ScalarType a) where show (SingleScalarType ty) = show ty show (VectorScalarType ty) = show ty + show (UnionScalarType ty) = show ty + +instance Show (UnionScalarType a) where + show ZeroScalarType = "" + show (SuccScalarType x (ZeroScalarType)) = show x + show (SuccScalarType x xs) = show x ++ " + " ++ show xs formatIntegralType :: Format r (IntegralType a -> r) formatIntegralType = later $ \case @@ -197,6 +247,7 @@ formatIntegralType = later $ \case TypeWord16 -> "Word16" TypeWord32 -> "Word32" TypeWord64 -> "Word64" + TypeTAG -> "TAG" formatFloatingType :: Format r (FloatingType a -> r) formatFloatingType = later $ \case @@ -269,6 +320,7 @@ integralDict TypeWord8 = IntegralDict integralDict TypeWord16 = IntegralDict integralDict TypeWord32 = IntegralDict integralDict TypeWord64 = IntegralDict +integralDict TypeTAG = IntegralDict floatingDict :: FloatingType a -> FloatingDict a floatingDict TypeHalf = FloatingDict @@ -318,6 +370,9 @@ scalarTypeWord8 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord8 scalarTypeWord32 :: ScalarType Word32 scalarTypeWord32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord32 +scalarTypeTAG :: ScalarType TAG +scalarTypeTAG = SingleScalarType $ NumSingleType $ IntegralNumType TypeTAG + rnfScalarType :: ScalarType t -> () rnfScalarType (SingleScalarType t) = rnfSingleType t rnfScalarType (VectorScalarType t) = rnfVectorType t @@ -518,3 +573,18 @@ runQ $ do -- return (concat is ++ concat fs ++ concat vs) + +instance IsSingle Undef where + singleType = UndefSingleType + +instance IsScalar Undef where + scalarType = SingleScalarType singleType + +instance (IsUnionScalar a) => IsScalar (UnionScalar a) where + scalarType = UnionScalarType (unionScalarType @a) + +instance IsUnionScalar '[] where + unionScalarType = ZeroScalarType + +instance (IsSingle x, IsUnionScalar xs) => IsUnionScalar (x ': xs) where + unionScalarType = SuccScalarType (singleType @x) (unionScalarType @xs) diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 0342f401c..f0a65ca2d 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -34,6 +35,8 @@ module Data.Primitive.Vec ( listOfVec, liftVec, + replicateVecN, + ) where import Control.Monad.ST @@ -48,6 +51,7 @@ import GHC.Prim import GHC.TypeLits import GHC.Word +import Data.Proxy -- Note: [Representing SIMD vector types] -- @@ -259,6 +263,14 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# +replicateVecN :: forall a n . (KnownNat n, Prim a) => a -> Vec n a +replicateVecN x = runST $ do + let n = fromInteger $ natVal (Proxy :: Proxy n) + mba <- newByteArray (n * sizeOf x) + mapM_ (\n' -> writeByteArray mba n' x) [0..n] + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). diff --git a/stack-8.10.yaml b/stack-8.10.yaml index d0823dcbd..dd0347506 100644 --- a/stack-8.10.yaml +++ b/stack-8.10.yaml @@ -7,7 +7,8 @@ resolver: lts-18.25 packages: - . -# extra-deps: +extra-deps: +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-8.6.yaml b/stack-8.6.yaml index 5d3724662..4e7118946 100644 --- a/stack-8.6.yaml +++ b/stack-8.6.yaml @@ -12,6 +12,7 @@ extra-deps: - prettyprinter-ansi-terminal-1.1.3 - tasty-rerun-1.1.18 - text-1.2.4.1 +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-8.8.yaml b/stack-8.8.yaml index f9565e8b4..d4a9f59a0 100644 --- a/stack-8.8.yaml +++ b/stack-8.8.yaml @@ -9,6 +9,7 @@ packages: extra-deps: - formatting-7.1.3 - prettyprinter-1.7.1 +- posable-1.0.0.1 # Override default flag values for local packages and extra-deps # flags: {} diff --git a/stack-9.0.yaml b/stack-9.0.yaml index 1349abd27..9579a9e81 100644 --- a/stack-9.0.yaml +++ b/stack-9.0.yaml @@ -7,7 +7,9 @@ resolver: nightly-2022-02-16 packages: - . -# extra-deps: [] +extra-deps: +- posable-1.0.0.1 + # Override default flag values for local packages and extra-deps # flags: {}