1
use crate::{
2
    api,
3
    error::Error,
4
    types::{CHECK_RUN_CONCLUSION, CheckRun, TokenResponse},
5
};
6
use serde::{Deserialize, Serialize};
7
use std::collections::HashMap;
8
use tokio::sync::Mutex;
9
use tracing::{debug, warn};
10

            
11
#[cfg(test)]
12
mod test;
13

            
14
/// Configuration options for creating the github client
15
#[derive(Serialize, Deserialize, Debug)]
16
#[serde(rename_all = "kebab-case")]
17
pub struct ClientOptions {
18
    /// Client ID for the GitHub App
19
    pub client_id: String,
20

            
21
    /// Private key for the GitHub App
22
    pub private_key: String,
23

            
24
    /// URL to github api, defaults to "https://api.github.com"
25
    #[serde(skip_serializing_if = "str::is_empty", default = "default_api_url")]
26
    pub api: String,
27
}
28

            
29
2
fn default_api_url() -> String {
30
2
    "https://api.github.com".to_string()
31
2
}
32

            
33
impl ClientOptions {
34
    /// Validate the client options
35
5
    pub fn validate(&self) -> Result<(), &'static str> {
36
5
        if self.client_id.is_empty() {
37
            return Err("GitHub Client ID must be set in the configuration");
38
5
        }
39
5
        Ok(())
40
5
    }
41
}
42

            
43
pub struct Client {
44
    client_id: String,
45
    key: jsonwebtoken::EncodingKey,
46
    api: String,
47
    token_cache: Mutex<HashMap<u64, TokenResponse>>,
48
}
49

            
50
impl Client {
51
    /// Create a new GitHub client with the provided options.
52
    /// Will read the private key from the file system.
53
9
    pub fn build(options: ClientOptions) -> Result<Self, Error> {
54
9
        let key = std::fs::read_to_string(&options.private_key)
55
9
            .map_err(|e| Error::ReadPrivateKey(options.private_key.clone(), e))?;
56
9
        let key =
57
9
            jsonwebtoken::EncodingKey::from_rsa_pem(key.as_bytes()).map_err(Error::EncodingKey)?;
58
9
        Ok(Client {
59
9
            client_id: options.client_id,
60
9
            key,
61
9
            api: options.api,
62
9
            token_cache: Mutex::new(HashMap::new()),
63
9
        })
64
9
    }
65

            
66
    /// Return a reference to the client ID.
67
3
    pub fn client_id(&self) -> &str {
68
3
        &self.client_id
69
3
    }
70

            
71
    /// Get an installations token for the GitHub App.
72
12
    async fn get_token(&self, app_installation_id: u64) -> Result<String, Error> {
73
12
        if let Some(token) = self.get_cached_token(app_installation_id).await {
74
5
            return Ok(token);
75
7
        }
76

            
77
7
        let claims = JWTClaims::new(&self.client_id);
78
7
        let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
79
7
        let jwt = jsonwebtoken::encode(&header, &claims, &self.key).map_err(Error::JWT)?;
80
7
        let token = api::get_installation_token(&self.api, &jwt, app_installation_id).await?;
81

            
82
6
        let mut cache = self.token_cache.lock().await;
83
6
        let token_value = token.token.clone();
84
6
        cache.insert(app_installation_id, token);
85

            
86
6
        Ok(token_value)
87
12
    }
88

            
89
    /// Create a new pending check run for a commit in a repository.
90
    /// Needs to use the GitHub App installation token to authenticate.
91
1
    pub async fn create_check_run(
92
1
        &self,
93
1
        app_installation_id: u64,
94
1
        repo: &str,
95
1
        commit: &str,
96
1
    ) -> Result<(), Error> {
97
1
        let token = self.get_token(app_installation_id).await?;
98

            
99
1
        api::create_check_run(&self.api, &token, repo, &CheckRun::new(commit)).await
100
1
    }
101

            
102
    /// Refresh the check_run status based on the current status.
103
    /// Will fetch the current check-runs first and then update the check-run status.
104
    /// This means 2 API calls will be made.
105
3
    pub async fn refresh_check_run_status(
106
3
        &self,
107
3
        app_id: u64,
108
3
        repo: &str,
109
3
        commit: &str,
110
3
    ) -> Result<(), Error> {
111
3
        let (uncompleted, own_run) = self.get_check_run_status(app_id, repo, commit).await?;
112
3
        self.update_check_run(app_id, repo, commit, uncompleted, own_run)
113
3
            .await
114
3
    }
115

            
116
    /// Get the combined status of all check-runs for a commit.
117
3
    pub async fn get_check_run_status(
118
3
        &self,
119
3
        app_installation_id: u64,
120
3
        repo: &str,
121
3
        commit: &str,
122
3
    ) -> Result<(u32, Option<CheckRun>), Error> {
123
3
        let check_runs = self
124
3
            .get_check_runs(app_installation_id, repo, commit)
125
3
            .await?;
126
3
        debug!(
127
2
            "Found {} check runs for commit '{}' in repository '{}'",
128
2
            check_runs.len(),
129
            commit,
130
            repo
131
        );
132

            
133
3
        Ok(self.overall_check_status(&check_runs))
134
3
    }
135

            
136
    /// Update the status of the check-run if necessary.
137
3
    pub async fn update_check_run(
138
3
        &self,
139
3
        app_installation_id: u64,
140
3
        repo: &str,
141
3
        commit: &str,
142
3
        count: u32,
143
3
        check_run: Option<CheckRun>,
144
3
    ) -> Result<(), Error> {
145
3
        let token = self.get_token(app_installation_id).await?;
146

            
147
3
        match check_run {
148
1
            Some(mut run) => {
149
1
                if run.update_status(count) {
150
                    api::update_check_run(&self.api, &token, repo, &run).await
151
                } else {
152
1
                    debug!("No changes to check run status, skipping update");
153
1
                    Ok(())
154
                }
155
            }
156
            None => {
157
2
                warn!("No check run found to update, creating a new one");
158
2
                let mut run = CheckRun::new(commit);
159
2
                run.update_status(count);
160
2
                api::create_check_run(&self.api, &token, repo, &run).await
161
            }
162
        }
163
3
    }
164

            
165
    /// Get the current head commit for a pull request.
166
1
    pub async fn get_pull_request_head_commit(
167
1
        &self,
168
1
        app_installation_id: u64,
169
1
        repo: &str,
170
1
        pull_number: u64,
171
1
    ) -> Result<String, Error> {
172
1
        let token = self.get_token(app_installation_id).await?;
173

            
174
1
        let pr = api::get_pull_request(&self.api, &token, repo, pull_number).await?;
175

            
176
1
        Ok(pr.head.sha)
177
1
    }
178

            
179
    /// Return a list of current check runs for a commit in a repository.
180
    /// Needs to use the GitHub App installation token to authenticate.
181
3
    async fn get_check_runs(
182
3
        &self,
183
3
        app_installation_id: u64,
184
3
        repo: &str,
185
3
        commit: &str,
186
3
    ) -> Result<Vec<CheckRun>, Error> {
187
3
        let token = self.get_token(app_installation_id).await?;
188

            
189
3
        api::get_check_runs(&self.api, &token, repo, commit).await
190
3
    }
191

            
192
    /// Check a collection of check runs and returns the number of uncompleted check runs.
193
    /// Additionally returns the check run created by this app. If there are multiple check-runs, the first will be returned.
194
3
    fn overall_check_status(&self, check_runs: &[CheckRun]) -> (u32, Option<CheckRun>) {
195
3
        if check_runs.is_empty() {
196
            warn!("Received empty check-runs list");
197
            return (0, None);
198
3
        }
199
3
        let mut uncompleted = 0;
200
3
        let mut own_check_run: Option<CheckRun> = None;
201

            
202
6
        for run in check_runs {
203
3
            if run
204
3
                .app
205
3
                .as_ref()
206
3
                .is_some_and(|app| app.client_id == self.client_id)
207
            {
208
                // This is a check run created by this app
209
1
                if own_check_run.is_none() {
210
1
                    own_check_run = Some(run.clone());
211
1
                } else {
212
                    warn!(
213
                        "Found multiple check runs created by this app: '{}' and '{}, commit: '{}'",
214
                        own_check_run.as_ref().unwrap().name,
215
                        run.name,
216
                        run.head_sha
217
                    );
218
                }
219
1
                debug!("Found own check run: {}", run.id);
220
1
                continue;
221
2
            }
222
2
            match run.status.as_str() {
223
2
                "completed" => {
224
                    if run
225
                        .conclusion
226
                        .as_ref()
227
                        .is_some_and(|v| v == CHECK_RUN_CONCLUSION || v == "skipped")
228
                    {
229
                        debug!("Check run '{}' is completed successfully", run.name);
230
                    } else {
231
                        debug!(
232
                            "Check run '{}' is completed not successfull: '{}'",
233
                            run.name,
234
                            run.conclusion.as_deref().unwrap_or("unknown")
235
                        );
236
                        uncompleted += 1;
237
                    }
238
                }
239
                _ => {
240
2
                    debug!(
241
2
                        "Check run '{}' is not completed, status: {}",
242
                        run.name, run.status
243
                    );
244
2
                    uncompleted += 1;
245
                }
246
            }
247
        }
248
3
        (uncompleted, own_check_run)
249
3
    }
250

            
251
    /// Check the cache for a token and return it if it exists.
252
12
    async fn get_cached_token(&self, app_installation_id: u64) -> Option<String> {
253
12
        let cache = self.token_cache.lock().await;
254
12
        if let Some(token) = cache.get(&app_installation_id) {
255
6
            let now = chrono::Utc::now() + chrono::Duration::seconds(30);
256
6
            if token.expires_at.ge(&now) {
257
5
                debug!(
258
2
                    "Using cached token for installation ID: {}",
259
                    app_installation_id
260
                );
261
5
                return Some(token.token.clone());
262
1
            }
263
1
            debug!(
264
                "Cached token for installation ID {} is expired, fetching a new one",
265
                app_installation_id
266
            );
267
6
        }
268
7
        None
269
12
    }
270

            
271
    #[cfg(test)]
272
4
    pub fn new_for_testing(client_id: &str, secret: &str, api: &str) -> Self {
273
4
        let key = jsonwebtoken::EncodingKey::from_secret(secret.as_bytes());
274

            
275
4
        Client {
276
4
            client_id: client_id.to_string(),
277
4
            key,
278
4
            api: api.to_string(),
279
4
            token_cache: Mutex::new(HashMap::new()),
280
4
        }
281
4
    }
282
}
283

            
284
#[derive(Debug, Serialize, Deserialize)]
285
struct JWTClaims {
286
    /// Issued At
287
    /// Recommended to be 60 seconds in the past to account for clock drift
288
    iat: u64,
289
    /// Expires At
290
    /// Maximum of 10 minutes in the future
291
    exp: u64,
292
    /// Issuer
293
    /// The GitHub App's client ID
294
    iss: String,
295
}
296

            
297
impl JWTClaims {
298
    /// Create a new JWT claims object with the issued time 30s in the past
299
7
    pub fn new(client_id: &str) -> Self {
300
7
        debug!("Creating JWT claims for client ID: {}", client_id);
301
7
        let now = jsonwebtoken::get_current_timestamp();
302
7
        let iat = now - 30;
303
7
        let exp = now + 2 * 60;
304
7
        JWTClaims {
305
7
            iat,
306
7
            exp,
307
7
            iss: client_id.to_string(),
308
7
        }
309
7
    }
310
}