Add more tests, implement /register and /heartbeat

This commit is contained in:
ProtoByter 2022-12-01 08:25:17 +00:00
parent baedf9ba16
commit cd25b0092d
5 changed files with 220 additions and 22 deletions

View File

@ -5,11 +5,12 @@ import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.server.plugins.*
import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.server.request.*
import io.ktor.server.request.ContentTransformationException
import io.ktor.server.response.*
import io.ktor.server.routing.*
import java.util.*
fun main() {
embeddedServer(Netty, port = 8080, host = "0.0.0.0", module = Application::module)
@ -42,14 +43,46 @@ fun Application.module() {
} catch (e: Exception) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
}
if (request != null) {
var client = DatabaseHandler.getClient(request.name)
if (client == null) {
client = DatabaseHandler.getClient(request.id)
}
if (client != null) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
} else {
val id = DatabaseHandler.addClient(request.id, request.name,this.call.request.origin.remoteHost+":${this.call.request.origin.port}")
call.respond(HttpStatusCode.OK, RegisterResponse(id))
}
}
}
post("/heartbeat") {
var request: HeartbeatRequest? = null
try {
request = call.receive()
} catch (e: Exception) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
}
try {
if (DatabaseHandler.checkCredentials(request!!.id, request!!.password_hash)) {
val password = DatabaseHandler.doHeartbeat(request!!.id)
call.respond(HttpStatusCode.OK, HeartbeatResponse(password))
} else {
call.respond(HttpStatusCode.Forbidden, "Invalid credentials")
}
} catch (e: Exception) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
}
}
post("/exit") {
var request: ExitRequest? = null
val request: ExitRequest?
try {
request = call.receive()
@ -59,11 +92,11 @@ fun Application.module() {
try {
if (DatabaseHandler.checkCredentials(
request!!.id,
Base64.getDecoder().decode(request!!.password_hash)
request.id,
request.password_hash
)
) {
if (DatabaseHandler.removeClient(request!!.id)) {
if (DatabaseHandler.removeClient(request.id)) {
call.respondText("Client removed")
}
} else {

View File

@ -20,14 +20,14 @@ class Client(
var name: String,
var timeout: Instant,
var ip: String,
var password_hash: ByteArray
var password_hash: String
)
object Clients: UUIDTable() {
var name = varchar("name", 16)
var ip = varchar("ip", 21)
var timeout = timestamp("timeout")
var password_hash = binary("password_hash", 96)
var password_hash = varchar("password_hash", 96)
}
object DatabaseHandler {
@ -41,22 +41,34 @@ object DatabaseHandler {
}
}
fun addClient(id: UUID, name: String, ip: String): ByteArray {
fun generatePassword(): String {
val password = Random.nextBytes(128)
return Base64.getEncoder().encodeToString(password)
}
fun getPasswordHash(password: String): String {
val sha3 = SHA3.getSha3_512()
val passwordHash = Base64.getEncoder().encode(sha3.digest(password))
val passwordHash = sha3.digest(Base64.getDecoder().decode(password))
return Base64.getEncoder().encodeToString(passwordHash)
}
fun addClient(id: UUID, name: String, ip: String): String {
val pw = generatePassword()
val pwHash = getPasswordHash(pw)
transaction {
Clients.insert {
it[Clients.id] = id
it[Clients.name] = name
it[Clients.ip] = ip
it[password_hash] = passwordHash
it[password_hash] = pwHash
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
}
}
return password
return pw
}
fun getClient(name: String): Client? {
@ -67,17 +79,39 @@ object DatabaseHandler {
}
}
fun getClient(id: UUID): Client? {
return transaction {
Clients.select { Clients.id eq id }.map {
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.password_hash])
}.firstOrNull()
}
}
fun doHeartbeat(id: UUID): String {
val pw = generatePassword()
val pwHash = getPasswordHash(pw)
transaction {
Clients.update({ Clients.id eq id }) {
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
it[password_hash] = pwHash
}
}
return pw
}
fun removeClient(id: UUID): Boolean {
return transaction {
return@transaction Clients.deleteWhere { Clients.id eq id } > 0
}
}
fun checkCredentials(id: UUID, password_hash: ByteArray): Boolean {
fun checkCredentials(id: UUID, password_hash: String): Boolean {
return transaction {
return@transaction Clients.select { Clients.id eq id }.map {
it[Clients.password_hash]
}.firstOrNull().contentEquals(password_hash)
}.firstOrNull()?.contentEquals(password_hash) ?: false
}
}

View File

@ -10,7 +10,7 @@ import java.util.*
data class RegisterRequest(
val id: UUID,
val name: String,
val ip: String,
val port: Int,
)
@Serializable
@ -21,6 +21,12 @@ data class RegisterResponse(
@Serializable
data class HeartbeatRequest(
val id: UUID,
val password_hash: String,
)
@Serializable
data class HeartbeatResponse(
val new_password: String,
)
@Serializable

View File

@ -31,11 +31,11 @@ class ApplicationKtTest {
}
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")?.password_hash
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
client.post("/exit") {
contentType(ContentType.Application.Json)
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000000"), Base64.getEncoder().encodeToString(pw_hash)))
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000000"), pw_hash))
}.apply {
assertEquals("Client not found", body())
}
@ -78,11 +78,11 @@ class ApplicationKtTest {
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")?.password_hash
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
client.post("/exit") {
contentType(ContentType.Application.Json)
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), Base64.getEncoder().encodeToString(pw_hash)))
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), pw_hash))
}.apply {
assertEquals("Client removed", body())
}
@ -98,4 +98,100 @@ class ApplicationKtTest {
assertEquals(400, this.status.value)
}
}
@Test
fun testHeartbeatRequestNoBody() = testApplication {
application {
module()
}
client.post("/heartbeat").apply {
assertEquals(400, this.status.value)
}
}
@Test
fun testHeartbeatRequestCorrectBody() = testApplication {
val client = createClient {
install(ContentNegotiation) {
json()
}
}
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
application {
module()
}
client.post("/heartbeat") {
contentType(ContentType.Application.Json)
setBody(HeartbeatRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), pw_hash))
}.apply {
assertEquals(200, this.status.value)
this.body() as HeartbeatResponse
}
}
@Test
fun testHeartbeatRequestInvalidPassword() = testApplication {
val client = createClient {
install(ContentNegotiation) {
json()
}
}
application {
module()
}
addTestData()
client.post("/heartbeat") {
contentType(ContentType.Application.Json)
setBody(HeartbeatRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), "heheh"))
}.apply {
assertEquals(403, this.status.value)
}
}
@Test
fun testRegisterRequestNoBody() = testApplication {
application {
module()
}
client.post("/register").apply {
assertEquals(400, this.status.value)
}
}
@Test
fun testRegisterRequestCorrectBody() = testApplication {
val client = createClient {
install(ContentNegotiation) {
json()
}
}
application {
module()
}
client.post("/register") {
contentType(ContentType.Application.Json)
setBody(
RegisterRequest(
UUID.fromString("00000000-0000-0000-0000-000000000002"),
"TestClient2",
1
)
)
}.apply {
assertEquals(200, this.status.value)
this.body() as RegisterResponse
}
}
}

View File

@ -1,5 +1,6 @@
package org.muellerssoftware.openproximitychat
import org.h2.security.SHA3
import org.junit.Before
import java.util.*
import kotlin.test.Test
@ -26,7 +27,7 @@ internal class DatabaseHandlerTest {
}
@Test
fun addClient() {
fun testAddClient() {
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000010"), "AddTestClient1", "0.0.0.0:1")
DatabaseHandler.getClient("AddTestClient1")?.let {
assert(it.name == "AddTestClient1")
@ -36,7 +37,7 @@ internal class DatabaseHandlerTest {
}
@Test
fun getClient() {
fun testGetClient() {
DatabaseHandler.getClient("TestClient1")?.let {
assert(it.name == "TestClient1")
assert(it.ip == "0.0.0.0:1")
@ -57,7 +58,7 @@ internal class DatabaseHandlerTest {
}
@Test
fun removeClient() {
fun testRemoveClient() {
DatabaseHandler.removeClient(UUID.fromString("00000000-0000-0000-0000-000000000001"))
assert(DatabaseHandler.getClient("TestClient1") == null)
@ -69,7 +70,7 @@ internal class DatabaseHandlerTest {
}
@Test
fun checkAuth() {
fun testCheckAuth() {
DatabaseHandler.getClient("TestClient1")?.let {
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
}
@ -82,4 +83,32 @@ internal class DatabaseHandlerTest {
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
}
}
@Test
fun testHeartbeatHashes() {
val sha3 = SHA3.getSha3_512()
val clients = listOf("TestClient1", "TestClient2", "TestClient3")
for (client in clients) {
DatabaseHandler.getClient(client)?.let {
val pw = DatabaseHandler.doHeartbeat(it.uuid)
val new_record = DatabaseHandler.getClient(client)
assert(Base64.getEncoder().encodeToString(sha3.digest(Base64.getDecoder().decode(pw)))!!.contentEquals(new_record!!.password_hash))
}
}
}
@Test
fun testHeartbeatTimes() {
val clients = listOf("TestClient1", "TestClient2", "TestClient3")
for (client in clients) {
DatabaseHandler.getClient(client)?.let {
DatabaseHandler.doHeartbeat(it.uuid)
val new_record = DatabaseHandler.getClient(client)
assert(new_record!!.timeout > it.timeout)
}
}
}
}