Add more tests, implement /register
and /heartbeat
This commit is contained in:
parent
baedf9ba16
commit
cd25b0092d
|
@ -5,11 +5,12 @@ import io.ktor.serialization.kotlinx.json.*
|
|||
import io.ktor.server.application.*
|
||||
import io.ktor.server.engine.*
|
||||
import io.ktor.server.netty.*
|
||||
import io.ktor.server.plugins.*
|
||||
import io.ktor.server.plugins.contentnegotiation.*
|
||||
import io.ktor.server.request.*
|
||||
import io.ktor.server.request.ContentTransformationException
|
||||
import io.ktor.server.response.*
|
||||
import io.ktor.server.routing.*
|
||||
import java.util.*
|
||||
|
||||
fun main() {
|
||||
embeddedServer(Netty, port = 8080, host = "0.0.0.0", module = Application::module)
|
||||
|
@ -42,14 +43,46 @@ fun Application.module() {
|
|||
} catch (e: Exception) {
|
||||
call.respond(HttpStatusCode.BadRequest, "Invalid request")
|
||||
}
|
||||
|
||||
if (request != null) {
|
||||
var client = DatabaseHandler.getClient(request.name)
|
||||
|
||||
if (client == null) {
|
||||
client = DatabaseHandler.getClient(request.id)
|
||||
}
|
||||
|
||||
if (client != null) {
|
||||
call.respond(HttpStatusCode.BadRequest, "Invalid request")
|
||||
} else {
|
||||
val id = DatabaseHandler.addClient(request.id, request.name,this.call.request.origin.remoteHost+":${this.call.request.origin.port}")
|
||||
call.respond(HttpStatusCode.OK, RegisterResponse(id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post("/heartbeat") {
|
||||
var request: HeartbeatRequest? = null
|
||||
|
||||
try {
|
||||
request = call.receive()
|
||||
} catch (e: Exception) {
|
||||
call.respond(HttpStatusCode.BadRequest, "Invalid request")
|
||||
}
|
||||
|
||||
try {
|
||||
if (DatabaseHandler.checkCredentials(request!!.id, request!!.password_hash)) {
|
||||
val password = DatabaseHandler.doHeartbeat(request!!.id)
|
||||
call.respond(HttpStatusCode.OK, HeartbeatResponse(password))
|
||||
} else {
|
||||
call.respond(HttpStatusCode.Forbidden, "Invalid credentials")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
call.respond(HttpStatusCode.BadRequest, "Invalid request")
|
||||
}
|
||||
}
|
||||
|
||||
post("/exit") {
|
||||
var request: ExitRequest? = null
|
||||
val request: ExitRequest?
|
||||
|
||||
try {
|
||||
request = call.receive()
|
||||
|
@ -59,11 +92,11 @@ fun Application.module() {
|
|||
|
||||
try {
|
||||
if (DatabaseHandler.checkCredentials(
|
||||
request!!.id,
|
||||
Base64.getDecoder().decode(request!!.password_hash)
|
||||
request.id,
|
||||
request.password_hash
|
||||
)
|
||||
) {
|
||||
if (DatabaseHandler.removeClient(request!!.id)) {
|
||||
if (DatabaseHandler.removeClient(request.id)) {
|
||||
call.respondText("Client removed")
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -20,14 +20,14 @@ class Client(
|
|||
var name: String,
|
||||
var timeout: Instant,
|
||||
var ip: String,
|
||||
var password_hash: ByteArray
|
||||
var password_hash: String
|
||||
)
|
||||
|
||||
object Clients: UUIDTable() {
|
||||
var name = varchar("name", 16)
|
||||
var ip = varchar("ip", 21)
|
||||
var timeout = timestamp("timeout")
|
||||
var password_hash = binary("password_hash", 96)
|
||||
var password_hash = varchar("password_hash", 96)
|
||||
}
|
||||
|
||||
object DatabaseHandler {
|
||||
|
@ -41,22 +41,34 @@ object DatabaseHandler {
|
|||
}
|
||||
}
|
||||
|
||||
fun addClient(id: UUID, name: String, ip: String): ByteArray {
|
||||
fun generatePassword(): String {
|
||||
val password = Random.nextBytes(128)
|
||||
|
||||
return Base64.getEncoder().encodeToString(password)
|
||||
}
|
||||
|
||||
fun getPasswordHash(password: String): String {
|
||||
val sha3 = SHA3.getSha3_512()
|
||||
val passwordHash = Base64.getEncoder().encode(sha3.digest(password))
|
||||
val passwordHash = sha3.digest(Base64.getDecoder().decode(password))
|
||||
|
||||
return Base64.getEncoder().encodeToString(passwordHash)
|
||||
}
|
||||
|
||||
fun addClient(id: UUID, name: String, ip: String): String {
|
||||
val pw = generatePassword()
|
||||
val pwHash = getPasswordHash(pw)
|
||||
|
||||
transaction {
|
||||
Clients.insert {
|
||||
it[Clients.id] = id
|
||||
it[Clients.name] = name
|
||||
it[Clients.ip] = ip
|
||||
it[password_hash] = passwordHash
|
||||
it[password_hash] = pwHash
|
||||
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
|
||||
}
|
||||
}
|
||||
|
||||
return password
|
||||
return pw
|
||||
}
|
||||
|
||||
fun getClient(name: String): Client? {
|
||||
|
@ -67,17 +79,39 @@ object DatabaseHandler {
|
|||
}
|
||||
}
|
||||
|
||||
fun getClient(id: UUID): Client? {
|
||||
return transaction {
|
||||
Clients.select { Clients.id eq id }.map {
|
||||
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.password_hash])
|
||||
}.firstOrNull()
|
||||
}
|
||||
}
|
||||
|
||||
fun doHeartbeat(id: UUID): String {
|
||||
val pw = generatePassword()
|
||||
val pwHash = getPasswordHash(pw)
|
||||
|
||||
transaction {
|
||||
Clients.update({ Clients.id eq id }) {
|
||||
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
|
||||
it[password_hash] = pwHash
|
||||
}
|
||||
}
|
||||
|
||||
return pw
|
||||
}
|
||||
|
||||
fun removeClient(id: UUID): Boolean {
|
||||
return transaction {
|
||||
return@transaction Clients.deleteWhere { Clients.id eq id } > 0
|
||||
}
|
||||
}
|
||||
|
||||
fun checkCredentials(id: UUID, password_hash: ByteArray): Boolean {
|
||||
fun checkCredentials(id: UUID, password_hash: String): Boolean {
|
||||
return transaction {
|
||||
return@transaction Clients.select { Clients.id eq id }.map {
|
||||
it[Clients.password_hash]
|
||||
}.firstOrNull().contentEquals(password_hash)
|
||||
}.firstOrNull()?.contentEquals(password_hash) ?: false
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import java.util.*
|
|||
data class RegisterRequest(
|
||||
val id: UUID,
|
||||
val name: String,
|
||||
val ip: String,
|
||||
val port: Int,
|
||||
)
|
||||
|
||||
@Serializable
|
||||
|
@ -21,6 +21,12 @@ data class RegisterResponse(
|
|||
@Serializable
|
||||
data class HeartbeatRequest(
|
||||
val id: UUID,
|
||||
val password_hash: String,
|
||||
)
|
||||
|
||||
@Serializable
|
||||
data class HeartbeatResponse(
|
||||
val new_password: String,
|
||||
)
|
||||
|
||||
@Serializable
|
||||
|
|
|
@ -31,11 +31,11 @@ class ApplicationKtTest {
|
|||
}
|
||||
|
||||
addTestData()
|
||||
val pw_hash = DatabaseHandler.getClient("TestClient1")?.password_hash
|
||||
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
|
||||
|
||||
client.post("/exit") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000000"), Base64.getEncoder().encodeToString(pw_hash)))
|
||||
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000000"), pw_hash))
|
||||
}.apply {
|
||||
assertEquals("Client not found", body())
|
||||
}
|
||||
|
@ -78,11 +78,11 @@ class ApplicationKtTest {
|
|||
|
||||
addTestData()
|
||||
|
||||
val pw_hash = DatabaseHandler.getClient("TestClient1")?.password_hash
|
||||
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
|
||||
|
||||
client.post("/exit") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), Base64.getEncoder().encodeToString(pw_hash)))
|
||||
setBody(ExitRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), pw_hash))
|
||||
}.apply {
|
||||
assertEquals("Client removed", body())
|
||||
}
|
||||
|
@ -98,4 +98,100 @@ class ApplicationKtTest {
|
|||
assertEquals(400, this.status.value)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHeartbeatRequestNoBody() = testApplication {
|
||||
application {
|
||||
module()
|
||||
}
|
||||
|
||||
client.post("/heartbeat").apply {
|
||||
assertEquals(400, this.status.value)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHeartbeatRequestCorrectBody() = testApplication {
|
||||
val client = createClient {
|
||||
install(ContentNegotiation) {
|
||||
json()
|
||||
}
|
||||
}
|
||||
|
||||
addTestData()
|
||||
|
||||
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
|
||||
|
||||
application {
|
||||
module()
|
||||
}
|
||||
|
||||
client.post("/heartbeat") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(HeartbeatRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), pw_hash))
|
||||
}.apply {
|
||||
assertEquals(200, this.status.value)
|
||||
this.body() as HeartbeatResponse
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHeartbeatRequestInvalidPassword() = testApplication {
|
||||
val client = createClient {
|
||||
install(ContentNegotiation) {
|
||||
json()
|
||||
}
|
||||
}
|
||||
|
||||
application {
|
||||
module()
|
||||
}
|
||||
|
||||
addTestData()
|
||||
|
||||
client.post("/heartbeat") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(HeartbeatRequest(UUID.fromString("00000000-0000-0000-0000-000000000001"), "heheh"))
|
||||
}.apply {
|
||||
assertEquals(403, this.status.value)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRegisterRequestNoBody() = testApplication {
|
||||
application {
|
||||
module()
|
||||
}
|
||||
|
||||
client.post("/register").apply {
|
||||
assertEquals(400, this.status.value)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRegisterRequestCorrectBody() = testApplication {
|
||||
val client = createClient {
|
||||
install(ContentNegotiation) {
|
||||
json()
|
||||
}
|
||||
}
|
||||
|
||||
application {
|
||||
module()
|
||||
}
|
||||
|
||||
client.post("/register") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(
|
||||
RegisterRequest(
|
||||
UUID.fromString("00000000-0000-0000-0000-000000000002"),
|
||||
"TestClient2",
|
||||
1
|
||||
)
|
||||
)
|
||||
}.apply {
|
||||
assertEquals(200, this.status.value)
|
||||
this.body() as RegisterResponse
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package org.muellerssoftware.openproximitychat
|
||||
|
||||
import org.h2.security.SHA3
|
||||
import org.junit.Before
|
||||
import java.util.*
|
||||
import kotlin.test.Test
|
||||
|
@ -26,7 +27,7 @@ internal class DatabaseHandlerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun addClient() {
|
||||
fun testAddClient() {
|
||||
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000010"), "AddTestClient1", "0.0.0.0:1")
|
||||
DatabaseHandler.getClient("AddTestClient1")?.let {
|
||||
assert(it.name == "AddTestClient1")
|
||||
|
@ -36,7 +37,7 @@ internal class DatabaseHandlerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun getClient() {
|
||||
fun testGetClient() {
|
||||
DatabaseHandler.getClient("TestClient1")?.let {
|
||||
assert(it.name == "TestClient1")
|
||||
assert(it.ip == "0.0.0.0:1")
|
||||
|
@ -57,7 +58,7 @@ internal class DatabaseHandlerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun removeClient() {
|
||||
fun testRemoveClient() {
|
||||
DatabaseHandler.removeClient(UUID.fromString("00000000-0000-0000-0000-000000000001"))
|
||||
assert(DatabaseHandler.getClient("TestClient1") == null)
|
||||
|
||||
|
@ -69,7 +70,7 @@ internal class DatabaseHandlerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun checkAuth() {
|
||||
fun testCheckAuth() {
|
||||
DatabaseHandler.getClient("TestClient1")?.let {
|
||||
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
|
||||
}
|
||||
|
@ -82,4 +83,32 @@ internal class DatabaseHandlerTest {
|
|||
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHeartbeatHashes() {
|
||||
val sha3 = SHA3.getSha3_512()
|
||||
|
||||
val clients = listOf("TestClient1", "TestClient2", "TestClient3")
|
||||
|
||||
for (client in clients) {
|
||||
DatabaseHandler.getClient(client)?.let {
|
||||
val pw = DatabaseHandler.doHeartbeat(it.uuid)
|
||||
val new_record = DatabaseHandler.getClient(client)
|
||||
assert(Base64.getEncoder().encodeToString(sha3.digest(Base64.getDecoder().decode(pw)))!!.contentEquals(new_record!!.password_hash))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHeartbeatTimes() {
|
||||
val clients = listOf("TestClient1", "TestClient2", "TestClient3")
|
||||
|
||||
for (client in clients) {
|
||||
DatabaseHandler.getClient(client)?.let {
|
||||
DatabaseHandler.doHeartbeat(it.uuid)
|
||||
val new_record = DatabaseHandler.getClient(client)
|
||||
assert(new_record!!.timeout > it.timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user