{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE DeriveAnyClass     #-}
{-# LANGUAGE NamedFieldPuns     #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE RecordWildCards    #-}
{-# LANGUAGE StrictData         #-}
{-# LANGUAGE TypeApplications   #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | "Database.Redis" like interface with connection through Redis Sentinel.
--
-- More details here: <https://redis.io/topics/sentinel>.
--
-- Example:
--
-- @
-- conn <- 'connect' 'SentinelConnectionInfo' (("localhost", PortNumber 26379) :| []) "mymaster" 'defaultConnectInfo'
--
-- 'runRedis' conn $ do
--   'set' "hello" "world"
-- @
--
-- When connection is opened, the Sentinels will be queried to get current master. Subsequent 'runRedis'
-- calls will talk to that master.
--
-- If 'runRedis' call fails, the next call will choose a new master to talk to.
--
-- This implementation is based on Gist by Emanuel Borsboom
-- at <https://gist.github.com/borsboom/681d37d273d5c4168723>
module Database.Redis.Sentinel
  (
    -- * Connection
    SentinelConnectInfo(..)
  , SentinelConnection
  , connect
    -- * runRedis with Sentinel support
  , runRedis
  , RedisSentinelException(..)

    -- * Re-export Database.Redis
  , module Database.Redis
  ) where

import           Control.Concurrent
import           Control.Exception     (Exception, IOException, evaluate, throwIO)
import           Control.Monad
import           Control.Monad.Catch   (Handler (..), MonadCatch, catches, throwM)
import           Control.Monad.Except
import           Control.Monad.IO.Class(MonadIO(liftIO))
import           Data.ByteString       (ByteString)
import qualified Data.ByteString       as BS
import qualified Data.ByteString.Char8 as BS8
import           Data.Foldable         (toList)
import           Data.List             (delete)
import           Data.List.NonEmpty    (NonEmpty (..))
import           Data.Typeable         (Typeable)
import           Data.Unique
import           Network.Socket        (HostName)

import           Database.Redis hiding (Connection, connect, runRedis)
import qualified Database.Redis as Redis

-- | Interact with a Redis datastore.  See 'Database.Redis.runRedis' for details.
runRedis :: SentinelConnection
         -> Redis (Either Reply a)
         -> IO (Either Reply a)
runRedis :: forall a.
SentinelConnection -> Redis (Either Reply a) -> IO (Either Reply a)
runRedis (SentinelConnection MVar SentinelConnection'
connMVar) Redis (Either Reply a)
action = do
  (baseConn, preToken) <- MVar SentinelConnection'
-> (SentinelConnection'
    -> IO (SentinelConnection', (Connection, Unique)))
-> IO (Connection, Unique)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar SentinelConnection'
connMVar ((SentinelConnection'
  -> IO (SentinelConnection', (Connection, Unique)))
 -> IO (Connection, Unique))
-> (SentinelConnection'
    -> IO (SentinelConnection', (Connection, Unique)))
-> IO (Connection, Unique)
forall a b. (a -> b) -> a -> b
$ \oldConnection :: SentinelConnection'
oldConnection@SentinelConnection'
          { Bool
rcCheckFailover :: Bool
rcCheckFailover :: SentinelConnection' -> Bool
rcCheckFailover
          , rcToken :: SentinelConnection' -> Unique
rcToken = Unique
oldToken
          , rcSentinelConnectInfo :: SentinelConnection' -> SentinelConnectInfo
rcSentinelConnectInfo = SentinelConnectInfo
oldConnectInfo
          , rcMasterConnectInfo :: SentinelConnection' -> ConnectInfo
rcMasterConnectInfo = ConnectInfo
oldMasterConnectInfo
          , rcBaseConnection :: SentinelConnection' -> Connection
rcBaseConnection = Connection
oldBaseConnection } ->
      if Bool
rcCheckFailover
        then do
          (newConnectInfo, newMasterConnectInfo) <- SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster SentinelConnectInfo
oldConnectInfo
          newToken <- newUnique
          (connInfo, conn) <-
            if sameHost newMasterConnectInfo oldMasterConnectInfo
              then return (oldMasterConnectInfo, oldBaseConnection)
              else do
                newConn <- Redis.connect newMasterConnectInfo
                return (newMasterConnectInfo, newConn)

          return
            ( SentinelConnection'
              { rcCheckFailover = False
              , rcToken = newToken
              , rcSentinelConnectInfo = newConnectInfo
              , rcMasterConnectInfo = connInfo
              , rcBaseConnection = conn
              }
            , (conn, newToken)
            )
        else (SentinelConnection', (Connection, Unique))
-> IO (SentinelConnection', (Connection, Unique))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SentinelConnection'
oldConnection, (Connection
oldBaseConnection, Unique
oldToken))

  -- Use evaluate to make sure we catch exceptions from 'runRedis'.
  reply <- (Redis.runRedis baseConn action >>= evaluate)
    `catchRedisRethrow` (\HostName
_ -> Unique -> IO ()
setCheckSentinel Unique
preToken)
  case reply of
    Left (Error ByteString
e) | ByteString
"READONLY " ByteString -> ByteString -> Bool
`BS.isPrefixOf` ByteString
e ->
        -- This means our connection has turned into a slave
        Unique -> IO ()
setCheckSentinel Unique
preToken
    Either Reply a
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  return reply

  where
    sameHost :: Redis.ConnectInfo -> Redis.ConnectInfo -> Bool
    sameHost :: ConnectInfo -> ConnectInfo -> Bool
sameHost ConnectInfo
l ConnectInfo
r = ConnectInfo -> HostName
connectHost ConnectInfo
l HostName -> HostName -> Bool
forall a. Eq a => a -> a -> Bool
== ConnectInfo -> HostName
connectHost ConnectInfo
r Bool -> Bool -> Bool
&& ConnectInfo -> PortID
connectPort ConnectInfo
l PortID -> PortID -> Bool
forall a. Eq a => a -> a -> Bool
== ConnectInfo -> PortID
connectPort ConnectInfo
r

    setCheckSentinel :: Unique -> IO ()
setCheckSentinel Unique
preToken = MVar SentinelConnection'
-> (SentinelConnection' -> IO SentinelConnection') -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar SentinelConnection'
connMVar ((SentinelConnection' -> IO SentinelConnection') -> IO ())
-> (SentinelConnection' -> IO SentinelConnection') -> IO ()
forall a b. (a -> b) -> a -> b
$ \conn :: SentinelConnection'
conn@SentinelConnection'{Unique
rcToken :: SentinelConnection' -> Unique
rcToken :: Unique
rcToken} ->
      if Unique
preToken Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique
rcToken
        then do
          newToken <- IO Unique
newUnique
          return (conn{rcToken = newToken, rcCheckFailover = True})
        else SentinelConnection' -> IO SentinelConnection'
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return SentinelConnection'
conn


connect :: SentinelConnectInfo -> IO SentinelConnection
connect :: SentinelConnectInfo -> IO SentinelConnection
connect SentinelConnectInfo
origConnectInfo = do
  (connectInfo, masterConnectInfo) <- SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster SentinelConnectInfo
origConnectInfo
  conn <- Redis.connect masterConnectInfo
  token <- newUnique

  SentinelConnection <$> newMVar SentinelConnection'
    { rcCheckFailover = False
    , rcToken = token
    , rcSentinelConnectInfo = connectInfo
    , rcMasterConnectInfo = masterConnectInfo
    , rcBaseConnection = conn
    }

updateMaster :: SentinelConnectInfo
             -> IO (SentinelConnectInfo, Redis.ConnectInfo)
updateMaster :: SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster sci :: SentinelConnectInfo
sci@SentinelConnectInfo{NonEmpty (HostName, PortID)
ByteString
ConnectInfo
connectSentinels :: NonEmpty (HostName, PortID)
connectMasterName :: ByteString
connectBaseInfo :: ConnectInfo
connectBaseInfo :: SentinelConnectInfo -> ConnectInfo
connectMasterName :: SentinelConnectInfo -> ByteString
connectSentinels :: SentinelConnectInfo -> NonEmpty (HostName, PortID)
..} = do
    -- This is using the Either monad "backwards" -- Left means stop because we've made a connection,
    -- Right means try again.
    resultEither <- ExceptT (ConnectInfo, (HostName, PortID)) IO ()
-> IO (Either (ConnectInfo, (HostName, PortID)) ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT (ConnectInfo, (HostName, PortID)) IO ()
 -> IO (Either (ConnectInfo, (HostName, PortID)) ()))
-> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
-> IO (Either (ConnectInfo, (HostName, PortID)) ())
forall a b. (a -> b) -> a -> b
$ NonEmpty (HostName, PortID)
-> ((HostName, PortID)
    -> ExceptT (ConnectInfo, (HostName, PortID)) IO ())
-> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ NonEmpty (HostName, PortID)
connectSentinels (((HostName, PortID)
  -> ExceptT (ConnectInfo, (HostName, PortID)) IO ())
 -> ExceptT (ConnectInfo, (HostName, PortID)) IO ())
-> ((HostName, PortID)
    -> ExceptT (ConnectInfo, (HostName, PortID)) IO ())
-> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall a b. (a -> b) -> a -> b
$ \(HostName
host, PortID
port) -> do
      HostName
-> PortID -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
trySentinel HostName
host PortID
port ExceptT (ConnectInfo, (HostName, PortID)) IO ()
-> (HostName -> ExceptT (ConnectInfo, (HostName, PortID)) IO ())
-> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m a) -> m a
`catchRedis` (\HostName
_ -> () -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall a. a -> ExceptT (ConnectInfo, (HostName, PortID)) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())


    case resultEither of
        Left (ConnectInfo
conn, (HostName, PortID)
sentinelPair) -> (SentinelConnectInfo, ConnectInfo)
-> IO (SentinelConnectInfo, ConnectInfo)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( SentinelConnectInfo
sci
            { connectSentinels = sentinelPair :| delete sentinelPair (toList connectSentinels)
            }
          , ConnectInfo
conn
          )
        Right () -> RedisSentinelException -> IO (SentinelConnectInfo, ConnectInfo)
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (RedisSentinelException -> IO (SentinelConnectInfo, ConnectInfo))
-> RedisSentinelException -> IO (SentinelConnectInfo, ConnectInfo)
forall a b. (a -> b) -> a -> b
$ NonEmpty (HostName, PortID) -> RedisSentinelException
NoSentinels NonEmpty (HostName, PortID)
connectSentinels
  where
    trySentinel :: HostName -> PortID -> ExceptT (Redis.ConnectInfo, (HostName, PortID)) IO ()
    trySentinel :: HostName
-> PortID -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
trySentinel HostName
sentinelHost PortID
sentinelPort = do
      -- bang to ensure exceptions from runRedis get thrown immediately.
      !replyE <- IO (Either Reply [ByteString])
-> ExceptT
     (ConnectInfo, (HostName, PortID)) IO (Either Reply [ByteString])
forall a. IO a -> ExceptT (ConnectInfo, (HostName, PortID)) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either Reply [ByteString])
 -> ExceptT
      (ConnectInfo, (HostName, PortID)) IO (Either Reply [ByteString]))
-> IO (Either Reply [ByteString])
-> ExceptT
     (ConnectInfo, (HostName, PortID)) IO (Either Reply [ByteString])
forall a b. (a -> b) -> a -> b
$ do
        !sentinelConn <- ConnectInfo -> IO Connection
Redis.connect (ConnectInfo -> IO Connection) -> ConnectInfo -> IO Connection
forall a b. (a -> b) -> a -> b
$ ConnectInfo
Redis.defaultConnectInfo
            { connectHost = sentinelHost
            , connectPort = sentinelPort
            , connectMaxConnections = 1
            }
        Redis.runRedis sentinelConn $ sendRequest
          ["SENTINEL", "get-master-addr-by-name", connectMasterName]

      case replyE of
        Right [ByteString
host, ByteString
port] ->
          (ConnectInfo, (HostName, PortID))
-> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall a.
(ConnectInfo, (HostName, PortID))
-> ExceptT (ConnectInfo, (HostName, PortID)) IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
            ( ConnectInfo
connectBaseInfo
              { connectHost = BS8.unpack host
              , connectPort =
                  maybe
                    (PortNumber 26379)
                    (PortNumber . fromIntegral . fst)
                    $ BS8.readInt port
              }
            , (HostName
sentinelHost, PortID
sentinelPort)
            )
        Either Reply [ByteString]
_ -> () -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
forall a. a -> ExceptT (ConnectInfo, (HostName, PortID)) IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

catchRedisRethrow :: MonadCatch m => m a -> (String -> m ()) -> m a
catchRedisRethrow :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m ()) -> m a
catchRedisRethrow m a
action HostName -> m ()
handler =
  m a
action m a -> [Handler m a] -> m a
forall (f :: * -> *) (m :: * -> *) a.
(HasCallStack, Foldable f, MonadCatch m) =>
m a -> f (Handler m a) -> m a
`catches`
    [ (IOException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((IOException -> m a) -> Handler m a)
-> (IOException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \IOException
ex -> HostName -> m ()
handler (forall a. Show a => a -> HostName
show @IOException IOException
ex) m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IOException -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM IOException
ex
    , (ConnectionLostException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((ConnectionLostException -> m a) -> Handler m a)
-> (ConnectionLostException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \ConnectionLostException
ex -> HostName -> m ()
handler (forall a. Show a => a -> HostName
show @ConnectionLostException ConnectionLostException
ex) m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConnectionLostException -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM ConnectionLostException
ex
    ]

catchRedis :: MonadCatch m => m a -> (String -> m a) -> m a
catchRedis :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m a) -> m a
catchRedis m a
action HostName -> m a
handler =
  m a
action m a -> [Handler m a] -> m a
forall (f :: * -> *) (m :: * -> *) a.
(HasCallStack, Foldable f, MonadCatch m) =>
m a -> f (Handler m a) -> m a
`catches`
    [ (IOException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((IOException -> m a) -> Handler m a)
-> (IOException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \IOException
ex -> HostName -> m a
handler (forall a. Show a => a -> HostName
show @IOException IOException
ex)
    , (ConnectionLostException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((ConnectionLostException -> m a) -> Handler m a)
-> (ConnectionLostException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \ConnectionLostException
ex -> HostName -> m a
handler (forall a. Show a => a -> HostName
show @ConnectionLostException ConnectionLostException
ex)
    ]

newtype SentinelConnection = SentinelConnection (MVar SentinelConnection')

data SentinelConnection'
  = SentinelConnection'
      { SentinelConnection' -> Bool
rcCheckFailover       :: Bool
      , SentinelConnection' -> Unique
rcToken               :: Unique
      , SentinelConnection' -> SentinelConnectInfo
rcSentinelConnectInfo :: SentinelConnectInfo
      , SentinelConnection' -> ConnectInfo
rcMasterConnectInfo   :: Redis.ConnectInfo
      , SentinelConnection' -> Connection
rcBaseConnection      :: Redis.Connection
      }

-- | Configuration of Sentinel hosts.
data SentinelConnectInfo
  = SentinelConnectInfo
      { SentinelConnectInfo -> NonEmpty (HostName, PortID)
connectSentinels  :: NonEmpty (HostName, PortID)
        -- ^ List of sentinels.
      , SentinelConnectInfo -> ByteString
connectMasterName :: ByteString
        -- ^ Name of master to connect to.
      , SentinelConnectInfo -> ConnectInfo
connectBaseInfo   :: Redis.ConnectInfo
        -- ^ This is used to configure auth and other parameters for Redis connection,
        -- but 'Redis.connectHost' and 'Redis.connectPort' are ignored.
      }
  deriving (Int -> SentinelConnectInfo -> ShowS
[SentinelConnectInfo] -> ShowS
SentinelConnectInfo -> HostName
(Int -> SentinelConnectInfo -> ShowS)
-> (SentinelConnectInfo -> HostName)
-> ([SentinelConnectInfo] -> ShowS)
-> Show SentinelConnectInfo
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SentinelConnectInfo -> ShowS
showsPrec :: Int -> SentinelConnectInfo -> ShowS
$cshow :: SentinelConnectInfo -> HostName
show :: SentinelConnectInfo -> HostName
$cshowList :: [SentinelConnectInfo] -> ShowS
showList :: [SentinelConnectInfo] -> ShowS
Show)

-- | Exception thrown by "Database.Redis.Sentinel".
data RedisSentinelException
  = NoSentinels (NonEmpty (HostName, PortID))
    -- ^ Thrown if no sentinel can be reached.
  deriving (Int -> RedisSentinelException -> ShowS
[RedisSentinelException] -> ShowS
RedisSentinelException -> HostName
(Int -> RedisSentinelException -> ShowS)
-> (RedisSentinelException -> HostName)
-> ([RedisSentinelException] -> ShowS)
-> Show RedisSentinelException
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RedisSentinelException -> ShowS
showsPrec :: Int -> RedisSentinelException -> ShowS
$cshow :: RedisSentinelException -> HostName
show :: RedisSentinelException -> HostName
$cshowList :: [RedisSentinelException] -> ShowS
showList :: [RedisSentinelException] -> ShowS
Show, Typeable)

deriving instance Exception RedisSentinelException