-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathCategorify.hs
More file actions
1359 lines (1302 loc) · 66.6 KB
/
Copy pathCategorify.hs
File metadata and controls
1359 lines (1302 loc) · 66.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
-- | The high-level transformation between GHC's `Plugins.CoreExpr` and the abstract categorical
-- representation, as described in [Compiling to
-- Categories](http://conal.net/papers/compiling-to-categories/compiling-to-categories.pdf).
module Categorifier.Core.Categorify
( AutoInterpreter,
categorify,
applyTyAndPredArgs,
isTypeOrPred,
simplifyFun,
)
where
import qualified Categorifier.Core.Benchmark as Bench
import Categorifier.Core.MakerMap
( MakerMapFun',
composeCat,
curryCat,
forkCat,
handleAdditionalArgs,
makeMaker1,
makeMaker2,
splitNameString,
)
import Categorifier.Core.Makers (Makers (..), isCalledIn, isFreeIn)
import qualified Categorifier.Core.PrimOp as PrimOp
import Categorifier.Core.Trace (maybeTraceWith, maybeTraceWithStack, renderSDoc)
import Categorifier.Core.Types
( AutoInterpreter,
CategoricalFailure (..),
CategoryStack,
CategoryState (..),
DictCacheEntry (..),
DictionaryFailure (..),
DictionaryStack,
liftDictionaryStack,
)
import Categorifier.Duoidal (joinD, sequenceD, traverseD, (<*\>), (<=\<), (=<\<))
import qualified Categorifier.GHC.Builtin as Plugins
import qualified Categorifier.GHC.Core as Plugins
import qualified Categorifier.GHC.Data as Plugins
import qualified Categorifier.GHC.Driver as Plugins
import qualified Categorifier.GHC.Types as Plugins
import qualified Categorifier.GHC.Utils as Plugins
import Categorifier.Hierarchy (BaseIdentifiers (..), getLast, pattern Last)
import qualified Categorifier.TH as TH
import Control.Arrow (Arrow ((&&&)))
import Control.Monad (when, (<=<))
import Control.Monad.Extra (loopM, unless)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT (..), catchE, runExceptT, throwE, withExceptT)
import Control.Monad.Trans.RWS.Strict (ask, gets, local, modify)
import Data.Bifunctor (Bifunctor (..))
import Data.Bitraversable (bitraverse)
import Data.Bool (bool)
import Data.Functor ((<&>))
import Data.Functor.Alt ((<!>))
import Data.Functor.Transformer (tmap)
import Data.Generics.Uniplate.Data (transformBi, transformM, universeBi)
import Data.List.Extra (isPrefixOf, isSuffixOf, notNull, sortOn)
import Data.List.NonEmpty (NonEmpty ((:|)))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, mapMaybe)
import qualified Data.Set as Set
import Data.Traversable (for)
import Data.Tuple.Extra (first3)
import PyF (fmt)
import Prelude hiding (head)
-- Need Uniplate for traversals on GHC-provided recursive types
{-# ANN module ("HLint: ignore Avoid restricted module" :: String) #-}
-- | This is named as a pun on `Categorifier.Categorify.expression`, as it's effectively the "real"
-- implementation of that pseudo-function.
categorify ::
-- | Enable debugging
Bool ->
-- | Enable benchmarking
Bool ->
Plugins.DynFlags ->
Plugins.Logger ->
-- | Target category
Plugins.Type ->
(Plugins.Type -> DictionaryStack Plugins.CoreExpr) ->
BaseIdentifiers ->
Makers ->
Makers ->
AutoInterpreter ->
MakerMapFun' ->
(Makers -> [(Plugins.CLabelString, (PrimOp.Boxer, [Plugins.Type], Plugins.Type))]) ->
-- | The @a -> b@ parameter from `Categorifier.Categorify.expression`.
Plugins.CoreExpr ->
-- | The @c a b@ result of `Categorifier.Categorify.expression` (where @c@ represents the arrow of
-- the target category).
CategoryStack Plugins.CoreExpr
categorify
debug
bench
dflags
logger
cat
buildDictionary
baseIdentifiers
baseMakers
makers
tryAutoInterpret
makerMapFun
additionalBoxers
fun = do
res0 <-
maybeTraceWith debug (const "---------- categorify ----------")
. Bench.billToUninterruptible bench Bench.Categorify
$ categorifyFun fun
-- It seems GHC simplifier's memory usage is determined by the size of the largest top level
-- bind. Therefore, we float some local let-binds out to the top level, which reduces the
-- size of the largest top-level bind.
--
-- The `markBindNoInline` saves some work by preventing the simplifier from inlining these
-- binds. The simplifier would otherwise attempt to inline some of them as part of
-- `PreInlineUnconditionally` and `PostInlineUnconditionally`.
(res1, fmap markBindNoInline -> binds) <- floatLetsOut res0
(fmap markBindNoInline -> dictVarBinds) <-
fmap (uncurry Plugins.NonRec . fst) . sortOn snd <$> getCreatedDictVars
let res =
maybeTraceWith debug (\res' -> [fmt|result size: {show $ Plugins.exprSize res'}|])
. maybeTraceWith debug (thump "result")
. Plugins.mkCoreLets dictVarBinds
$ Plugins.mkCoreLets binds res1
when bench Bench.displayTimes
pure res
where
floatLetsOut :: Plugins.CoreExpr -> CategoryStack (Plugins.CoreExpr, [Plugins.CoreBind])
floatLetsOut = go []
where
go ::
[Plugins.CoreBind] ->
Plugins.CoreExpr ->
CategoryStack (Plugins.CoreExpr, [Plugins.CoreBind])
go binds = \case
Plugins.App f arg -> do
(f', binds') <- go binds f
first (Plugins.App f') <$> go binds' arg
Plugins.Cast e co ->
first (`Plugins.Cast` co) <$> go binds e
Plugins.Case e b ty alts ->
first (\e' -> Plugins.Case e' b ty alts) <$> go binds e
Plugins.Let (Plugins.NonRec v rhs) e -> do
v' <- uniquifyVarName v
let e' = subst [(v, Plugins.Var v')] e
second (Plugins.NonRec v' rhs :) <$> go binds e'
-- In all other cases we stop descending. In particular, we must stop descending
-- in the `Lam` case, because a `Let` under a `Lam` may refer to a lambda-bound
-- var, and thus cannot be floated out. We could potentially float out a `Let`
-- that doesn't refer to any lambda bound var (also called full-laziness), but
-- we currently don't bother to do so.
other -> pure (other, binds)
funName = case fst (Plugins.collectArgs fun) of
Plugins.Var v -> Just . Plugins.occNameString . Plugins.nameOccName $ Plugins.varName v
_ -> Nothing
-- "Since we are translating function-typed terms, we can assume that we have an explicit
-- abstraction, /λ(x :: τ) → U/ for some term /U/; otherwise, simply η-expand." ⸻§3
categorifyFun :: Plugins.CoreExpr -> CategoryStack Plugins.CoreExpr
categorifyFun =
maybeTraceWithStack debug (thump "fun") $ \case
Plugins.Lam x u -> categorifyLambda x u
-- `Plugins.Cast` at the top level is difficult. Within `categorifyLambda`, we can always
-- compose the coercion, but here we need to apply the coercion to a function, and then
-- categorify /that/. This currently only handles simple cases, where we can either ignore
-- the coercion or separate it into coercions for the domain and codomain, which we can
-- then compose.
to@(Plugins.Cast from co) ->
case co of
Plugins.FunCo {} ->
joinD $
( \(a', b') (a, b) ->
joinD $
composeCat makers
<$> mkCoerce makers b' b
<*\> joinD
(composeCat makers <$> categorifyFun from <*\> mkCoerce makers a a')
)
<$> extractTypes from
<*\> extractTypes to
Plugins.Refl {} -> categorifyFun from
Plugins.TransCo inner outer ->
categorifyFun $ Plugins.Cast (Plugins.Cast from inner) outer -- NON-INDUCTIVE
_ -> throwE . pure $ UnsupportedCast from co
where
extractTypes expr =
let eTy = Plugins.exprType expr
in maybe (throwE . pure . NotFunTy expr $ eTy) pure $
Plugins.splitFunTy_maybe eTy
Plugins.Tick tickish expr -> Plugins.Tick tickish <$> categorifyFun expr
-- __NB__: `Plugins.etaExpand` can result in `Plugins.Cast` and `Plugins.Tick` in addition
-- to the expected `Plugins.Lam`, so we handle those cases here recursively.
e -> categorifyFun $ Plugins.etaExpand 1 e -- NON-INDUCTIVE
categorifyLambda = categorifyLambda' MakeConst
categorifyLambda' ::
MakeOrIgnoreConst ->
Plugins.Var ->
Plugins.CoreExpr ->
CategoryStack Plugins.CoreExpr
categorifyLambda' makeOrIgnoreConst name =
maybeTraceWithStack debug (thump "lam" . Plugins.Lam name) $ \case
Plugins.Coercion co -> throwE . pure . UnsupportedDependentType name $ Left co
Plugins.Type ty -> throwE . pure . UnsupportedDependentType name $ pure ty
-- "The remaining case is a constant as abstraction body, i.e., /λx → c/." ⸻§3
body
| MakeConst <- makeOrIgnoreConst,
not (name `isFreeIn` body) ->
( maybe (tryMkConst name) (mkConstFun (Plugins.varType name) . fst)
. Plugins.splitFunTy_maybe
. Plugins.dropForAlls
$ Plugins.exprType body
)
body
makers
-- "First consider the case that the abstraction body is a variable. Since our terms are
-- closed and well-typed, there is only one possible variable choice, so we must have the
-- identity function on /τ/: /(λx → x) ≡ id :: τ → τ/." ⸻§3
Plugins.Var y ->
if name == y
then mkId makers $ Plugins.varType name
else categorifyLambda name =<\< mkInline (Plugins.Var y)
-- `tagToEnum#` in enum comparisons
--
-- Typically we see `tagToEnum#` within some other primitive-related expression.
-- But in some situations, we see @`tagToEnum#` (`==#` <some tag> <another tag>)@.
--
-- We have two options for detecting this situation:
--
-- - right after the `==` for the enum is inlined, when it's a nested series of `case`
-- expressions to primitivise the enums around the comparison
--
-- - right before we try to inline `tagToEnum#` itself, when it's at the outside of the
-- application
--
-- The advantage of the former is that we can keep all the primop stuff inside
-- `case`-handling in this function, but the latter is preferable because at this
-- point, the `let` and `case` bindings have been moved beneath `tagToEnum#` itself,
-- so we can trigger primop replacement without having to look deeper than the top
-- level of the expression.
e@(Plugins.App _ _)
| Just _t2e <- PrimOp.matchTagToEnumApp e ->
handlePrimOps "applying `tagToEnum#'" name e $ Plugins.exprType e
-- "Translating an application (as abstraction body) is a little more involved, involving
-- the /Category/, /Cartesian/, and /Closed/ instances for functions" ⸻§3
e@(Plugins.App (Plugins.collectArgs -> (head, args')) arg) ->
let args = args' <> [arg]
in case head of
Plugins.Var ident
| ident /= name ->
maybe
-- If the expression is a categorical operation, translate it directly.
(interpretVocabulary name e ident)
(categorifyDataCon name e)
(Plugins.isDataConId_maybe ident)
args
-- This handles `Cast from co` where `from` is a dictionary, and `co` is an
-- `AxiomInstCo`. Such a `Cast` mostly likely comes from single-method classes
-- (for multi-method classes you'd get the benign `$fFoo_$cfoo` stuff). The way
-- we deal with it is to inline `from` to get `Cast (Cast from' co') co`, where
-- `co` cancels `co'`, and `from'` has the right type (same type as
-- `$fFoo_$cfoo`). If the inlined expression is not in this form, we keep
-- inlining, until it is. Then we simply discard `co` and `co'`, and proceed
-- with `from'`.
Plugins.Cast from0 Plugins.AxiomInstCo {}
| Plugins.isPredTy (Plugins.exprType from0) -> do
inlined <- flip loopM from0 $ \from -> do
( \case
Plugins.Cast from' _ ->
if Plugins.isPredTy (Plugins.exprType from')
then Left from'
else Right from'
other -> Left other
)
<$> (simplifyFun dflags logger [] =<\< mkInline from)
categorifyLambda name
=<\< simplifyFun dflags logger [] (Plugins.mkCoreApps inlined args)
-- Convert all the arguments of an application at once.
_
| let (tyArgs, otherArgs) = spanTypes args,
notNull tyArgs -> do
{-
handleExtraArgs can only handle term args, not type args. This is because
it uses mkApply, which basically creates @(a -> b, a) `k` b@. However,
to apply a type arg, e.g., apply @A@ to @forall a. Maybe a@, we need to
create @((a :: *) -> Maybe a, A) `k` Maybe A@. This doesn't align with
mkApply.
The approach taken here is to coerce the CoreExpr of type
@forall a. Maybe a@ into type @Maybe A@.
-}
headCoerced <-
Plugins.App
<$> mkCoerce
baseMakers
(Plugins.exprType head)
(Plugins.exprType (Plugins.mkTyApps head tyArgs))
<*\> pure head
handleExtraArgs makers name otherArgs =<\< categorifyLambda name headCoerced
| otherwise ->
-- If we are dealing with something like
--
-- @
-- (let ... in (let-binding, 0 or more times)
-- case ... of ... -> (single-alt case, 0 or more times)
-- \x y ... -> body) arg1 arg2 ...
-- @
--
-- then instead of simply categorifying `head` and each arg, we check if we
-- should perform any substitutions, i.e., substitute `arg1` for `x`, `arg2`
-- for `y`, etc.
--
-- This is not just an optimization, but is in fact a necessary step.
-- The `simplifyFun []` after inlining is supposed to perform beta-reductions.
-- However, sometimes a `case` expression with a single `DEFAULT` case may
-- prevent beta-reductions from being performed. As a result, here we may be
-- faced with something like
--
-- @
-- (case eq_sel ($p3(%,,%) ($d(%,,%)_aKHi `cast` <Co:25>)) of co_aKND
-- { __DEFAULT -> let ...
-- in \x -> body
-- }) arg
-- @
--
-- i.e., `simplifyFun` failed to substitute `arg` for `x` here, due to the
-- existence of the `case`.
--
-- This could cause problems. For instance, if `arg` is a constant, then an
-- expression that mentions `arg` and doesn't depend on the input can be
-- categorified via `mkConst`. But if we try to categorify `x -> body`,
-- then since `x` is now an argument, anything in `body` that mentions `x`
-- cannot be categorified via `mkConst`. Therefore, in this case, we must
-- substitute `arg` for `x`.
let (xs, bndrs, body) = collectNestedBinders head
(newBody, remainingBndrs, remainingArgs) = substBndrs body bndrs args
in if length bndrs == length remainingBndrs
then handleExtraArgs makers name args =<\< categorifyLambda name head
else
categorifyLambda name . addLetsAndCases xs $
Plugins.mkCoreApps
(Plugins.mkCoreLams remainingBndrs newBody)
remainingArgs
-- "If the body of an abstraction is an abstraction, we can curry a translation of the
-- uncurried form:" ⸻§3
Plugins.Lam name' body -> do
let pair =
freshId
(Plugins.exprFreeVars body)
"pair"
(Plugins.mkBoxedTupleTy [Plugins.varType name, Plugins.varType name'])
sub <-
traverseD
sequenceD
[(name, mkFst makers (Plugins.Var pair)), (name', mkSnd makers (Plugins.Var pair))]
curryCat makers =<\< categorifyLambda pair (subst sub body)
-- The @unsafeBinder@ and @unsafeAlts@ are unsafe because they are not necessarily unique.
--
-- The lack of uniqueness of @unsafeBinder@ can be observed by building
-- //code_generation/generate:CalcTracking on 15465/10, where you'd see
--
-- @
-- lam:
-- \ (wild_aAtb :: Type1) (wild_aAtb :: Type2) -> ...
-- @
--
-- which leads to a core lint error. To get a unique binder, use @withBinder@.
--
-- The lack of uniqueness of @unsafeAlts@ can be observed by building
-- //code_generation/generate:BallFollower on 15833/2, where you'd see @x_azfm@
-- and @y_azfn@ used repeatedly, causing shadowing.
caseExpr@(Plugins.Case scrut (Plugins.zapIdOccInfo -> unsafeBinder) typ unsafeAlts) -> do
alts <- for unsafeAlts $ \(Plugins.Alt altCon unsafeBoundVars rhs) -> do
boundVars <- traverse uniquifyVarName unsafeBoundVars
pure . Plugins.Alt altCon boundVars $
subst (zip unsafeBoundVars $ fmap Plugins.Var boundVars) rhs
let withBinder f = do
binder <-
-- If @unsafeBinder@ occurs in any of the @Alt@s, we don't bother making
-- a new unique binder, because if it is not unique, we don't know which
-- one the occurrence refers to.
if any (\(Plugins.Alt _ _ rhs) -> isFreeIn unsafeBinder rhs) alts
then pure unsafeBinder
else uniquifyVarName unsafeBinder
categorifyLambda name . Plugins.Let (Plugins.NonRec binder scrut) =<\< f binder
case alts of
-- "For __case__ expressions, suppose the scrutinee expression has a product type"
-- ⸻§3
[Plugins.Alt (Plugins.DataAlt dc) [a, b] rhs] | Plugins.isTupleDataCon dc ->
withBinder $ \binder -> do
bindFst <- do
if a `isFreeIn` rhs
then Plugins.Let . Plugins.NonRec a <$> mkFst makers (Plugins.Var binder)
else pure id
bindSnd <-
if b `isFreeIn` rhs
then Plugins.Let . Plugins.NonRec b <$> mkSnd makers (Plugins.Var binder)
else pure id
pure $ bindFst (bindSnd rhs)
-- "Distributive categories enable translation of definition by cases. Consider
-- only __case__ over binary sums /a + b/ for now." ⸻§8
[ Plugins.Alt (Plugins.DataAlt left) [a] lrhs,
Plugins.Alt (Plugins.DataAlt right) [b] rrhs
]
| Plugins.dataConName left == Plugins.leftDataConName
&& Plugins.dataConName right == Plugins.rightDataConName ->
withBinder $
mkEither makers (Plugins.Lam a lrhs) (Plugins.Lam b rrhs) . Plugins.Var
-- @if@ is represented in Core as a @case@ on `Bool`.
[ Plugins.Alt (Plugins.DataAlt false) [] rhsF,
Plugins.Alt (Plugins.DataAlt true) [] rhsT
]
| false == Plugins.falseDataCon && true == Plugins.trueDataCon ->
joinD $
composeCat makers
<$> mkIf makers typ
<*\> joinD
( forkCat makers
<$> categorifyLambda name scrut
<*\> joinD
( forkCat makers
<$> categorifyLambda name rhsT
<*\> categorifyLambda name rhsF
)
)
-- @Data.Constraint.Dict@ contains a constraint, so it can't have a
-- @HasRep@ instance. Here we handle it as a special case.
[Plugins.Alt (Plugins.DataAlt dc) [v] rhs]
| isDictDataCon dc,
let predTy = Plugins.varType v,
Plugins.isPredTy predTy ->
let findPred [] = liftDictionaryStack predTy scrut $ buildDictionary predTy
findPred (x : xs) =
findTypeInTuple makers predTy (Plugins.Var x) >>= maybe (findPred xs) pure
in do
expr <- findPred $ universeBi scrut
categorifyLambda name $ Plugins.Let (Plugins.NonRec v expr) rhs
-- "One more transformation eliminates the unboxing __case__ scrutinees: transform
-- an expression like “__case__ /a/ __of__ /I# x/ → ... /boxI x/ ...” to “__let__
-- /x′/ = /a/ __in__ ... /x′/ ...”." ⸻§10.1
--
-- Here we are matching a case-statement that matches apart a boxed value into an
-- unboxed (primitive) one, so that @_rhs@ contains some unboxed operations. This
-- type of destructuring bind will eventually be removed in case 3 of
-- `Plugins.checkForUnboxedVars` below.
[Plugins.Alt (Plugins.DataAlt con) [_unboxedV] _rhs]
-- We look for the boxing constructor for the type that this case statement
-- returns.
| con `elem` fmap snd primConMap ->
handlePrimOps "unboxing" name (Plugins.Case scrut unsafeBinder typ alts) typ
-- This case handles calls to `fromInteger` and `fromIntegral` at the top level.
-- It is separate from the preceding case because it involves reboxing rather than
-- unboxing a primitive (the unboxing of the argument typically occurs within the
-- right-hand side of the case alternative); its contents must still be handled
-- with `replacePrimOps`.
[Plugins.Alt Plugins.DEFAULT [] rhs]
| Plugins.isCoVar unsafeBinder ->
if unsafeBinder `isFreeIn` rhs
then do
binder <- uniquifyVarName unsafeBinder
res <-
categorifyLambda name $
transformBi (\case v | v == unsafeBinder -> binder; other -> other) rhs
-- Here we are making a `Plugins.Case`, rather than a `Plugins.Let`
-- (as `withBinder` would do), because core lint complains
-- "bad `let` binding" when a let-binding has a coercion type.
pure $
Plugins.Case
scrut
binder
(Plugins.exprType res)
[Plugins.Alt Plugins.DEFAULT [] res]
else categorifyLambda name rhs
-- `frominteger`
| Just toTy <- PrimOp.matchOnUniverse PrimOp.matchFloatFromIntegralApp scrut,
Just _dc <- PrimOp.matchOnUniverse (PrimOp.matchBoxingApp primConMap) rhs ->
-- This is the original case at the top level of `categorifyLambda`.
handlePrimOps
"fromInteger"
name
(Plugins.Case scrut unsafeBinder typ alts)
toTy
-- `fromIntegral`
| Just i2iApp <- PrimOp.matchOnUniverse PrimOp.matchIntegerToIntApp scrut,
Just _toIApp <- PrimOp.matchOnUniverse PrimOp.matchToIntegerApp i2iApp,
Just _dc <- PrimOp.matchOnUniverse (PrimOp.matchBoxingApp primConMap) rhs ->
-- This is the original case at the top level of `categorifyLambda`.
handlePrimOps
"fromIntegral"
name
(Plugins.Case scrut unsafeBinder typ alts)
typ
-- Also need to handle the unit case.
[Plugins.Alt _ [] rhs] -> withBinder $ \_binder -> pure rhs
-- When the scrut's type is a constraint (e.g., `Num (C Double)`), we must
-- specialize the whole case expression, because constraints don't have `HasRep`
-- instances. This is achieved by simplifying it with `Inline` and `Rules`.
[Plugins.Alt {}]
| Plugins.isPredTy (Plugins.varType unsafeBinder) ->
categorifyLambda name
=<\< simplifyFun dflags logger [Plugins.Inline, Plugins.Rules] caseExpr
-- "consider a __case__ expression /case scrut of { p1 → rhs1; ...; pn → rhsn }/,
-- where (the scrutinee) /scrut/ has a non-standard type with a /HasRep/
-- instance. Rewrite /scrut/ to /inline abst (repr scrut)/ (this time inlining
-- /abst/ instead of /repr/). GHC’s usual simplifications will then replace the
-- __case__ over a non-standard type with a __case__ over a standard type or one
-- closer to standard." ⸻§9
-- __NB__: This case can pretty easily cause an infinite loop, so we should be
-- very careful with handling the product and coproduct @case@ cases.
_ -> do
let bindTy =
maybeTraceWith
debug
( \bt ->
[fmt|case> fallback:
scrut: {dbg scrut}
binder: {dbg unsafeBinder}
binder type: {dbg bt}
type: {dbg typ}
alts: {dbg alts}|]
)
-- scrut type sometimes differs from binder type, e.g.,
--
-- scrut type: @Vec ('S ('S ('S 'Z))) (PressureSensorStatus C)@
-- binder type: @Vec n1 (PressureSensorStatus C)@
--
-- In this case we can't use binder type.
$ Plugins.exprType scrut
abst <- inlineHasRep =<\< mkAbst makers bindTy
repr <- mkRepr makers bindTy
-- NON-INDUCTIVE
-- Here we expect `simplifyFun` to apply the `let`-substitution, case-of-case,
-- and case-of-known-constructor transformations.
categorifyLambda name <=\< simplifyFun dflags logger [Plugins.CaseOfCase] $
Plugins.Case (Plugins.App abst (Plugins.App repr scrut)) unsafeBinder typ alts
Plugins.Let bind expr -> case bind of
Plugins.NonRec v rhs ->
if not (name `isFreeIn` rhs)
&&
-- Don't float out join points, because doing so may cause errors like this:
--
-- ghc: panic! (the 'impossible' happened)
-- (GHC version 8.10.1:
-- GHC.StgToCmm.Env: variable not found
-- $j_sdPJ
not (Plugins.isJoinId v)
then -- Float bindings outside of lambdas when possible. This is an optimization,
-- but more importantly it prevents us from trying to inline type class
-- dictionaries (which GHC does not want to do) by moving them outside the term
-- being categorified. E.g.,
--
-- > categorify $ \x -> let $pIsPrimitive = ... in myFun $pIsPrimitive x
--
-- becomes
--
-- > let $pIsPrimitive = ... in categorify $ \x -> myFun $pIsPrimitive x
--
-- rather than
--
-- > categorify $ \x' -> (\(x, $pIsPrimitive) -> myFun $pIsPrimitive x) (x', ...)
--
-- This also stores the binding in case we need to look it up and categorify it
-- later (e.g., in the case of join points).
fmap (Plugins.Let bind) . tmap (local (Map.insert v rhs)) $
categorifyLambda name expr
else -- Either substitute the @rhs@ in @expr@, or rewrite as a lambda.
-- Whether substituting or rewriting is determined by the number of occurrence of
-- `v` in `expr`, which we can obtain from `OccInfo` or by whether `expr` consists
-- only of projections.
--
-- NOTE: The rationale is as follows. @let@ bindings can always be desugared to an
-- application of a lambda, which potentially involves fewer common subexpressions.
-- However, because of the handling of nested lambdas from the paper, that can
-- sometimes lead to an explosion in argument size, eventually resulting in
-- additional curry/uncurry calls and slowing down the process of @buildDictionary@.
-- So here we instead substitute the bound term in some cases, most notably when
-- the term consists only of projections from @name@.
-- TODO (#22): currently the v's `OccInfo` is always `ManyOccs`, probably
-- because we forget to `zapIdOccInfo` somewhere. If the `OccInfo` is accurate,
-- we can obtain `isManyOccs` from it rather than manually counting.
let isManyOccs = case filter (== v) $ universeBi expr of
_ : _ : _ -> True
_ -> False
in if Plugins.isJoinId v
|| not isManyOccs
|| hasOnlyProjections (const True) rhs
|| Plugins.isPredTy (Plugins.varType v)
|| any isDictOrBarbiesDictTyCon (universeBi (Plugins.varType v))
|| Plugins.isForAllTy (Plugins.varType v)
then
categorifyLambda name
=<\< bool
pure
-- If `v` has a polymorphic type, we run the simplifier
-- to apply the type argument(s).
(simplifyFun dflags logger [])
(Plugins.isForAllTy (Plugins.varType v))
(subst [(v, rhs)] expr)
else
categorifyLambda
name
(Plugins.App (Plugins.Lam v expr) rhs) -- NON-INDUCTIVE
Plugins.Rec [(v, rhs)] -> do
nonRec <- Plugins.NonRec v <$> unfix v [] rhs
categorifyLambda name $ Plugins.Let nonRec expr
Plugins.Rec binds -> throwE . pure $ UnsupportedMutuallyRecursiveLetBindings binds
to@(Plugins.Cast from _) ->
joinD $
composeCat makers
<$> mkCoerce makers (Plugins.exprType from) (Plugins.exprType to)
<*\> categorifyLambda name from
Plugins.Tick tickish body -> Plugins.Tick tickish <$> categorifyLambda name body
-- This case is covered by "constant as abstraction body", but hard to convince GHC of
-- that, so we duplicate the relevant logic here.
Plugins.Lit lit -> mkConst' makers (Plugins.varType name) (Plugins.Lit lit)
substBndrs ::
Plugins.CoreExpr ->
[Plugins.Var] ->
[Plugins.CoreExpr] ->
(Plugins.CoreExpr, [Plugins.Var], [Plugins.CoreExpr])
substBndrs body (bndr : bndrs) (Plugins.Var arg : args) =
-- We only perform the substitution if the arg is a `Plugins.Var`, because
-- for non-Var args, we'd need to perform checks such as how many times
-- `bndr` occurs in `body` to determine whether we should substitute, and
-- the benefit does not make up for the overhead.
substBndrs (subst [(bndr, Plugins.Var arg)] body) bndrs args
substBndrs body bndrs args = (body, bndrs, args)
categorifyDataCon ::
Plugins.Var ->
Plugins.CoreExpr ->
Plugins.DataCon ->
[Plugins.CoreExpr] ->
CategoryStack Plugins.CoreExpr
categorifyDataCon name e dc args =
case (splitNameString $ Plugins.dataConName dc, args) of
-- Handle the data constructors expected in `Categorifier.Client.Rep` terms.
((Just "Data.Constraint", "Dict"), [Plugins.Type _, constraint]) ->
mkConst' makers (Plugins.varType name) constraint
((Just "Data.Either", "Left"), Plugins.Type a : Plugins.Type b : rest) ->
makeMaker1 makers (categorifyLambda name) rest =<\< mkInl makers a b
((Just "Data.Either", "Right"), Plugins.Type a : Plugins.Type b : rest) ->
makeMaker1 makers (categorifyLambda name) rest =<\< mkInr makers a b
((Just "GHC.Tuple", "(,)"), Plugins.Type a : Plugins.Type b : rest) ->
makeMaker2 makers (categorifyLambda name) e rest
<=\< mkId makers
$ Plugins.mkBoxedTupleTy [a, b]
-- "Given a saturated constructor application /Con e1...en/, rewrite it to /abst (inline
-- repr (Con e1...en))/, where /inline e/ tells GHC’s simplifier to inline the expression
-- /e/." ⸻§9
((_, _), _) -> do
let nonTypeArgs = filter (not . Plugins.isTypeArg) args
(binds, body) =
Plugins.collectBinders
(Plugins.etaExpand (Plugins.dataConRepArity dc - length nonTypeArgs) e)
bodyTy = Plugins.exprType body
abst <- mkAbst makers bodyTy
repr <- inlineHasRep =<\< mkRepr makers bodyTy
-- NON-INDUCTIVE
-- Here we expect `simplifyFun` to apply the `let`-substitution and
-- case-of-known-constructor transformations.
categorifyLambda name
<=\< simplifyFun dflags logger []
. Plugins.mkLams binds
. Plugins.App abst
$ Plugins.App repr body
-- `HasRep` is special to the plugin. We need to ensure the operations /don't/ inline in some
-- cases and the must be /fully/ inlined in others. We wrap the methods in functions so we can
-- prevent specialization, then here we inline twice -- the first inlines the function to the
-- method, and the second inlines the method. This should work as long as the instances are
-- defined correctly (which should be the case, since it's rare to have to define an instance
-- that wouldn't be provided by `deriveHasRep`).
inlineHasRep :: Plugins.CoreExpr -> CategoryStack Plugins.CoreExpr
inlineHasRep = inlineCast <=\< mkInline <=\< mkInline
where
-- `inlineCast` is needed to deal with casts resulting from type family definitions. For
-- example, when we inline `abst` for `KSum2 C () (TimedSensor (Msg MotorStatusMsg C) C)`,
-- we expect to get something like
--
-- ```
-- \a -> case a of (b, c) -> case c of (d, e) -> UnsafeSum2 b d e
-- ```
--
-- But since we have the following type instance:
--
-- ```
-- type instance Msg MotorStatusMsg f = CobsSensor (MotorStatusMsg f) f
-- ```
--
-- Without `inlineCast`, what we get instead is
--
-- ```
-- $fHasRepKSum2_$cabst `cast`
-- (<Co:19> :: Rep (KSum2 C () (TimedSensor (CobsSensor (MotorStatusMsg C) C) C)) ->
-- KSum2 C () (TimedSensor (CobsSensor (MotorStatusMsg C) C) C)
-- ~R# Rep (KSum2 C () (TimedSensor (Msg MotorStatusMsg C) C)) ->
-- KSum2 C () (TimedSensor (Msg MotorStatusMsg C) C)
-- )
-- ```
--
-- Because of this extra `cast`, `$fHasRepKSum2_$cabst` isn't inlined, causing
-- infinite looping.
--
-- What `inlineCast` does is simply drop the `Coercion` and call `mkInline` again.
-- This would cause core lint to complain (one doesn't simply drop `Coercion`s), but
-- otherwise should be harmless.
inlineCast :: Plugins.CoreExpr -> CategoryStack Plugins.CoreExpr
inlineCast = \case
Plugins.Cast e _ ->
-- Here we try to inline `e`, and if `e` can't be inlined, just return it rather than
-- throwing an error. Because for a type like `data Foo = Foo (C Word8)`, we have
--
-- ```
-- $fHasRepFoo_$cabst = Foo `cast`
-- (<Co:5> :: (C Word8 -> Foo) ~R# (Rep Foo -> Foo))
-- ```
--
-- and since we can't inline `Foo`, we need to return it.
mkInline' (const . pure) pure e
other -> pure other
-- This checks to see if the input is some application of an `Plugins.Id` that we know how to
-- map directly to a categorical representation. See `findMaker` for the `Plugins.Id`s that we
-- recognize here.
interpretVocabulary ::
Plugins.Var ->
Plugins.CoreExpr ->
Plugins.Var ->
[Plugins.CoreExpr] ->
CategoryStack Plugins.CoreExpr
interpretVocabulary n expr name args =
-- TODO: When doesn't a name have a module? What should we do in those cases?
let (moduleName, varName) =
first (fromMaybe "") . splitNameString . Plugins.varName $
maybeTraceWith debug (thump "interpreting" . Plugins.WithIdInfo) name
in findMaker makers n name expr varName args moduleName
handleExtraArgs m = handleAdditionalArgs m . categorifyLambda
thump :: (Plugins.Outputable a) => Plugins.SDoc -> a -> String
thump label term =
renderSDoc dflags $
Plugins.sep [label Plugins.<> ":", Plugins.nest 2 $ Plugins.ppr term]
-- This is where we enumerate the applications of identifiers that have a categorical
-- representation.
--
-- The general approach is this:
-- 1. "apply" enough to get to a single morphism in the target category.
-- - for most functions, this means do nothing, but ones like `fmap` and `either` take
-- one or more morphisms in the target category before returning a morphism in the
-- target category
-- 2. call @maker/n/@ on it where /n/ is the number of parameters (curried in __Hask__, tupled
-- in the target) in the final morphism.
-- - e.g., @(`+`) :: a -> a -> a@ becomes @`plusV` :: c (a, a) a@, so you call @maker2@,
-- not to be confused with @`fst` :: (a, b) -> a@ (becoming @`exlV` :: c (a, b) a@, which
-- would call @maker1@.
findMaker ::
Makers ->
-- Lambda-bound var
Plugins.Var ->
-- The var being interpreted
Plugins.Var ->
-- The expression where the var being interpreted is the head
Plugins.CoreExpr ->
String ->
[Plugins.CoreExpr] ->
String ->
CategoryStack Plugins.CoreExpr
findMaker m@Makers {..} n target expr var args modu =
fromMaybe
( -- If we can't find a matching case to interpret, we fall back through a few cases:
--
-- 1. try to interpret a specialized version
-- 2. look for a separate categorification of the function
-- 3. auto-interpret
-- 4. inline the function
maybe
( -- the checks whether we're categorifying this exact function. If so, we can't take
-- advantage of `mkNative'` for separate categorification because it'll loop.
if funName /= Just var
then
withExceptT getLast $
withExceptT Last mkNative'
<!> withExceptT
Last
( maybe
-- __FIXME__: This call to @`simplifyFun` []@ causes a lot of our
-- specialization woes. E.g., prior to this, we'll have an
-- expression containing `<*>`, `fmap`, etc., but after the
-- call (despite not passing `Rules`), we have
-- `$fApplicativeFoo_$c<*>`, `$fFunctorFoo_$cfmap`, etc.
-- However, if we /don't/ simplify here, then we get
-- complaints about missing dictionaries elsewhere.
(categorifyLambda n =<\< simplifyFun dflags logger [] =<\< mkInline expr)
(maker1 (dropWhile isTypeOrPred args))
=<\< tryAutoInterpret'
)
else
maybe
(categorifyLambda n =<\< simplifyFun dflags logger [] =<\< mkInline expr)
(maker1 (dropWhile isTypeOrPred args))
=<\< tryAutoInterpret'
)
pure
=<\< interpretSpecialized
)
. (($ args) <=< Map.lookup (modu, var))
$ Map.mapKeys (fromMaybe "" . TH.nameModule &&& TH.nameBase) makerMap
where
makerMap ::
Map TH.Name ([Plugins.CoreExpr] -> Maybe (CategoryStack Plugins.CoreExpr))
makerMap =
makerMapFun
dflags
logger
m
n
target
expr
cat
var
args
modu
categorifyFun
(categorifyLambda n)
maker1 = makeMaker1 m (categorifyLambda n)
maker2 = makeMaker2 m (categorifyLambda n) expr
mkNative' = do
let tagTy = Plugins.LitTy $ Plugins.StrTyLit [fmt|{modu}.{var}|]
f = fst $ applyTyAndPredArgs Plugins.Var (Plugins.Var target) args
(argTy, resTy) <-
maybe (throwE . pure $ NotFunTy f (Plugins.exprType f)) pure $
Plugins.splitFunTy_maybe (Plugins.exprType f)
maker1 (dropWhile isTypeOrPred args) =<\< mkNative tagTy argTy resTy
interpretSpecialized :: CategoryStack (Maybe Plugins.CoreExpr)
interpretSpecialized
| "$fEq" `isPrefixOf` var = interpretEqSpecialized monoTy
| "$fFoldable" `isPrefixOf` var = interpretFoldableSpecialized monoTy
| "$fIntegral" `isPrefixOf` var = interpretIntegralSpecialized monoTy
| "$fNum" `isPrefixOf` var = interpretNumSpecialized monoTy
| "$fOrd" `isPrefixOf` var = interpretOrdSpecialized monoTy
| otherwise = pure Nothing
where
monoTy =
Plugins.exprType . fst $
applyTyAndPredArgs Plugins.Var (Plugins.Var target) args
interpretEqSpecialized :: Plugins.Type -> CategoryStack (Maybe Plugins.CoreExpr)
interpretEqSpecialized monoTy
| rest <- dropWhile isTypeOrPred args,
"$c==" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c==" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkEqual ty)
| otherwise = pure Nothing
interpretFoldableSpecialized :: Plugins.Type -> CategoryStack (Maybe Plugins.CoreExpr)
interpretFoldableSpecialized monoTy
| rest <- dropWhile isTypeOrPred args,
"$cmaximum" `isSuffixOf` var = do
(t, a) <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cmaximum" monoTy)
pure
(Plugins.splitAppTy_maybe =<< extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker1 rest =<\< mkMaximum t a)
| rest <- dropWhile isTypeOrPred args,
"$cminimum" `isSuffixOf` var = do
(t, a) <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cminimum" monoTy)
pure
(Plugins.splitAppTy_maybe =<< extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker1 rest =<\< mkMinimum t a)
| otherwise = pure Nothing
interpretIntegralSpecialized :: Plugins.Type -> CategoryStack (Maybe Plugins.CoreExpr)
interpretIntegralSpecialized monoTy
| rest <- dropWhile isTypeOrPred args,
"$cdiv" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cdiv" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkDiv ty)
| rest <- dropWhile isTypeOrPred args,
"$cmod" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cmod" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkMod ty)
| otherwise = pure Nothing
interpretNumSpecialized :: Plugins.Type -> CategoryStack (Maybe Plugins.CoreExpr)
interpretNumSpecialized monoTy
| rest <- dropWhile isTypeOrPred args,
"$c+" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c+" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkPlus ty)
| rest <- dropWhile isTypeOrPred args,
"$cfromInteger" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cfromInteger" monoTy)
pure
(extractTypeFromFunTy [ResTy] monoTy)
pure <$> (maker1 rest =<\< mkFromInteger ty)
| otherwise = pure Nothing
interpretOrdSpecialized :: Plugins.Type -> CategoryStack (Maybe Plugins.CoreExpr)
interpretOrdSpecialized monoTy
| rest <- dropWhile isTypeOrPred args,
"$c<" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c<" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkLT ty)
| rest <- dropWhile isTypeOrPred args,
"$c>" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c>" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkGT ty)
| rest <- dropWhile isTypeOrPred args,
"$c<=" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c<=" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkLE ty)
| rest <- dropWhile isTypeOrPred args,
"$c>=" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $c>=" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkGE ty)
| rest <- dropWhile isTypeOrPred args,
"$cmax" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cmax" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkMax ty)
| rest <- dropWhile isTypeOrPred args,
"$cmin" `isSuffixOf` var = do
ty <-
maybe
(throwE . pure $ NotTyConApp "interpreting $cmin" monoTy)
pure
(extractTypeFromFunTy [ArgTy] monoTy)
pure <$> (maker2 rest =<\< mkMin ty)
| otherwise = pure Nothing
tryAutoInterpret' =
fmap (fmap (maybeTraceWith debug (const [fmt|Automatically interpreted {dbg target}|])))
. tryAutoInterpret
buildDictionary
cat
(Plugins.exprType . fst $ applyTyAndPredArgs Plugins.Var (Plugins.Var target) args)
target
$ takeWhile isTypeOrPred args