Multiple misc changes + /protected/search now implements origin policy

This commit is contained in:
ProtoByter 2022-12-03 09:00:20 +00:00
parent 91d2f2b87e
commit 58d0438b14
6 changed files with 159 additions and 36 deletions

View File

@ -32,6 +32,7 @@ fun Application.module() {
}
}
}
routing {
get("/info") {
call.respondText("""
@ -68,7 +69,7 @@ fun Application.module() {
if (client != null) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
} else {
val id = DatabaseHandler.addClient(request.id, request.name,this.call.request.origin.remoteHost+":${this.call.request.origin.port}")
val id = DatabaseHandler.addClient(request.id, request.name,this.call.request.origin.remoteHost+":${this.call.request.origin.port}", request.sampledClients)
call.respond(HttpStatusCode.OK, RegisterResponse(id))
}
}
@ -93,6 +94,26 @@ fun Application.module() {
call.respondText("Client not found")
}
}
get("/protected/search") {
var request: SearchRequest? = null
try {
request = call.receive()
} catch (e: Exception) {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
}
if (request != null) {
val client = DatabaseHandler.getClient(request!!.id)
if (client!!.sampledClients == request!!.sampledClients) {
call.respond(HttpStatusCode.OK, SearchResponse(client.ip))
} else {
call.respond(HttpStatusCode.BadRequest, "Invalid request")
}
}
}
}
}

View File

@ -20,14 +20,23 @@ class Client(
var name: String,
var timeout: Instant,
var ip: String,
var password_hash: String
var passwordHash: String,
var sampledClients: List<UUID>
)
object Clients: UUIDTable() {
var name = varchar("name", 16)
var ip = varchar("ip", 21)
var timeout = timestamp("timeout")
var password_hash = varchar("password_hash", 96)
var passwordHash = varchar("passwordHash", 96)
var sampledClient1 = uuid("sampledClient1")
var sampledClient2 = uuid("sampledClient2")
var sampledClient3 = uuid("sampledClient3")
var sampledClient4 = uuid("sampledClient4")
var sampledClient5 = uuid("sampledClient5")
var sampledClient6 = uuid("sampledClient6")
var sampledClient7 = uuid("sampledClient7")
var sampledClient8 = uuid("sampledClient8")
}
object DatabaseHandler {
@ -41,20 +50,20 @@ object DatabaseHandler {
}
}
fun generatePassword(): String {
private fun generatePassword(): String {
val password = Random.nextBytes(128)
return Base64.getEncoder().encodeToString(password)
}
fun getPasswordHash(password: String): String {
private fun getPasswordHash(password: String): String {
val sha3 = SHA3.getSha3_512()
val passwordHash = sha3.digest(Base64.getDecoder().decode(password))
return Base64.getEncoder().encodeToString(passwordHash)
}
fun addClient(id: UUID, name: String, ip: String): String {
fun addClient(id: UUID, name: String, ip: String, sampledClients: List<UUID>): String {
val pw = generatePassword()
val pwHash = getPasswordHash(pw)
@ -63,28 +72,55 @@ object DatabaseHandler {
it[Clients.id] = id
it[Clients.name] = name
it[Clients.ip] = ip
it[password_hash] = pwHash
it[passwordHash] = pwHash
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
it[sampledClient1] = sampledClients[0]
it[sampledClient2] = sampledClients[1]
it[sampledClient3] = sampledClients[2]
it[sampledClient4] = sampledClients[3]
it[sampledClient5] = sampledClients[4]
it[sampledClient6] = sampledClients[5]
it[sampledClient7] = sampledClients[6]
it[sampledClient8] = sampledClients[7]
}
}
return pw
}
fun getClient(name: String): Client? {
return transaction {
Clients.select { Clients.name eq name }.map {
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.password_hash])
}.firstOrNull()
private fun checkTimeout(client: Client): Client? {
if (client.timeout < Instant.now()) {
removeClient(client.uuid)
return null
}
return client
}
fun getClient(name: String): Client? {
val client = transaction {
Clients.select { Clients.name eq name }.map {
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.passwordHash], listOf(
it[Clients.sampledClient1],it[Clients.sampledClient2],it[Clients.sampledClient3],it[Clients.sampledClient4],
it[Clients.sampledClient5],it[Clients.sampledClient6],it[Clients.sampledClient7],it[Clients.sampledClient8]
))
}.firstOrNull()
} ?: return null
return checkTimeout(client)
}
fun getClient(id: UUID): Client? {
return transaction {
val client = transaction {
Clients.select { Clients.id eq id }.map {
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.password_hash])
Client(it[Clients.id].value, it[Clients.name], it[Clients.timeout], it[Clients.ip], it[Clients.passwordHash], listOf(
it[Clients.sampledClient1],it[Clients.sampledClient2],it[Clients.sampledClient3],it[Clients.sampledClient4],
it[Clients.sampledClient5],it[Clients.sampledClient6],it[Clients.sampledClient7],it[Clients.sampledClient8]
))
}.firstOrNull()
}
} ?: return null
return checkTimeout(client)
}
fun doHeartbeat(id: UUID): String {
@ -94,7 +130,7 @@ object DatabaseHandler {
transaction {
Clients.update({ Clients.id eq id }) {
it[timeout] = java.time.LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC) + java.time.Duration.ofMinutes(5)
it[password_hash] = pwHash
it[passwordHash] = pwHash
}
}
@ -107,11 +143,13 @@ object DatabaseHandler {
}
}
fun checkCredentials(id: UUID, password_hash: String): Boolean {
return transaction {
return@transaction Clients.select { Clients.id eq id }.map {
it[Clients.password_hash]
}.firstOrNull()?.contentEquals(password_hash) ?: false
fun checkCredentials(id: UUID, passwordHash: String): Boolean {
val client = getClient(id) ?: return false
return if (checkTimeout(client) != null) {
client.passwordHash == passwordHash
} else {
false
}
}

View File

@ -14,6 +14,7 @@ data class RegisterRequest(
val id: UUID,
val name: String,
val port: Int,
val sampledClients: List<UUID>
)
@Serializable
@ -23,15 +24,16 @@ data class RegisterResponse(
@Serializable
data class HeartbeatResponse(
val new_password: String,
val newPassword: String,
)
@Serializable
data class SearchRequest(
val id: UUID,
val sampledClients: List<UUID>,
)
@Serializable
data class SearchResponse(
val clients: List<Client>
val ip: String,
)

View File

@ -16,7 +16,6 @@ object Versions {
init {
Application::class.java.getResource("/versions.txt")!!.openStream().use {
it.bufferedReader().useLines { lines ->
lines.forEach {
val split = it.split("=")
when (split[0]) {

View File

@ -12,10 +12,11 @@ import kotlin.test.Test
import kotlin.test.assertEquals
class ApplicationKtTest {
val sampledClientsTestVal = listOf(UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID())
@Before
fun addTestData() {
DatabaseHandler.clear()
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000001"), "TestClient1", "0.0.0.0:1")
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000001"), "TestClient1", "0.0.0.0:1", sampledClientsTestVal)
}
@Test
@ -31,7 +32,7 @@ class ApplicationKtTest {
}
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.passwordHash
client.post("/protected/exit") {
basicAuth("00000000-0000-0000-0000-000000000000", pw_hash)
@ -76,7 +77,7 @@ class ApplicationKtTest {
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.passwordHash
client.post("/protected/exit") {
basicAuth("00000000-0000-0000-0000-000000000001", pw_hash)
@ -117,7 +118,7 @@ class ApplicationKtTest {
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.password_hash
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.passwordHash
application {
module()
@ -181,7 +182,8 @@ class ApplicationKtTest {
RegisterRequest(
UUID.fromString("00000000-0000-0000-0000-000000000002"),
"TestClient2",
1
1,
sampledClientsTestVal
)
)
}.apply {
@ -189,4 +191,47 @@ class ApplicationKtTest {
this.body() as RegisterResponse
}
}
@Test
fun testSearchRequestNoBody() = testApplication {
application {
module()
}
client.get("/protected/search").apply {
assertEquals(401, this.status.value)
}
}
@Test
fun testSearchRequestCorrectBody() = testApplication {
val client = createClient {
install(ContentNegotiation) {
json()
}
}
application {
module()
}
addTestData()
val pw_hash = DatabaseHandler.getClient("TestClient1")!!.passwordHash
client.get("/protected/search") {
contentType(ContentType.Application.Json)
setBody(
SearchRequest(
UUID.fromString("00000000-0000-0000-0000-000000000001"),
sampledClientsTestVal
)
)
basicAuth("00000000-0000-0000-0000-000000000001", pw_hash)
}.apply {
val clientCorrect = DatabaseHandler.getClient(UUID.fromString("00000000-0000-0000-0000-000000000001"))
assertEquals(200, this.status.value)
assertEquals((this.body() as SearchResponse).ip, clientCorrect!!.ip)
}
}
}

View File

@ -6,6 +6,17 @@ import java.util.*
import kotlin.test.Test
internal class DatabaseHandlerTest {
private val sampledClientsTestVal = listOf(
UUID.fromString("00000000-0000-0000-0000-000000000002"),
UUID.fromString("00000000-0000-0000-0000-000000000003"),
UUID.fromString("00000000-0000-0000-0000-000000000004"),
UUID.fromString("00000000-0000-0000-0000-000000000005"),
UUID.fromString("00000000-0000-0000-0000-000000000006"),
UUID.fromString("00000000-0000-0000-0000-000000000007"),
UUID.fromString("00000000-0000-0000-0000-000000000008"),
UUID.fromString("00000000-0000-0000-0000-000000000009"),
)
@Before
fun setUp() {
DatabaseHandler.clear()
@ -13,26 +24,30 @@ internal class DatabaseHandlerTest {
UUID.fromString("00000000-0000-0000-0000-000000000001"),
"TestClient1",
"0.0.0.0:1",
sampledClientsTestVal
)
DatabaseHandler.addClient(
UUID.fromString("00000000-0000-0000-0000-000000000002"),
"TestClient2",
"0.0.0.0:1"
"0.0.0.0:1",
sampledClientsTestVal
)
DatabaseHandler.addClient(
UUID.fromString("00000000-0000-0000-0000-000000000003"),
"TestClient3",
"0.0.0.0:1"
"0.0.0.0:1",
sampledClientsTestVal
)
}
@Test
fun testAddClient() {
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000010"), "AddTestClient1", "0.0.0.0:1")
DatabaseHandler.addClient(UUID.fromString("00000000-0000-0000-0000-000000000010"), "AddTestClient1", "0.0.0.0:1", sampledClientsTestVal)
DatabaseHandler.getClient("AddTestClient1")?.let {
assert(it.name == "AddTestClient1")
assert(it.ip == "0.0.0.0:1")
assert(it.uuid == UUID.fromString("00000000-0000-0000-0000-000000000010"))
assert(it.sampledClients == sampledClientsTestVal)
}
}
@ -42,18 +57,21 @@ internal class DatabaseHandlerTest {
assert(it.name == "TestClient1")
assert(it.ip == "0.0.0.0:1")
assert(it.uuid == UUID.fromString("00000000-0000-0000-0000-000000000001"))
assert(it.sampledClients == sampledClientsTestVal)
}
DatabaseHandler.getClient("TestClient2")?.let {
assert(it.name == "TestClient2")
assert(it.ip == "0.0.0.0:1")
assert(it.uuid == UUID.fromString("00000000-0000-0000-0000-000000000002"))
assert(it.sampledClients == sampledClientsTestVal)
}
DatabaseHandler.getClient("TestClient3")?.let {
assert(it.name == "TestClient3")
assert(it.ip == "0.0.0.0:1")
assert(it.uuid == UUID.fromString("00000000-0000-0000-0000-000000000003"))
assert(it.sampledClients == sampledClientsTestVal)
}
}
@ -72,15 +90,15 @@ internal class DatabaseHandlerTest {
@Test
fun testCheckAuth() {
DatabaseHandler.getClient("TestClient1")?.let {
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
assert(DatabaseHandler.checkCredentials(it.uuid, it.passwordHash))
}
DatabaseHandler.getClient("TestClient2")?.let {
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
assert(DatabaseHandler.checkCredentials(it.uuid, it.passwordHash))
}
DatabaseHandler.getClient("TestClient3")?.let {
assert(DatabaseHandler.checkCredentials(it.uuid, it.password_hash))
assert(DatabaseHandler.checkCredentials(it.uuid, it.passwordHash))
}
}
@ -94,7 +112,7 @@ internal class DatabaseHandlerTest {
DatabaseHandler.getClient(client)?.let {
val pw = DatabaseHandler.doHeartbeat(it.uuid)
val new_record = DatabaseHandler.getClient(client)
assert(Base64.getEncoder().encodeToString(sha3.digest(Base64.getDecoder().decode(pw)))!!.contentEquals(new_record!!.password_hash))
assert(Base64.getEncoder().encodeToString(sha3.digest(Base64.getDecoder().decode(pw)))!!.contentEquals(new_record!!.passwordHash))
}
}
}