{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wall     #-}
module Crypto.Random.DRBG.HMAC
        ( State, counter
        , reseedInterval
        , instantiate
        , reseed
        , generate) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.Tagged (proxy)
import Data.Word (Word64)
import Crypto.Classes
import Crypto.HMAC
import Crypto.Random.DRBG.Types

type Key = B.ByteString
type Value = B.ByteString

data State d = St
        { State d -> Word64
counter               :: {-# UNPACK #-} !Word64
        -- Start admin info
        , State d -> Value
value                 :: !Value
        , State d -> Value
key                   :: !Key
        }

-- This is available with the right type in the tagged package starting from
-- version 0.7, but ending with GHC version 7.8. Sigh.
asProxyTypeOf :: d -> state d -> d
asProxyTypeOf :: d -> state d -> d
asProxyTypeOf = d -> state d -> d
forall a b. a -> b -> a
const

reseedInterval :: Word64
reseedInterval :: Word64
reseedInterval = 2Word64 -> Int -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^(48::Int)

fc :: B.ByteString -> L.ByteString
fc :: Value -> ByteString
fc = [Value] -> ByteString
L.fromChunks ([Value] -> ByteString)
-> (Value -> [Value]) -> Value -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \s :: Value
s -> [Value
s]

update :: (Hash c d) => State d -> L.ByteString -> State d
update :: State d -> ByteString -> State d
update st :: State d
st input :: ByteString
input = State d
st { value :: Value
value = Value
newV , key :: Value
key = Value
newK }
  where
  hm :: Value -> ByteString -> d
hm x :: Value
x = MacKey c d -> ByteString -> d
forall c d. Hash c d => MacKey c d -> ByteString -> d
hmac (Value -> MacKey c d
forall c d. Value -> MacKey c d
MacKey Value
x)
  k :: Value
k    = State d -> Value
forall d. State d -> Value
key State d
st
  v :: Value
v    = State d -> Value
forall d. State d -> Value
value State d
st
  k' :: Value
k'   = d -> Value
forall a. Serialize a => a -> Value
encode (d -> Value) -> d -> Value
forall a b. (a -> b) -> a -> b
$ (Value -> ByteString -> d
forall c d. Hash c d => Value -> ByteString -> d
hm Value
k ([ByteString] -> ByteString
L.concat [Value -> ByteString
fc Value
v, Word8 -> ByteString
L.singleton 0, ByteString
input]) d -> State d -> d
forall d (state :: * -> *). d -> state d -> d
`asProxyTypeOf` State d
st)
  v' :: Value
v'   = d -> Value
forall a. Serialize a => a -> Value
encode (d -> Value) -> d -> Value
forall a b. (a -> b) -> a -> b
$ (Value -> ByteString -> d
forall c d. Hash c d => Value -> ByteString -> d
hm Value
k' (Value -> ByteString
fc Value
v) d -> State d -> d
forall d (state :: * -> *). d -> state d -> d
`asProxyTypeOf` State d
st)
  (newK :: Value
newK, newV :: Value
newV) =
    if ByteString -> Int64
L.length ByteString
input Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== 0
      then (Value
k',Value
v')
      else let k'' :: Value
k'' = d -> Value
forall a. Serialize a => a -> Value
encode (d -> Value) -> d -> Value
forall a b. (a -> b) -> a -> b
$ Value -> ByteString -> d
forall c d. Hash c d => Value -> ByteString -> d
hm Value
k' ([ByteString] -> ByteString
L.concat [Value -> ByteString
fc Value
v', Word8 -> ByteString
L.singleton 1, ByteString
input]) d -> State d -> d
forall d (state :: * -> *). d -> state d -> d
`asProxyTypeOf` State d
st
           in (Value
k'', d -> Value
forall a. Serialize a => a -> Value
encode (d -> Value) -> d -> Value
forall a b. (a -> b) -> a -> b
$ Value -> ByteString -> d
forall c d. Hash c d => Value -> ByteString -> d
hm Value
k'' (Value -> ByteString
fc Value
v') d -> State d -> d
forall d (state :: * -> *). d -> state d -> d
`asProxyTypeOf` State d
st)

instantiate :: (Hash c d) => Entropy -> Nonce -> PersonalizationString -> State d
instantiate :: Value -> Value -> Value -> State d
instantiate ent :: Value
ent nonce :: Value
nonce perStr :: Value
perStr = State d
st
  where
  seedMaterial :: ByteString
seedMaterial = [Value] -> ByteString
L.fromChunks [Value
ent, Value
nonce, Value
perStr]
  k :: Value
k = Int -> Word8 -> Value
B.replicate Int
olen 0
  v :: Value
v = Int -> Word8 -> Value
B.replicate Int
olen 1
  st :: State d
st = State d -> ByteString -> State d
forall c d. Hash c d => State d -> ByteString -> State d
update (Word64 -> Value -> Value -> State d
forall d. Word64 -> Value -> Value -> State d
St 1 Value
v Value
k) ByteString
seedMaterial
  olen :: Int
olen = (Tagged d Int
forall ctx d. Hash ctx d => Tagged d Int
outputLength Tagged d Int -> State d -> Int
forall k (s :: k) a (proxy :: k -> *). Tagged s a -> proxy s -> a
`proxy` State d
st) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 8

reseed :: (Hash c d) => State d -> Entropy -> AdditionalInput -> State d
reseed :: State d -> Value -> Value -> State d
reseed st :: State d
st ent :: Value
ent ai :: Value
ai = (State d -> ByteString -> State d
forall c d. Hash c d => State d -> ByteString -> State d
update State d
st ([Value] -> ByteString
L.fromChunks [Value
ent, Value
ai])) { counter :: Word64
counter = 1 }

generate :: (Hash c d) => State d -> BitLength -> AdditionalInput -> Maybe (RandomBits, State d)
generate :: State d -> Int -> Value -> Maybe (Value, State d)
generate st :: State d
st req :: Int
req additionalInput :: Value
additionalInput =
        if(State d -> Word64
forall d. State d -> Word64
counter State d
st Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
reseedInterval)
                then Maybe (Value, State d)
forall a. Maybe a
Nothing
                else (Value, State d) -> Maybe (Value, State d)
forall a. a -> Maybe a
Just (Value
randBitsFinal, State d
stFinal { counter :: Word64
counter = 1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ State d -> Word64
forall d. State d -> Word64
counter State d
st})
  where
  st' :: State d
st' = if Value -> Int
B.length Value
additionalInput Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0
                then State d
st
                else State d -> ByteString -> State d
forall c d. Hash c d => State d -> ByteString -> State d
update State d
st (Value -> ByteString
fc Value
additionalInput)
  reqBytes :: Int
reqBytes = (Int
reqInt -> Int -> Int
forall a. Num a => a -> a -> a
+7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 8
  iterations :: Int
iterations = (Int
reqBytes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
outlen Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
outlen

  -- getV is the main cost.  HMACing and storing 'iterations' bytestrings at
  -- ~64 bytes each is a real waste.  Some pre-allocation and unsafe functions
  -- exported from Crypto.HMAC could cut this down, but it really isn't worth
  -- giving CPR to such a bad idea as using ByteString for crypto computations
  getV :: Value -> Int -> (Value, [B.ByteString])
  getV :: Value -> Int -> (Value, [Value])
getV !Value
u 0 = (Value
u, [])
  getV !Value
u i :: Int
i = 
        let !vNew :: d
vNew = MacKey c d -> Value -> d
forall c d. Hash c d => MacKey c d -> Value -> d
hmac' (Value -> MacKey c d
forall c d. Value -> MacKey c d
MacKey Value
kFinal) Value
u d -> State d -> d
forall d (state :: * -> *). d -> state d -> d
`asProxyTypeOf` State d
st
            !encV :: Value
encV = d -> Value
forall a. Serialize a => a -> Value
encode d
vNew
            (uFinal :: Value
uFinal, rest :: [Value]
rest) = Value -> Int -> (Value, [Value])
getV Value
encV (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
        in (Value
uFinal, Value
encVValue -> [Value] -> [Value]
forall a. a -> [a] -> [a]
:[Value]
rest)
  (vFinal :: Value
vFinal, randBitsList :: [Value]
randBitsList) = Value -> Int -> (Value, [Value])
getV (State d -> Value
forall d. State d -> Value
value State d
st') Int
iterations
  randBitsFinal :: Value
randBitsFinal = Int -> Value -> Value
B.take Int
reqBytes (Value -> Value) -> Value -> Value
forall a b. (a -> b) -> a -> b
$ [Value] -> Value
B.concat [Value]
randBitsList
  kFinal :: Value
kFinal = State d -> Value
forall d. State d -> Value
key State d
st'
  stFinal :: State d
stFinal = State d -> ByteString -> State d
forall c d. Hash c d => State d -> ByteString -> State d
update (State d
st' { key :: Value
key = Value
kFinal, value :: Value
value = Value
vFinal} State d -> State d -> State d
forall a. a -> a -> a
`asTypeOf` State d
st) (Value -> ByteString
fc Value
additionalInput)
  outlen :: Int
outlen = (Tagged d Int
forall ctx d. Hash ctx d => Tagged d Int
outputLength Tagged d Int -> State d -> Int
forall k (s :: k) a (proxy :: k -> *). Tagged s a -> proxy s -> a
`proxy` State d
st) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 8