1
use crate::{
2
    client::Client,
3
    error::Error,
4
    types::{CheckRunEvent, IssueCommentEvent, PullRequestEvent},
5
};
6
use axum::{
7
    Json, Router,
8
    extract::State,
9
    http::{HeaderMap, HeaderValue, StatusCode},
10
    routing::{get, post},
11
};
12
use hmac::{Hmac, Mac};
13
use serde::{Deserialize, Serialize};
14
use std::net::SocketAddr;
15
use std::sync::Arc;
16
use tokio::{net::TcpListener, signal, sync::Mutex, time::Duration};
17
use tower_http::trace::TraceLayer;
18
use tracing::{debug, error, info, warn};
19

            
20
mod hex;
21
#[cfg(test)]
22
mod test;
23
mod tls;
24

            
25
pub const SERVER_STATUS_OK: &str = "ok";
26
pub const SERVER_STATUS_ERROR: &str = "error";
27
pub const SERVER_MESSAGE_OK: &str = "Server is running fine";
28

            
29
/// Options for the http server
30
#[derive(Serialize, Deserialize, Debug)]
31
#[serde(default, rename_all = "kebab-case")]
32
pub struct ServerOptions {
33
    /// Port to bind to, defaults to 8080
34
    #[serde(default = "default_port")]
35
    pub port: u16,
36

            
37
    /// Optional ssl configuration for the server
38
    pub ssl: SSLOptions,
39

            
40
    /// Shared webhook secret for verifying the webhook sender
41
    pub webhook_secret: Option<String>,
42

            
43
    /// Refresh check runs periodically instead of on every webhook event
44
    /// This is useful for reducing the number of API calls to GitHub.
45
    /// When set to zero, periodic refresh is disabled.
46
    /// Unit is in seconds.
47
    #[serde(default = "Default::default")]
48
    pub periodic_refresh: u64,
49
}
50

            
51
10
fn default_port() -> u16 {
52
10
    8080
53
10
}
54

            
55
impl ServerOptions {
56
    /// Validate the server options
57
5
    pub fn validate(&self) -> Result<(), &'static str> {
58
5
        if self.port == 0 {
59
            return Err("Port can't be 0");
60
5
        }
61
5
        self.ssl.validate()
62
5
    }
63
}
64

            
65
impl Default for ServerOptions {
66
8
    fn default() -> Self {
67
8
        Self {
68
8
            port: default_port(),
69
8
            webhook_secret: std::env::var("CERBERUS_WEBHOOK_SECRET").ok(),
70
8
            ssl: SSLOptions::default(),
71
8
            periodic_refresh: 0,
72
8
        }
73
8
    }
74
}
75

            
76
/// SSL configuration for the server
77
#[derive(Serialize, Deserialize, Debug, Default)]
78
#[serde(default)]
79
pub struct SSLOptions {
80
    /// Whether to enable SSL, defaults to false
81
    pub enabled: bool,
82
    /// Path to the SSL private key file
83
    pub key: String,
84
    /// Path to the SSL certificate file
85
    pub cert: String,
86
}
87

            
88
impl SSLOptions {
89
    /// Validate the SSL options
90
5
    pub fn validate(&self) -> Result<(), &'static str> {
91
5
        if !self.enabled {
92
5
            return Ok(());
93
        }
94
        if self.key.is_empty() || self.cert.is_empty() {
95
            return Err("Incomplete SSL configuration: cert and key must be set if SSL is enabled");
96
        }
97
        Ok(())
98
5
    }
99
}
100

            
101
/// Job for refreshing check runs
102
#[derive(Debug, Ord, PartialEq, PartialOrd, Eq)]
103
struct Job {
104
    app_installation_id: u64,
105
    repo: String,
106
    commit: String,
107
}
108

            
109
/// HTTP Server for receiving webhook events from GitHub
110
pub struct Server {
111
    options: ServerOptions,
112
}
113

            
114
#[derive(Clone)]
115
struct ServerState {
116
    webhook_secret: Option<String>,
117
    github: Arc<Client>,
118
    job_queue: Arc<Mutex<Vec<Job>>>,
119
    use_job_queue: bool,
120
}
121

            
122
impl ServerState {
123
    /// Create a new server state with the given webhook secret and GitHub client
124
9
    fn new(webhook_secret: Option<String>, github: Client) -> Self {
125
9
        let github = Arc::new(github);
126
9
        Self {
127
9
            webhook_secret,
128
9
            github,
129
9
            job_queue: Arc::new(Mutex::new(Vec::new())),
130
9
            use_job_queue: false,
131
9
        }
132
9
    }
133

            
134
    /// Create a new pending job and add it to the job queue
135
2
    async fn new_job(&self, app_installation_id: u64, repo: &str, commit: &str) {
136
2
        let job = Job {
137
2
            app_installation_id,
138
2
            repo: repo.to_string(),
139
2
            commit: commit.to_string(),
140
2
        };
141
2
        let mut job_queue = self.job_queue.lock().await;
142
2
        job_queue.push(job);
143
2
    }
144

            
145
    /// Start a background task that periodically runs all jobs in the queue
146
1
    fn periodically_run_job_queue(&mut self, period: u64) {
147
1
        let job_queue = self.job_queue.clone();
148
1
        let github = self.github.clone();
149

            
150
1
        info!(
151
            "Periodic refresh of check runs enabled with a period of {} seconds",
152
            period,
153
        );
154

            
155
1
        self.use_job_queue = true;
156
1
        tokio::spawn(async move {
157
1
            let period = Duration::from_secs(period);
158
            loop {
159
2
                tokio::time::sleep(period).await;
160

            
161
1
                let mut job_queue = job_queue.lock().await;
162
1
                if job_queue.is_empty() {
163
                    continue;
164
1
                }
165

            
166
1
                deduplicate_jobs(job_queue.as_mut());
167

            
168
1
                info!("Running {} jobs in the queue", job_queue.len());
169

            
170
1
                for job in job_queue.drain(..) {
171
1
                    if let Err(e) = github
172
1
                        .refresh_check_run_status(job.app_installation_id, &job.repo, &job.commit)
173
1
                        .await
174
                    {
175
                        error!(
176
                            "Failed to refresh check run status for job: '{}' - '{}': {}",
177
                            job.repo, job.commit, e
178
                        );
179
1
                    }
180
                }
181
            }
182
        });
183
1
    }
184
}
185

            
186
impl Server {
187
    /// Create a new server with the given options and GitHub client
188
3
    pub fn new(options: ServerOptions) -> Self {
189
3
        Self { options }
190
3
    }
191

            
192
    /// Run the server
193
    /// Server will shutdown gracefully on Ctrl+C or SIGTERM
194
3
    pub async fn run(&self, github: Client) -> Result<(), Error> {
195
3
        let mut state = ServerState::new(self.options.webhook_secret.clone(), github);
196
3
        if self.options.periodic_refresh > 0 {
197
            state.periodically_run_job_queue(self.options.periodic_refresh);
198
3
        }
199
3
        let router = new_router(state);
200

            
201
3
        let addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], self.options.port));
202
3
        info!("Starting server on {}", addr);
203

            
204
3
        if self.options.ssl.enabled {
205
            let listener =
206
                tls::TlsListener::bind(addr, &self.options.ssl.key, &self.options.ssl.cert)
207
                    .await
208
                    .map_err(|e| Error::BindPort(Box::new(e)))?;
209

            
210
            axum::serve(listener, router)
211
                .with_graceful_shutdown(shutdown_signal())
212
                .await
213
                .map_err(Error::Serve)
214
        } else {
215
3
            let listener = TcpListener::bind(addr)
216
3
                .await
217
3
                .map_err(|e| Error::BindPort(Box::new(e)))?;
218

            
219
3
            axum::serve(listener, router)
220
3
                .with_graceful_shutdown(shutdown_signal())
221
3
                .await
222
                .map_err(Error::Serve)
223
        }
224
    }
225
}
226

            
227
3
fn new_router(state: ServerState) -> Router {
228
3
    let webhook_router: Router = Router::new()
229
3
        .route("/webhook", post(webhook_handler))
230
3
        .with_state(state)
231
3
        .layer(TraceLayer::new_for_http());
232

            
233
    // Do not use tracing for the health check endpoint
234
3
    let health_router: Router = Router::new().route("/healthz", get(healthz));
235

            
236
3
    Router::new().merge(webhook_router).merge(health_router)
237
3
}
238

            
239
/// Expose health check endpoint
240
/// Can be used when running under kubernetes to check if the server is running
241
/// GET /healthz
242
async fn healthz() -> (StatusCode, Json<Response>) {
243
    (StatusCode::OK, Json(Response::new()))
244
}
245

            
246
/// Handle the webhook events send from GitHub
247
/// POST /webhook
248
7
async fn webhook_handler(
249
7
    headers: HeaderMap,
250
7
    state: State<ServerState>,
251
7
    payload: String,
252
7
) -> (StatusCode, Json<Response>) {
253
7
    let event = match headers.get("X-GitHub-Event") {
254
7
        Some(event) => event
255
7
            .to_str()
256
7
            .unwrap_or("could not read X-GitHub-Event header"),
257
        None => {
258
            return (
259
                StatusCode::BAD_REQUEST,
260
                Json(Response::error("Missing X-GitHub-Event header")),
261
            );
262
        }
263
    };
264
7
    debug!("Received webhook event: {}", event);
265
7
    if let Err(e) = verify_webhook(
266
7
        headers.get("X-Hub-Signature-256"),
267
7
        state.webhook_secret.as_deref(),
268
7
        &payload,
269
7
    ) {
270
        warn!("Failed to verify webhook signature: {}", e.1.message);
271
        return e;
272
7
    }
273

            
274
7
    match event {
275
7
        "check_run" => handle_check_run_event(state.0, &payload).await,
276
4
        "pull_request" => handle_pull_request_event(&state.github, &payload).await,
277
3
        "issue_comment" => handle_issue_comment_event(&state.github, &payload).await,
278
1
        "check_suite" => (StatusCode::OK, Json(Response::new())), // Ignore check_suite events
279
        event => {
280
            let message = format!("Received unsupported event: {event}");
281
            info!("{message}");
282
            (StatusCode::NOT_IMPLEMENTED, Json(Response::error(&message)))
283
        }
284
    }
285
7
}
286

            
287
/// Verify the webhook request against the shared secret
288
13
fn verify_webhook(
289
13
    signature: Option<&HeaderValue>,
290
13
    secret: Option<&str>,
291
13
    payload: &str,
292
13
) -> Result<(), (StatusCode, Json<Response>)> {
293
13
    let secret = match secret {
294
4
        Some(s) => s,
295
        None => {
296
9
            return Ok(());
297
        }
298
    };
299

            
300
4
    let signature = match signature {
301
3
        Some(s) => s.to_str().map_err(|e| {
302
            info!("Failed to read X-Hub-Signature-256 header: {e}");
303
            (
304
                StatusCode::FORBIDDEN,
305
                Json(Response::error("Invalid X-Hub-Signature-256 header")),
306
            )
307
        })?,
308
        None => {
309
1
            return Err((
310
1
                StatusCode::FORBIDDEN,
311
1
                Json(Response::error("Missing X-Hub-Signature-256 header")),
312
1
            ));
313
        }
314
    };
315
3
    let signature = signature.strip_prefix("sha256=").unwrap_or(signature);
316
3
    let signature = hex::decode_hex(signature).map_err(|_| {
317
1
        (
318
1
            StatusCode::FORBIDDEN,
319
1
            Json(Response::error("Invalid X-Hub-Signature-256 header")),
320
1
        )
321
1
    })?;
322

            
323
2
    let mut mac = Hmac::<sha2::Sha256>::new_from_slice(secret.as_bytes()).map_err(|e| {
324
        error!("Failed to create HMAC from secret: {e}");
325
        (
326
            StatusCode::INTERNAL_SERVER_ERROR,
327
            Json(Response::error("Failed to create HMAC from secret")),
328
        )
329
    })?;
330
2
    mac.update(payload.as_bytes());
331

            
332
2
    mac.verify_slice(signature.as_slice()).map_err(|_| {
333
1
        (
334
1
            StatusCode::FORBIDDEN,
335
1
            Json(Response::error("Invalid webhook signature")),
336
1
        )
337
1
    })?;
338

            
339
1
    Ok(())
340
13
}
341

            
342
/// Handle webhook pull_request events
343
1
async fn handle_pull_request_event(client: &Client, payload: &str) -> (StatusCode, Json<Response>) {
344
1
    let payload: PullRequestEvent = match serde_json::from_str(payload) {
345
1
        Ok(event) => event,
346
        Err(e) => {
347
            warn!("Failed to parse pull_request event payload: {e}");
348
            return (
349
                StatusCode::BAD_REQUEST,
350
                Json(Response::error("Invalid pull_request event payload")),
351
            );
352
        }
353
    };
354

            
355
1
    match payload.action.as_str() {
356
1
        "opened" | "synchronize" => {}
357
        action => {
358
            debug!("Ignoring pull_request event with action: {action}");
359
            return (StatusCode::OK, Json(Response::new()));
360
        }
361
    }
362

            
363
1
    let app_id = match payload.installation {
364
1
        Some(installation) => installation.id,
365
        None => {
366
            warn!("Missing app installation id in pull_request event");
367
            return (
368
                StatusCode::BAD_REQUEST,
369
                Json(Response::error("Missing app installation id")),
370
            );
371
        }
372
    };
373

            
374
1
    if let Err(e) = client
375
1
        .create_check_run(
376
1
            app_id,
377
1
            &payload.repository.full_name,
378
1
            &payload.pull_request.head.sha,
379
        )
380
1
        .await
381
    {
382
        error!("Failed to create check run: {e}");
383
        return (
384
            StatusCode::INTERNAL_SERVER_ERROR,
385
            Json(Response::error("Failed to create check-run")),
386
        );
387
1
    };
388
1
    info!(
389
1
        "Created check run for pull request {} - {}",
390
        payload.repository.full_name, payload.pull_request.number
391
    );
392
1
    (StatusCode::OK, Json(Response::new()))
393
1
}
394

            
395
/// Handle webhook check_run events
396
4
async fn handle_check_run_event(state: ServerState, payload: &str) -> (StatusCode, Json<Response>) {
397
4
    let payload: CheckRunEvent = match serde_json::from_str(payload) {
398
4
        Ok(event) => event,
399
        Err(e) => {
400
            warn!("Failed to parse check_run event payload: {e}");
401
            return (
402
                StatusCode::BAD_REQUEST,
403
                Json(Response::error("Invalid check_run event payload")),
404
            );
405
        }
406
    };
407

            
408
4
    if payload
409
4
        .check_run
410
4
        .app
411
4
        .is_some_and(|app| app.client_id == state.github.client_id())
412
    {
413
2
        debug!("Ignoring check_run event from our own app");
414
2
        return (StatusCode::OK, Json(Response::new()));
415
2
    }
416

            
417
2
    let app_id = match payload.installation {
418
2
        Some(installation) => installation.id,
419
        None => {
420
            warn!("Missing app installation id in check_run event");
421
            return (
422
                StatusCode::BAD_REQUEST,
423
                Json(Response::error("Missing app installation id")),
424
            );
425
        }
426
    };
427

            
428
2
    if state.use_job_queue {
429
1
        state
430
1
            .new_job(
431
1
                app_id,
432
1
                &payload.repository.full_name,
433
1
                &payload.check_run.head_sha,
434
1
            )
435
1
            .await;
436
1
        return (StatusCode::OK, Json(Response::new()));
437
1
    }
438

            
439
1
    match state
440
1
        .github
441
1
        .refresh_check_run_status(
442
1
            app_id,
443
1
            &payload.repository.full_name,
444
1
            &payload.check_run.head_sha,
445
1
        )
446
1
        .await
447
    {
448
1
        Ok(_) => (StatusCode::OK, Json(Response::new())),
449
        Err(e) => {
450
            error!("Failed to refresh check-run status: {e}");
451
            (
452
                StatusCode::INTERNAL_SERVER_ERROR,
453
                Json(Response::error("Failed to refresh check-run status")),
454
            )
455
        }
456
    }
457
4
}
458

            
459
/// Handle webhook issue_comment events
460
2
async fn handle_issue_comment_event(
461
2
    client: &Client,
462
2
    payload: &str,
463
2
) -> (StatusCode, Json<Response>) {
464
2
    let payload: IssueCommentEvent = match serde_json::from_str(payload) {
465
2
        Ok(event) => event,
466
        Err(e) => {
467
            warn!("Failed to parse issue_comment event payload: {e}");
468
            return (
469
                StatusCode::BAD_REQUEST,
470
                Json(Response::error("Invalid issue_comment event payload")),
471
            );
472
        }
473
    };
474

            
475
2
    let app_id = match payload.installation {
476
2
        Some(installation) => installation.id,
477
        None => {
478
            warn!("Missing app installation id in issue_comment event");
479
            return (
480
                StatusCode::BAD_REQUEST,
481
                Json(Response::error("Missing app installation id")),
482
            );
483
        }
484
    };
485

            
486
2
    if payload.action != "created" {
487
        debug!(
488
            "Ignoring issue_comment event with action: {}",
489
            payload.action
490
        );
491
        return (StatusCode::OK, Json(Response::new()));
492
2
    }
493

            
494
2
    if !payload.comment.body.contains("/cerberus refresh") {
495
1
        debug!("Ignoring issue comment without '/cerberus' command");
496
1
        return (StatusCode::OK, Json(Response::new()));
497
1
    }
498
1
    info!(
499
        "Received issue_comment event for issue {}: {}",
500
        payload.issue.number, payload.comment.body
501
    );
502

            
503
1
    let commit = match client
504
1
        .get_pull_request_head_commit(app_id, &payload.repository.full_name, payload.issue.number)
505
1
        .await
506
    {
507
1
        Ok(commit) => commit,
508
        Err(e) => {
509
            error!("Failed to get pull request head commit: {e}");
510
            return (
511
                StatusCode::INTERNAL_SERVER_ERROR,
512
                Json(Response::error("Failed to get pull request head commit")),
513
            );
514
        }
515
    };
516

            
517
1
    if let Err(e) = client
518
1
        .refresh_check_run_status(app_id, &payload.repository.full_name, &commit)
519
1
        .await
520
    {
521
        error!("Failed to refresh check-run status: {e}");
522
        return (
523
            StatusCode::INTERNAL_SERVER_ERROR,
524
            Json(Response::error("Failed to refresh check-run status")),
525
        );
526
1
    }
527

            
528
1
    (StatusCode::OK, Json(Response::new()))
529
2
}
530

            
531
/// Detailed status of the Webserver
532
#[derive(Debug, Serialize, Deserialize)]
533
pub struct Response {
534
    /// Status of the server.
535
    /// "ok" if everything is running fine, "error" if something is wrong.
536
    pub status: String,
537
    /// Optional message providing more details about the status.
538
    pub message: String,
539
}
540

            
541
impl Response {
542
    /// Create a new response with ok status.
543
8
    pub fn new() -> Self {
544
8
        Self {
545
8
            status: SERVER_STATUS_OK.to_string(),
546
8
            message: SERVER_MESSAGE_OK.to_string(),
547
8
        }
548
8
    }
549

            
550
    /// Create a new response with the error status.
551
6
    pub fn error(message: &str) -> Self {
552
6
        Self {
553
6
            status: SERVER_STATUS_ERROR.to_string(),
554
6
            message: message.to_string(),
555
6
        }
556
6
    }
557
}
558

            
559
/// Asynchronously wait for a shutdown signal (Ctrl+C or SIGTERM).
560
3
async fn shutdown_signal() {
561
3
    let ctrl_c = async {
562
3
        tokio::signal::ctrl_c()
563
3
            .await
564
            .expect("failed to install Ctrl+C handler");
565
    };
566

            
567
    #[cfg(unix)]
568
3
    let terminate = async {
569
3
        signal::unix::signal(signal::unix::SignalKind::terminate())
570
3
            .expect("failed to install signal handler")
571
3
            .recv()
572
3
            .await;
573
    };
574

            
575
    #[cfg(not(unix))]
576
    let terminate = std::future::pending::<()>();
577

            
578
3
    tokio::select! {
579
3
        _ = ctrl_c => {},
580
3
        _ = terminate => {},
581
    }
582
}
583

            
584
/// Remove duplicates from job queue
585
2
fn deduplicate_jobs(job_queue: &mut Vec<Job>) {
586
2
    job_queue.sort();
587
2
    job_queue.dedup();
588
2
}