{-# LANGUAGE FlexibleContexts #-}
module Data.Conduit.OpenPGP.Decrypt
( conduitDecrypt
) where
import Control.Monad (when)
import Control.Monad.Fail (MonadFail)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Trans.Resource (MonadResource, MonadThrow)
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import Data.Binary (get)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16.Lazy as B16L
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Combinators as CC
import qualified Data.Conduit.List as CL
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Maybe (fromJust, isNothing)
import Codec.Encryption.OpenPGP.CFB (decryptOpenPGPCfb, decryptPreservingNonce)
import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.Types
data RecursorState =
RecursorState
{ RecursorState -> Int
_depth :: Int
, RecursorState -> Maybe PKESK
_lastPKESK :: Maybe PKESK
, RecursorState -> Maybe SKESK
_lastSKESK :: Maybe SKESK
, RecursorState -> Maybe StrictByteString
_lastNonce :: Maybe B.ByteString
, RecursorState -> Maybe StrictByteString
_lastClearText :: Maybe B.ByteString
}
deriving (RecursorState -> RecursorState -> Bool
(RecursorState -> RecursorState -> Bool)
-> (RecursorState -> RecursorState -> Bool) -> Eq RecursorState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RecursorState -> RecursorState -> Bool
== :: RecursorState -> RecursorState -> Bool
$c/= :: RecursorState -> RecursorState -> Bool
/= :: RecursorState -> RecursorState -> Bool
Eq, Int -> RecursorState -> ShowS
[RecursorState] -> ShowS
RecursorState -> String
(Int -> RecursorState -> ShowS)
-> (RecursorState -> String)
-> ([RecursorState] -> ShowS)
-> Show RecursorState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RecursorState -> ShowS
showsPrec :: Int -> RecursorState -> ShowS
$cshow :: RecursorState -> String
show :: RecursorState -> String
$cshowList :: [RecursorState] -> ShowS
showList :: [RecursorState] -> ShowS
Show)
def :: RecursorState
def :: RecursorState
def = Int
-> Maybe PKESK
-> Maybe SKESK
-> Maybe StrictByteString
-> Maybe StrictByteString
-> RecursorState
RecursorState Int
0 Maybe PKESK
forall a. Maybe a
Nothing Maybe SKESK
forall a. Maybe a
Nothing Maybe StrictByteString
forall a. Maybe a
Nothing Maybe StrictByteString
forall a. Maybe a
Nothing
type InputCallback m = String -> m BL.ByteString
conduitDecrypt ::
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
=> InputCallback IO
-> ConduitT Pkt Pkt m ()
conduitDecrypt :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt = RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' RecursorState
def
conduitDecrypt' ::
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> ConduitT Pkt Pkt m ()
conduitDecrypt' :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' RecursorState
rs InputCallback IO
cb = (Pkt -> RecursorState -> m (RecursorState, [Pkt]))
-> RecursorState -> ConduitT Pkt Pkt m ()
forall (m :: * -> *) a accum b.
Monad m =>
(a -> accum -> m (accum, [b])) -> accum -> ConduitT a b m ()
CC.concatMapAccumM Pkt -> RecursorState -> m (RecursorState, [Pkt])
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
Pkt -> RecursorState -> m (RecursorState, [Pkt])
push RecursorState
rs
where
push ::
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
=> Pkt
-> RecursorState
-> m (RecursorState, [Pkt])
push :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
Pkt -> RecursorState -> m (RecursorState, [Pkt])
push Pkt
i RecursorState
s
| RecursorState -> Int
_depth RecursorState
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
42 = String -> m (RecursorState, [Pkt])
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"I think we've been quine-attacked"
| Bool
otherwise =
case Pkt
i of
SKESKPkt {} -> (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s {_lastSKESK = Just (fromPkt i)}, [])
(SymEncDataPkt ByteString
bs) -> do
d <- RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEDP RecursorState
s InputCallback IO
cb (Maybe SKESK -> SKESK
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe SKESK -> SKESK)
-> (RecursorState -> Maybe SKESK) -> RecursorState -> SKESK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecursorState -> Maybe SKESK
_lastSKESK (RecursorState -> SKESK) -> RecursorState -> SKESK
forall a b. (a -> b) -> a -> b
$ RecursorState
s) ByteString
bs
return (s, d)
(SymEncIntegrityProtectedDataPkt Word8
_ ByteString
bs) -> do
d <- RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEIPDP RecursorState
s InputCallback IO
cb (Maybe SKESK -> SKESK
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe SKESK -> SKESK)
-> (RecursorState -> Maybe SKESK) -> RecursorState -> SKESK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecursorState -> Maybe SKESK
_lastSKESK (RecursorState -> SKESK) -> RecursorState -> SKESK
forall a b. (a -> b) -> a -> b
$ RecursorState
s) ByteString
bs
return (s, d)
m :: Pkt
m@(ModificationDetectionCodePkt ByteString
mdc) -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe StrictByteString -> Bool
forall a. Maybe a -> Bool
isNothing (RecursorState -> Maybe StrictByteString
_lastClearText RecursorState
s)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"MDC with no referent"
let mcalculated :: Maybe ByteString
mcalculated = StrictByteString -> StrictByteString -> ByteString
calculateMDC (StrictByteString -> StrictByteString -> ByteString)
-> Maybe StrictByteString -> Maybe (StrictByteString -> ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecursorState -> Maybe StrictByteString
_lastNonce RecursorState
s Maybe (StrictByteString -> ByteString)
-> Maybe StrictByteString -> Maybe ByteString
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> RecursorState -> Maybe StrictByteString
_lastClearText RecursorState
s
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString
mcalculated Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
mdc) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
String
"MDC indicates tampering: " String -> ShowS
forall a. [a] -> [a] -> [a]
++
ByteString -> String
forall a. Show a => a -> String
show (ByteString -> ByteString
B16L.encode ByteString
mdc) String -> ShowS
forall a. [a] -> [a] -> [a]
++
String
" versus " String -> ShowS
forall a. [a] -> [a] -> [a]
++
String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"<empty>" (ByteString -> String
forall a. Show a => a -> String
show (ByteString -> String)
-> (ByteString -> ByteString) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B16L.encode) Maybe ByteString
mcalculated String -> ShowS
forall a. [a] -> [a] -> [a]
++
String
" ... " String -> ShowS
forall a. [a] -> [a] -> [a]
++
Maybe StrictByteString -> String
forall a. Show a => a -> String
show (RecursorState -> Maybe StrictByteString
_lastNonce RecursorState
s) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" / " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe StrictByteString -> String
forall a. Show a => a -> String
show (RecursorState -> Maybe StrictByteString
_lastClearText RecursorState
s)
(RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt
m])
Pkt
p -> (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt
p])
decryptSEDP ::
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> SKESK
-> BL.ByteString
-> m [Pkt]
decryptSEDP :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEDP RecursorState
rs InputCallback IO
cb SKESK
skesk ByteString
bs
= do
passphrase <- IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ InputCallback IO
cb String
"Input the passphrase I want"
let key = SKESK -> ByteString -> StrictByteString
skesk2Key SKESK
skesk ByteString
passphrase
decrypted =
case SymmetricAlgorithm
-> StrictByteString
-> StrictByteString
-> Either String StrictByteString
decryptOpenPGPCfb
(SKESK -> SymmetricAlgorithm
_skeskSymmetricAlgorithm SKESK
skesk)
(ByteString -> StrictByteString
BL.toStrict ByteString
bs)
StrictByteString
key of
Left String
e -> String -> StrictByteString
forall a. HasCallStack => String -> a
error String
e
Right StrictByteString
x -> StrictByteString
x
runConduitRes $
CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
conduitDecompress .|
conduitDecrypt' rs {_depth = _depth rs + 1} cb .|
CL.consume
decryptSEIPDP ::
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> SKESK
-> BL.ByteString
-> m [Pkt]
decryptSEIPDP :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEIPDP RecursorState
rs InputCallback IO
cb SKESK
skesk ByteString
bs
= do
passphrase <- IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ InputCallback IO
cb String
"Input the passphrase I want"
let key = SKESK -> ByteString -> StrictByteString
skesk2Key SKESK
skesk ByteString
passphrase
(nonce, decrypted) =
case decryptPreservingNonce
(_skeskSymmetricAlgorithm skesk)
(BL.toStrict bs)
key of
Left String
e -> String -> (StrictByteString, StrictByteString)
forall a. HasCallStack => String -> a
error String
e
Right (StrictByteString, StrictByteString)
x -> (StrictByteString, StrictByteString)
x
runConduitRes $
CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
conduitDecompress .|
conduitDecrypt'
rs
{ _depth = _depth rs + 1
, _lastNonce = Just nonce
, _lastClearText = Just decrypted
}
cb .|
CL.consume
calculateMDC :: B.ByteString -> B.ByteString -> BL.ByteString
calculateMDC :: StrictByteString -> StrictByteString -> ByteString
calculateMDC StrictByteString
nonce StrictByteString
garbage
| StrictByteString -> Int
B.length StrictByteString
garbage Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
23 = ByteString
forall a. Monoid a => a
mempty
| Bool
otherwise =
StrictByteString -> ByteString
BL.fromStrict (StrictByteString -> ByteString)
-> (StrictByteString -> StrictByteString)
-> StrictByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA1 -> StrictByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest SHA1 -> StrictByteString)
-> (StrictByteString -> Digest SHA1)
-> StrictByteString
-> StrictByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StrictByteString -> Digest SHA1
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
CH.hash :: B.ByteString -> CH.Digest CHA.SHA1) (StrictByteString -> ByteString) -> StrictByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
StrictByteString
nonce StrictByteString -> StrictByteString -> StrictByteString
forall a. Semigroup a => a -> a -> a
<> Int -> StrictByteString -> StrictByteString
B.take (StrictByteString -> Int
B.length StrictByteString
garbage Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
22) StrictByteString
garbage StrictByteString -> StrictByteString -> StrictByteString
forall a. Semigroup a => a -> a -> a
<> [Word8] -> StrictByteString
B.pack [Word8
211, Word8
20]