Skip to content

Commit 3f1a7df

Browse files
Cata tactic should generalize let and ensure unifiability (haskell#1938)
* Let must be generalized when doing a letForEach * Don't attempt to cata terms which don't unify with any args * Add a test Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 813e5ff commit 3f1a7df

File tree

6 files changed

+107
-4
lines changed

6 files changed

+107
-4
lines changed

plugins/hls-tactics-plugin/src/Wingman/CodeGen.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import GhcPlugins (isSymOcc, mkVarOccFS)
3333
import OccName (occName)
3434
import PatSyn
3535
import Type hiding (Var)
36+
import TysPrim (alphaTy)
3637
import Wingman.CodeGen.Utils
3738
import Wingman.GHC
3839
import Wingman.Judgements
@@ -309,7 +310,8 @@ letForEach rename solve (unHypothesis -> hy) jdg = do
309310
let g = jGoal jdg
310311
terms <- fmap sequenceA $ for hy $ \hi -> do
311312
let name = rename $ hi_name hi
312-
res <- tacticToRule jdg $ solve hi
313+
let generalized_let_ty = CType alphaTy
314+
res <- tacticToRule (withNewGoal generalized_let_ty jdg) $ solve hi
313315
pure $ fmap ((name,) . unLoc) res
314316
let hy' = fmap (g <$) $ syn_val terms
315317
matches = fmap (fmap (\(occ, expr) -> valBind (occNameToStr occ) expr)) terms

plugins/hls-tactics-plugin/src/Wingman/Machinery.hs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Control.Applicative (empty)
66
import Control.Lens ((<>~))
77
import Control.Monad.Error.Class
88
import Control.Monad.Reader
9-
import Control.Monad.State.Class (gets, modify)
9+
import Control.Monad.State.Class (gets, modify, MonadState)
1010
import Control.Monad.State.Strict (StateT (..), execStateT)
1111
import Control.Monad.Trans.Maybe
1212
import Data.Coerce
@@ -217,6 +217,20 @@ unify goal inst = do
217217
Nothing -> throwError (UnificationError inst goal)
218218

219219

220+
------------------------------------------------------------------------------
221+
-- | Attempt to unify two types.
222+
canUnify
223+
:: MonadState TacticState m
224+
=> CType -- ^ The goal type
225+
-> CType -- ^ The type we are trying unify the goal type with
226+
-> m Bool
227+
canUnify goal inst = do
228+
skolems <- gets ts_skolems
229+
case tryUnifyUnivarsButNotSkolems skolems goal inst of
230+
Just _ -> pure True
231+
Nothing -> pure False
232+
233+
220234
------------------------------------------------------------------------------
221235
-- | Prefer the first tactic to the second, if the bool is true. Otherwise, just run the second tactic.
222236
--
@@ -312,6 +326,17 @@ lookupNameInContext name = do
312326
Nothing -> empty
313327

314328

329+
getDefiningType
330+
:: (MonadError TacticError m, MonadReader Context m)
331+
=> m CType
332+
getDefiningType = do
333+
calling_fun_name <- fst . head <$> asks ctxDefiningFuncs
334+
maybe
335+
(throwError $ NotInScope calling_fun_name)
336+
pure
337+
=<< lookupNameInContext calling_fun_name
338+
339+
315340
------------------------------------------------------------------------------
316341
-- | Build a 'HyInfo' for an imported term.
317342
createImportedHyInfo :: OccName -> CType -> HyInfo CType

plugins/hls-tactics-plugin/src/Wingman/Tactics.hs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ module Wingman.Tactics
88
import ConLike (ConLike(RealDataCon))
99
import Control.Applicative (Alternative(empty))
1010
import Control.Lens ((&), (%~), (<>~))
11+
import Control.Monad (filterM)
1112
import Control.Monad (unless)
1213
import Control.Monad.Except (throwError)
14+
import Control.Monad.Extra (anyM)
1315
import Control.Monad.Reader.Class (MonadReader (ask))
1416
import Control.Monad.State.Strict (StateT(..), runStateT, gets)
1517
import Data.Bool (bool)
@@ -475,21 +477,38 @@ nary n =
475477
mkInvForAllTys [alphaTyVar, betaTyVar] $
476478
mkFunTys' (replicate n alphaTy) betaTy
477479

480+
478481
self :: TacticsM ()
479482
self =
480483
fmap listToMaybe getCurrentDefinitions >>= \case
481484
Just (self, _) -> useNameFromContext apply self
482485
Nothing -> throwError $ TacticPanic "no defining function"
483486

484487

488+
------------------------------------------------------------------------------
489+
-- | Perform a catamorphism when destructing the given 'HyInfo'. This will
490+
-- result in let binding, making values that call the defining function on each
491+
-- destructed value.
485492
cata :: HyInfo CType -> TacticsM ()
486493
cata hi = do
494+
(_, _, calling_args, _)
495+
<- tacticsSplitFunTy . unCType <$> getDefiningType
496+
freshened_args <- traverse freshTyvars calling_args
487497
diff <- hyDiff $ destruct hi
498+
499+
-- For for every destructed term, check to see if it can unify with any of
500+
-- the arguments to the calling function. If it doesn't, we don't try to
501+
-- perform a cata on it.
502+
unifiable_diff <- flip filterM (unHypothesis diff) $ \hi ->
503+
flip anyM freshened_args $ \ty ->
504+
canUnify (hi_type hi) $ CType ty
505+
488506
rule $
489507
letForEach
490508
(mkVarOcc . flip mappend "_c" . occNameString)
491-
(\hi -> self >> commit (apply hi) assumption)
492-
diff
509+
(\hi -> self >> commit (assume $ hi_name hi) assumption)
510+
$ Hypothesis unifiable_diff
511+
493512

494513
collapse :: TacticsM ()
495514
collapse = do

plugins/hls-tactics-plugin/test/CodeAction/RunMetaprogramSpec.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ spec = do
3333
metaTest 11 11 "MetaUseMethod"
3434
metaTest 9 38 "MetaCataCollapse"
3535
metaTest 7 16 "MetaCataCollapseUnary"
36+
metaTest 21 32 "MetaCataAST"
3637
metaTest 6 46 "MetaPointwise"
3738
metaTest 4 28 "MetaUseSymbol"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{-# LANGUAGE GADTs #-}
2+
3+
data AST a where
4+
BoolLit :: Bool -> AST Bool
5+
IntLit :: Int -> AST Int
6+
If :: AST Bool -> AST a -> AST a -> AST a
7+
Equal :: AST a -> AST a -> AST Bool
8+
9+
eval :: AST a -> a
10+
-- NOTE(sandy): There is an unrelated bug that is shown off in this test
11+
-- namely, that
12+
--
13+
-- @eval (IntLit n) = _@
14+
--
15+
-- but should be
16+
--
17+
-- @eval (IntLit n) = n@
18+
--
19+
-- https://github.com/haskell/haskell-language-server/issues/1937
20+
21+
eval (BoolLit b) = b
22+
eval (IntLit n) = _
23+
eval (If ast ast' ast_a)
24+
= let
25+
ast_c = eval ast
26+
ast'_c = eval ast'
27+
ast_a_c = eval ast_a
28+
in _ ast_c ast'_c ast_a_c
29+
eval (Equal ast ast')
30+
= let
31+
ast_c = eval ast
32+
ast'_c = eval ast'
33+
in _ ast_c ast'_c
34+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{-# LANGUAGE GADTs #-}
2+
3+
data AST a where
4+
BoolLit :: Bool -> AST Bool
5+
IntLit :: Int -> AST Int
6+
If :: AST Bool -> AST a -> AST a -> AST a
7+
Equal :: AST a -> AST a -> AST Bool
8+
9+
eval :: AST a -> a
10+
-- NOTE(sandy): There is an unrelated bug that is shown off in this test
11+
-- namely, that
12+
--
13+
-- @eval (IntLit n) = _@
14+
--
15+
-- but should be
16+
--
17+
-- @eval (IntLit n) = n@
18+
--
19+
-- https://github.com/haskell/haskell-language-server/issues/1937
20+
21+
eval = [wingman| intros x, cata x; collapse |]
22+

0 commit comments

Comments
 (0)