-- RSA LAB

import Data.Char

type Message = [Int]

stringToMessage::String -> Message
stringToMessage = map ord

messageToString::Message -> String
messageToString = map chr

pad::Int -> Message -> Message
pad bsize msg = 
	let msgSize = length msg in
	let gap = mod msgSize bsize in
	let padding = bsize - gap in
	msg ++ [padding | _ <- [1..padding] ]

unpad::Message -> Message
unpad m = reverse $ subunpad (-1) [] m

-- First arg s
-- -1 > Travel until reaching end if 
-- 0 > Pad removed
-- n > Still n items to remove
subunpad::Int -> Message -> Message -> Message
subunpad _ [] [] = []
subunpad (-1) ys (s:[]) = subunpad (s - 1) ys []
subunpad (-1) ys (x:xs) = subunpad (-1) (x:ys) xs
subunpad 0 ys _ = ys
subunpad s (y:ys) _ = subunpad (s - 1) ys []

groupBytes::Message -> Int
groupBytes = subgroupBytes 1 0

-- First arg : pow > Contains 256^n where n is the number of recursive calls
-- Second arg : acc > contains the result of the block
subgroupBytes::Int -> Int -> Message -> Int
subgroupBytes _ acc [] = acc
subgroupBytes pow acc (c:msg) = subgroupBytes (pow * 256) (acc + c * pow) msg

ungroupBytes::Int -> Message
ungroupBytes 0 = []
ungroupBytes n = (mod n 256):ungroupBytes (div n 256)

groupN::Int -> Message -> [Message]
groupN _ [] = []
groupN bsize s = (take bsize s):groupN bsize (drop bsize s)

makeBlocks::Int -> Message -> Message
makeBlocks bsize msg = map groupBytes (groupN bsize msg)

splitBlocks::Message -> Message
splitBlocks msg = concat (map ungroupBytes msg)

-- Reuse arithmetics from slide 42
primecandidates = [6 * k + a | k <- [1..], a <- [-1, 1]]
dividers n = [k | k <- takeWhile (\k -> k * k <= n) primeinf, rem n k == 0]
prime n = null (dividers n)
primeinf = 2:3:[n | n <- primecandidates, prime n]

choosePrime::Int -> Int
choosePrime b = head $ dropWhile (<= b) primeinf

-- a -> b -> (g, u, v) where a * u + b * v = g with g GCD
euclide::Int -> Int -> (Int, Int, Int)
euclide a 0 = (a, 1, 0)
euclide a b = let (d', u', v') = euclide b (mod a b) in (d', v', u' - (div a b) * v')

modInv e n = let (_, d, _) = euclide e n in d

-- Return x ^ k (mod) n
expMod x k n =
	if k == 0
		then 1
		else if even k
			then expMod (mod ((mod x n) * (mod x n)) n) (div k 2) n
			else (mod x n) * expMod ((mod x n) * (mod x n)) (div k 2) n

encrypt::Int -> Int -> Int -> String -> Message
encrypt e n bsize smsg =
	let msg = stringToMessage smsg in
	let pmsg = pad bsize msg in
	let bmsg = makeBlocks bsize pmsg in
	map (\m -> expMod m e n) bmsg