1
use crate::{
2
    client::Client,
3
    error::Error,
4
    types::{CheckRunEvent, IssueCommentEvent, PullRequestEvent},
5
};
6
use axum::{
7
    Json, Router,
8
    extract::State,
9
    http::{HeaderMap, HeaderValue, StatusCode},
10
    routing::{get, post},
11
};
12
use hmac::{Hmac, Mac};
13
use serde::{Deserialize, Serialize};
14
use std::net::SocketAddr;
15
use std::sync::Arc;
16
use tokio::{net::TcpListener, signal, sync::Mutex, time::Duration};
17
use tower_http::trace::TraceLayer;
18
use tracing::{debug, error, info, warn};
19

            
20
mod hex;
21
#[cfg(test)]
22
mod test;
23
mod tls;
24

            
25
pub const SERVER_STATUS_OK: &str = "ok";
26
pub const SERVER_STATUS_ERROR: &str = "error";
27
pub const SERVER_MESSAGE_OK: &str = "Server is running fine";
28

            
29
/// Options for the http server
30
#[derive(Serialize, Deserialize, Debug)]
31
#[serde(default, rename_all = "kebab-case")]
32
pub struct ServerOptions {
33
    /// Port to bind to, defaults to 8080
34
    #[serde(default = "default_port")]
35
    pub port: u16,
36

            
37
    /// Optional ssl configuration for the server
38
    pub ssl: SSLOptions,
39

            
40
    /// Shared webhook secret for verifying the webhook sender
41
    pub webhook_secret: Option<String>,
42

            
43
    /// Refresh check runs periodically instead of on every webhook event
44
    /// This is useful for reducing the number of API calls to GitHub.
45
    /// When set to zero, periodic refresh is disabled.
46
    /// Unit is in seconds.
47
    #[serde(default = "Default::default")]
48
    pub periodic_refresh: u64,
49
}
50

            
51
10
fn default_port() -> u16 {
52
10
    8080
53
10
}
54

            
55
impl ServerOptions {
56
    /// Validate the server options
57
5
    pub fn validate(&self) -> Result<(), &'static str> {
58
5
        if self.port == 0 {
59
            return Err("Port can't be 0");
60
5
        }
61
5
        self.ssl.validate()
62
5
    }
63
}
64

            
65
impl Default for ServerOptions {
66
8
    fn default() -> Self {
67
8
        Self {
68
8
            port: default_port(),
69
8
            webhook_secret: std::env::var("CERBERUS_WEBHOOK_SECRET").ok(),
70
8
            ssl: SSLOptions::default(),
71
8
            periodic_refresh: 0,
72
8
        }
73
8
    }
74
}
75

            
76
/// SSL configuration for the server
77
#[derive(Serialize, Deserialize, Debug, Default)]
78
#[serde(default)]
79
pub struct SSLOptions {
80
    /// Whether to enable SSL, defaults to false
81
    pub enabled: bool,
82
    /// Path to the SSL private key file
83
    pub key: String,
84
    /// Path to the SSL certificate file
85
    pub cert: String,
86
}
87

            
88
impl SSLOptions {
89
    /// Validate the SSL options
90
5
    pub fn validate(&self) -> Result<(), &'static str> {
91
5
        if !self.enabled {
92
5
            return Ok(());
93
        }
94
        if self.key.is_empty() || self.cert.is_empty() {
95
            return Err("Incomplete SSL configuration: cert and key must be set if SSL is enabled");
96
        }
97
        Ok(())
98
5
    }
99
}
100

            
101
/// Job for refreshing check runs
102
#[derive(Debug, Ord, PartialEq, PartialOrd, Eq)]
103
struct Job {
104
    app_installation_id: u64,
105
    repo: String,
106
    commit: String,
107
}
108

            
109
/// HTTP Server for receiving webhook events from GitHub
110
pub struct Server {
111
    options: ServerOptions,
112
}
113

            
114
#[derive(Clone)]
115
struct ServerState {
116
    webhook_secret: Option<String>,
117
    github: Arc<Client>,
118
    job_queue: Arc<Mutex<Vec<Job>>>,
119
    use_job_queue: bool,
120
}
121

            
122
impl ServerState {
123
    /// Create a new server state with the given webhook secret and GitHub client
124
8
    fn new(webhook_secret: Option<String>, github: Client) -> Self {
125
8
        let github = Arc::new(github);
126
8
        Self {
127
8
            webhook_secret,
128
8
            github,
129
8
            job_queue: Arc::new(Mutex::new(Vec::new())),
130
8
            use_job_queue: false,
131
8
        }
132
8
    }
133

            
134
    /// Create a new pending job and add it to the job queue
135
2
    async fn new_job(&self, app_installation_id: u64, repo: &str, commit: &str) {
136
2
        let job = Job {
137
2
            app_installation_id,
138
2
            repo: repo.to_string(),
139
2
            commit: commit.to_string(),
140
2
        };
141
2
        let mut job_queue = self.job_queue.lock().await;
142
2
        job_queue.push(job);
143
2
    }
144

            
145
    /// Start a background task that periodically runs all jobs in the queue
146
1
    fn periodically_run_job_queue(&mut self, period: u64) {
147
1
        let job_queue = self.job_queue.clone();
148
1
        let github = self.github.clone();
149

            
150
1
        info!(
151
1
            "Periodic refresh of check runs enabled with a period of {} seconds",
152
            period,
153
        );
154

            
155
1
        self.use_job_queue = true;
156
1
        tokio::spawn(async move {
157
1
            let period = Duration::from_secs(period);
158
            loop {
159
2
                tokio::time::sleep(period).await;
160

            
161
1
                let mut job_queue = job_queue.lock().await;
162
1
                if job_queue.is_empty() {
163
                    continue;
164
1
                }
165

            
166
1
                deduplicate_jobs(job_queue.as_mut());
167

            
168
1
                info!("Running {} jobs in the queue", job_queue.len());
169

            
170
1
                for job in job_queue.drain(..) {
171
1
                    if let Err(e) = github
172
1
                        .refresh_check_run_status(job.app_installation_id, &job.repo, &job.commit)
173
1
                        .await
174
                    {
175
                        error!(
176
                            "Failed to refresh check run status for job: '{}' - '{}': {}",
177
                            job.repo, job.commit, e
178
                        );
179
1
                    }
180
                }
181
            }
182
        });
183
1
    }
184
}
185

            
186
impl Server {
187
    /// Create a new server with the given options and GitHub client
188
3
    pub fn new(options: ServerOptions) -> Self {
189
3
        Self { options }
190
3
    }
191

            
192
    /// Run the server
193
    /// Server will shutdown gracefully on Ctrl+C or SIGTERM
194
3
    pub async fn run(&self, github: Client) -> Result<(), Error> {
195
3
        let mut state = ServerState::new(self.options.webhook_secret.clone(), github);
196
3
        if self.options.periodic_refresh > 0 {
197
            state.periodically_run_job_queue(self.options.periodic_refresh);
198
3
        }
199
3
        let router = new_router(state);
200

            
201
3
        let addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], self.options.port));
202
3
        info!("Starting server on {}", addr);
203

            
204
3
        if self.options.ssl.enabled {
205
            let listener =
206
                tls::TlsListener::bind(addr, &self.options.ssl.key, &self.options.ssl.cert)
207
                    .await
208
                    .map_err(|e| Error::BindPort(Box::new(e)))?;
209

            
210
            axum::serve(listener, router)
211
                .with_graceful_shutdown(shutdown_signal())
212
                .await
213
                .map_err(Error::Serve)
214
        } else {
215
3
            let listener = TcpListener::bind(addr)
216
3
                .await
217
3
                .map_err(|e| Error::BindPort(Box::new(e)))?;
218

            
219
3
            axum::serve(listener, router)
220
3
                .with_graceful_shutdown(shutdown_signal())
221
3
                .await
222
                .map_err(Error::Serve)
223
        }
224
    }
225
}
226

            
227
3
fn new_router(state: ServerState) -> Router {
228
3
    let webhook_router: Router = Router::new()
229
3
        .route("/webhook", post(webhook_handler))
230
3
        .with_state(state)
231
3
        .layer(TraceLayer::new_for_http());
232

            
233
    // Do not use tracing for the health check endpoint
234
3
    let health_router: Router = Router::new().route("/healthz", get(healthz));
235

            
236
3
    Router::new().merge(webhook_router).merge(health_router)
237
3
}
238

            
239
/// Expose health check endpoint
240
/// Can be used when running under kubernetes to check if the server is running
241
/// GET /healthz
242
async fn healthz() -> (StatusCode, Json<Response>) {
243
    (StatusCode::OK, Json(Response::new()))
244
}
245

            
246
/// Handle the webhook events send from GitHub
247
/// POST /webhook
248
6
async fn webhook_handler(
249
6
    headers: HeaderMap,
250
6
    state: State<ServerState>,
251
6
    payload: String,
252
6
) -> (StatusCode, Json<Response>) {
253
6
    let event = match headers.get("X-GitHub-Event") {
254
6
        Some(event) => event
255
6
            .to_str()
256
6
            .unwrap_or("could not read X-GitHub-Event header"),
257
        None => {
258
            return (
259
                StatusCode::BAD_REQUEST,
260
                Json(Response::error("Missing X-GitHub-Event header")),
261
            );
262
        }
263
    };
264
6
    debug!("Received webhook event: {}", event);
265
6
    if let Err(e) = verify_webhook(
266
6
        headers.get("X-Hub-Signature-256"),
267
6
        state.webhook_secret.as_deref(),
268
6
        &payload,
269
6
    ) {
270
        warn!("Failed to verify webhook signature: {}", e.1.message);
271
        return e;
272
6
    }
273

            
274
6
    match event {
275
6
        "check_run" => handle_check_run_event(state.0, &payload).await,
276
3
        "pull_request" => handle_pull_request_event(&state.github, &payload).await,
277
2
        "issue_comment" => handle_issue_comment_event(&state.github, &payload).await,
278
        event => {
279
            let message = format!("Received unsupported event: {event}");
280
            info!("{message}");
281
            (StatusCode::NOT_IMPLEMENTED, Json(Response::error(&message)))
282
        }
283
    }
284
6
}
285

            
286
/// Verify the webhook request against the shared secret
287
12
fn verify_webhook(
288
12
    signature: Option<&HeaderValue>,
289
12
    secret: Option<&str>,
290
12
    payload: &str,
291
12
) -> Result<(), (StatusCode, Json<Response>)> {
292
12
    let secret = match secret {
293
4
        Some(s) => s,
294
        None => {
295
8
            return Ok(());
296
        }
297
    };
298

            
299
4
    let signature = match signature {
300
3
        Some(s) => s.to_str().map_err(|e| {
301
            info!("Failed to read X-Hub-Signature-256 header: {e}");
302
            (
303
                StatusCode::FORBIDDEN,
304
                Json(Response::error("Invalid X-Hub-Signature-256 header")),
305
            )
306
        })?,
307
        None => {
308
1
            return Err((
309
1
                StatusCode::FORBIDDEN,
310
1
                Json(Response::error("Missing X-Hub-Signature-256 header")),
311
1
            ));
312
        }
313
    };
314
3
    let signature = signature.strip_prefix("sha256=").unwrap_or(signature);
315
3
    let signature = hex::decode_hex(signature).map_err(|_| {
316
1
        (
317
1
            StatusCode::FORBIDDEN,
318
1
            Json(Response::error("Invalid X-Hub-Signature-256 header")),
319
1
        )
320
1
    })?;
321

            
322
2
    let mut mac = Hmac::<sha2::Sha256>::new_from_slice(secret.as_bytes()).map_err(|e| {
323
        error!("Failed to create HMAC from secret: {e}");
324
        (
325
            StatusCode::INTERNAL_SERVER_ERROR,
326
            Json(Response::error("Failed to create HMAC from secret")),
327
        )
328
    })?;
329
2
    mac.update(payload.as_bytes());
330

            
331
2
    mac.verify_slice(signature.as_slice()).map_err(|_| {
332
1
        (
333
1
            StatusCode::FORBIDDEN,
334
1
            Json(Response::error("Invalid webhook signature")),
335
1
        )
336
1
    })?;
337

            
338
1
    Ok(())
339
12
}
340

            
341
/// Handle webhook pull_request events
342
1
async fn handle_pull_request_event(client: &Client, payload: &str) -> (StatusCode, Json<Response>) {
343
1
    let payload: PullRequestEvent = match serde_json::from_str(payload) {
344
1
        Ok(event) => event,
345
        Err(e) => {
346
            warn!("Failed to parse pull_request event payload: {e}");
347
            return (
348
                StatusCode::BAD_REQUEST,
349
                Json(Response::error("Invalid pull_request event payload")),
350
            );
351
        }
352
    };
353

            
354
1
    match payload.action.as_str() {
355
1
        "opened" | "synchronize" => {}
356
        action => {
357
            debug!("Ignoring pull_request event with action: {action}");
358
            return (StatusCode::OK, Json(Response::new()));
359
        }
360
    }
361

            
362
1
    let app_id = match payload.installation {
363
1
        Some(installation) => installation.id,
364
        None => {
365
            warn!("Missing app installation id in pull_request event");
366
            return (
367
                StatusCode::BAD_REQUEST,
368
                Json(Response::error("Missing app installation id")),
369
            );
370
        }
371
    };
372

            
373
1
    if let Err(e) = client
374
1
        .create_check_run(
375
1
            app_id,
376
1
            &payload.repository.full_name,
377
1
            &payload.pull_request.head.sha,
378
        )
379
1
        .await
380
    {
381
        error!("Failed to create check run: {e}");
382
        return (
383
            StatusCode::INTERNAL_SERVER_ERROR,
384
            Json(Response::error("Failed to create check-run")),
385
        );
386
1
    };
387
1
    info!(
388
1
        "Created check run for pull request {} - {}",
389
        payload.repository.full_name, payload.pull_request.number
390
    );
391
1
    (StatusCode::OK, Json(Response::new()))
392
1
}
393

            
394
/// Handle webhook check_run events
395
4
async fn handle_check_run_event(state: ServerState, payload: &str) -> (StatusCode, Json<Response>) {
396
4
    let payload: CheckRunEvent = match serde_json::from_str(payload) {
397
4
        Ok(event) => event,
398
        Err(e) => {
399
            warn!("Failed to parse check_run event payload: {e}");
400
            return (
401
                StatusCode::BAD_REQUEST,
402
                Json(Response::error("Invalid check_run event payload")),
403
            );
404
        }
405
    };
406

            
407
4
    if payload
408
4
        .check_run
409
4
        .app
410
4
        .is_some_and(|app| app.client_id == state.github.client_id())
411
    {
412
2
        debug!("Ignoring check_run event from our own app");
413
2
        return (StatusCode::OK, Json(Response::new()));
414
2
    }
415

            
416
2
    let app_id = match payload.installation {
417
2
        Some(installation) => installation.id,
418
        None => {
419
            warn!("Missing app installation id in check_run event");
420
            return (
421
                StatusCode::BAD_REQUEST,
422
                Json(Response::error("Missing app installation id")),
423
            );
424
        }
425
    };
426

            
427
2
    if state.use_job_queue {
428
1
        state
429
1
            .new_job(
430
1
                app_id,
431
1
                &payload.repository.full_name,
432
1
                &payload.check_run.head_sha,
433
1
            )
434
1
            .await;
435
1
        return (StatusCode::OK, Json(Response::new()));
436
1
    }
437

            
438
1
    match state
439
1
        .github
440
1
        .refresh_check_run_status(
441
1
            app_id,
442
1
            &payload.repository.full_name,
443
1
            &payload.check_run.head_sha,
444
1
        )
445
1
        .await
446
    {
447
1
        Ok(_) => (StatusCode::OK, Json(Response::new())),
448
        Err(e) => {
449
            error!("Failed to refresh check-run status: {e}");
450
            (
451
                StatusCode::INTERNAL_SERVER_ERROR,
452
                Json(Response::error("Failed to refresh check-run status")),
453
            )
454
        }
455
    }
456
4
}
457

            
458
/// Handle webhook issue_comment events
459
2
async fn handle_issue_comment_event(
460
2
    client: &Client,
461
2
    payload: &str,
462
2
) -> (StatusCode, Json<Response>) {
463
2
    let payload: IssueCommentEvent = match serde_json::from_str(payload) {
464
2
        Ok(event) => event,
465
        Err(e) => {
466
            warn!("Failed to parse issue_comment event payload: {e}");
467
            return (
468
                StatusCode::BAD_REQUEST,
469
                Json(Response::error("Invalid issue_comment event payload")),
470
            );
471
        }
472
    };
473

            
474
2
    let app_id = match payload.installation {
475
2
        Some(installation) => installation.id,
476
        None => {
477
            warn!("Missing app installation id in issue_comment event");
478
            return (
479
                StatusCode::BAD_REQUEST,
480
                Json(Response::error("Missing app installation id")),
481
            );
482
        }
483
    };
484

            
485
2
    if payload.action != "created" {
486
        debug!(
487
            "Ignoring issue_comment event with action: {}",
488
            payload.action
489
        );
490
        return (StatusCode::OK, Json(Response::new()));
491
2
    }
492

            
493
2
    if !payload.comment.body.contains("/cerberus refresh") {
494
1
        debug!("Ignoring issue comment without '/cerberus' command");
495
1
        return (StatusCode::OK, Json(Response::new()));
496
1
    }
497
1
    info!(
498
        "Received issue_comment event for issue {}: {}",
499
        payload.issue.number, payload.comment.body
500
    );
501

            
502
1
    let commit = match client
503
1
        .get_pull_request_head_commit(app_id, &payload.repository.full_name, payload.issue.number)
504
1
        .await
505
    {
506
1
        Ok(commit) => commit,
507
        Err(e) => {
508
            error!("Failed to get pull request head commit: {e}");
509
            return (
510
                StatusCode::INTERNAL_SERVER_ERROR,
511
                Json(Response::error("Failed to get pull request head commit")),
512
            );
513
        }
514
    };
515

            
516
1
    if let Err(e) = client
517
1
        .refresh_check_run_status(app_id, &payload.repository.full_name, &commit)
518
1
        .await
519
    {
520
        error!("Failed to refresh check-run status: {e}");
521
        return (
522
            StatusCode::INTERNAL_SERVER_ERROR,
523
            Json(Response::error("Failed to refresh check-run status")),
524
        );
525
1
    }
526

            
527
1
    (StatusCode::OK, Json(Response::new()))
528
2
}
529

            
530
/// Detailed status of the Webserver
531
#[derive(Debug, Serialize, Deserialize)]
532
pub struct Response {
533
    /// Status of the server.
534
    /// "ok" if everything is running fine, "error" if something is wrong.
535
    pub status: String,
536
    /// Optional message providing more details about the status.
537
    pub message: String,
538
}
539

            
540
impl Response {
541
    /// Create a new response with ok status.
542
7
    pub fn new() -> Self {
543
7
        Self {
544
7
            status: SERVER_STATUS_OK.to_string(),
545
7
            message: SERVER_MESSAGE_OK.to_string(),
546
7
        }
547
7
    }
548

            
549
    /// Create a new response with the error status.
550
6
    pub fn error(message: &str) -> Self {
551
6
        Self {
552
6
            status: SERVER_STATUS_ERROR.to_string(),
553
6
            message: message.to_string(),
554
6
        }
555
6
    }
556
}
557

            
558
/// Asynchronously wait for a shutdown signal (Ctrl+C or SIGTERM).
559
3
async fn shutdown_signal() {
560
3
    let ctrl_c = async {
561
3
        tokio::signal::ctrl_c()
562
3
            .await
563
            .expect("failed to install Ctrl+C handler");
564
    };
565

            
566
    #[cfg(unix)]
567
3
    let terminate = async {
568
3
        signal::unix::signal(signal::unix::SignalKind::terminate())
569
3
            .expect("failed to install signal handler")
570
3
            .recv()
571
3
            .await;
572
    };
573

            
574
    #[cfg(not(unix))]
575
    let terminate = std::future::pending::<()>();
576

            
577
3
    tokio::select! {
578
3
        _ = ctrl_c => {},
579
3
        _ = terminate => {},
580
    }
581
}
582

            
583
/// Remove duplicates from job queue
584
2
fn deduplicate_jobs(job_queue: &mut Vec<Job>) {
585
2
    job_queue.sort();
586
2
    job_queue.dedup();
587
2
}