Opened 2 years ago

Last modified 2 years ago

#14287 new bug

Early inlining causes potential join points to be missed

Reported by: jheek Owned by:
Priority: normal Milestone:
Component: Compiler Version: 8.2.1
Keywords: JoinPoints Cc: nomeata
Operating System: Unknown/Multiple Architecture: Unknown/Multiple
Type of failure: None/Unknown Test Case:
Blocked By: Blocking:
Related Tickets: Differential Rev(s):
Wiki Page:

Description

While trying to make stream fusion work with recursive step functions I noticed that the following filter implementation did not fuse nicely.

data Stream s a = Stream (s -> Step s a) s
data Step s a = Done | Yield a s

sfilter :: (a -> Bool) -> Stream s a -> Stream s a
sfilter pred (Stream step s0) = Stream filterStep s0 where
  filterStep s = case step s of
    Done -> Done
    Yield x ns
      | pred x    -> Yield x ns
      | otherwise -> filterStep ns

fromTo :: Int -> Int -> Stream Int Int
{-# INLINE fromTo #-}
fromTo from to = Stream step from where
  step i
    | i > to    = Done
    | otherwise = Yield i (i + 1)

sfoldl :: (b -> a -> b) -> b -> Stream s a -> b
{-# INLINE sfoldl #-}
sfoldl acc z (Stream !step s0) = oneShot go z s0 where
  go !y s = case step s of
    Done       -> y
    Yield x ns -> go (acc y x) ns

ssum :: (Num a) => Stream s a -> a
ssum = sfoldl (+) 0

filterTest :: Int
filterTest = ssum $ sfilter even (fromTo 1 101)

For this code to work nicely, GHC should detect that filterStep is a join point. However, in the definition of sfilter it is not because not all references are tail-called & saturated.

After inlining of sfilter and some trivial case-of-case transformations filterStep should become a join point. But it seems like the simplifier never gets the change to do this because float-out optimization makes filterStep a top level binding. With -fno-full-laziness filterStep does become a join point at the call site, but of course this is not really a solution.

Then I found that the following also works:

sfilter :: (a -> Bool) -> Stream s a -> Stream s a
sfilter pred (Stream step s0) = Stream filterStep s0 where
  {-# INLINE [2] filterStep #-}
  filterStep s = case step s of
    Done -> Done
    Yield x ns
      | pred x    -> Yield x ns
      | otherwise -> filterStep ns

Simply adding an INLINE [2] pragma disables the inlining in the early run of the simplifier. Therefore, the float out pass does not get the change to float-out before the filterStep is recognized as a joint point. Or at least that is my interpretation of what is going on.

What surprises me about this issue is that the gentle run seems to perform inlining while the wiki mentions that inlining is not performed in this stage: https://ghc.haskell.org/trac/ghc/wiki/Commentary/Compiler/Core2CorePipeline

Intuitively, I would think that floating-out is sub-optimal when the simplifier did not use all its tricks yet, because inlining typically opens up possibilities for simplification while floating-out typically reducing these possibilities.

Change History (8)

comment:1 Changed 2 years ago by bgamari

What surprises me about this issue is that the gentle run seems to perform inlining while the wiki mentions that inlining is not performed in this stage: ​https://ghc.haskell.org/trac/ghc/wiki/Commentary/Compiler/Core2CorePipeline

Indeed, as of 2effe18ab51d66474724d38b20e49cc1b8738f60 this is no longer true.

comment:2 Changed 2 years ago by jheek

I’m not sure whether early inlining is really the issue here but it does seem to avoid it. I also found arity analysis to be important in this case. Would it make sense to (optionally) run the float out step only after a full simplify, arity, and simplify cycle?

comment:3 Changed 2 years ago by mpickering

Another way to make the program faster is to write

sfilter3 :: (a -> Bool) -> Stream s a -> Stream s a                             
sfilter3 pred (Stream step s0) = Stream filterStep s0 where                     
  filterStep s =                                                                
    let go s =                                                                  
          case step s of                                                        
            Done -> Done                                                        
            Yield x ns                                                          
              | pred x    -> Yield x ns                                         
              | otherwise -> go ns                                              
    in go s

or to perform the transformation described by Simon in #13966.

Both leads to precisely the same core as the INLINE [2] version.

comment:4 Changed 2 years ago by jheek

Actually sfilter3 does not produce optimal code for me. It is correctly recognized as a join point but it gets floated out of to be a top-level binding before the recursive case-of-case transformation jumps in.

comment:5 Changed 2 years ago by mpickering

How are you compiling, with which version of the compiler? All 3 filterTests were csed to the same definition when I tried with my modified compiler.

Here is the test file and results.

https://gist.github.com/7c7cb362206f60bc85a76ceb30d786c3

Observe that in the loopification.simpl dump file, all 3 are CSEd but in no-loopification.simpl only 1 and 3 are CSEd.

comment:6 Changed 2 years ago by simonpj

A major goal in GHC is to avoid sensitivity to ordering of transformations. Otherwise things become too finely balanced.

Here are three programs

-- Recursive function
-- go is not a join point
f1 x = letrec go s = case s of
                        Done -> Done
                        Yield x s' | pred x    -> Yield x s'
                                   | otherwise -> go s'
       in case go s2 of ...

-- Same but float inwards
-- Now go becomes a join point
f2 x = case letrec go s = case s of
                        Done -> Done
                        Yield x s' | pred x    -> Yield x s'
                                   | otherwise -> go s'
             in go s2 of ...

-- Same but float outwards
-- Now go becomes top-level
go pred s = case s of
               Done -> Done
               Yield x s' | pred x    -> Yield x s'
                    | otherwise -> go s'
f3 x = case go pred s2 of ...

Points to note:

  • Float-in can create join points; see the transition from f1 to f2
  • Even though go is a join point in f2, we need a run of the Simplifier to mark it as such. Once it is marked as a join point, it'll stay that way.
  • Float-out is currently pretty aggressive about floating things to top level, and so will tend to generate f3. By itself that is not too bad. But now the case in f3 can't fuse with the loop.
  • Don't forget that the user might write the program in the f3 form in the first place. Ideally we want all forms to optimise the same way.

I think that the Right Solution to these fragilities is the loopification plan in #14068. Then, even in the top-level form we'd get

go pred s = letrec go' s = case s of
                             Done -> Done
                             Yield x s' | pred x    -> Yield x s'
                                        | otherwise -> go' s'
            in go' s
f3 x = case go pred s2 of ...

Now (a) go' is a join point, and (b) go is non-recursive and will inline.

comment:7 Changed 2 years ago by simonpj

Keywords: JoinPoints added

comment:8 Changed 2 years ago by nomeata

Cc: nomeata added
Note: See TracTickets for help on using tickets.