1
use crate::error::Error;
2
use crate::{types::*, version};
3
use reqwest::{Client, header, header::HeaderMap, header::HeaderName, header::HeaderValue};
4
use tracing::{debug, info};
5

            
6
/// Get an installation token for the GitHub App.
7
/// API endpoint: POST /app/installations/{installation_id}/access_tokens
8
7
pub async fn get_installation_token(
9
7
    endpoint: &str,
10
7
    token: &str,
11
7
    installation_id: u64,
12
7
) -> Result<TokenResponse, Error> {
13
7
    let url = format!("{endpoint}/app/installations/{installation_id}/access_tokens");
14
7
    info!("Fetching installation token from '{url}'");
15

            
16
7
    let client = new_client_with_common_headers(token)?;
17
7
    let response = send_request(client.post(&url)).await?;
18

            
19
6
    let token: TokenResponse = response
20
6
        .json()
21
6
        .await
22
6
        .map_err(|e| Error::Parse("get_installation_token", Box::new(e)))?;
23

            
24
6
    Ok(token)
25
7
}
26

            
27
/// Fetch all check runs for a commit.
28
/// API endpoint: GET /repos/{owner}/{repo}/commits/{ref}/check-runs
29
3
pub async fn get_check_runs(
30
3
    endpoint: &str,
31
3
    token: &str,
32
3
    repo: &str,
33
3
    commit: &str,
34
3
) -> Result<Vec<CheckRun>, Error> {
35
3
    let url = format!("{endpoint}/repos/{repo}/commits/{commit}/check-runs");
36
3
    info!("Fetching check runs from '{url}'");
37

            
38
3
    let client = new_client_with_common_headers(token)?;
39
3
    let response = send_request(client.get(&url)).await?;
40
3
    let response = receive_body(response).await?;
41

            
42
3
    let check_runs: CheckRunsResponse = match serde_json::from_str(&response) {
43
3
        Ok(check_runs) => check_runs,
44
        Err(e) => {
45
            debug!("Response body: '{}'", response);
46
            return Err(Error::Parse("get_check_runs", Box::new(e)));
47
        }
48
    };
49

            
50
3
    Ok(check_runs.check_runs)
51
3
}
52

            
53
/// Create a check run for a specific commit.
54
/// API endpoint: POST /repos/{owner}/{repo}/check-runs
55
3
pub async fn create_check_run(
56
3
    endpoint: &str,
57
3
    token: &str,
58
3
    repo: &str,
59
3
    payload: &CheckRun,
60
3
) -> Result<(), Error> {
61
3
    let url = format!("{endpoint}/repos/{repo}/check-runs");
62
3
    info!("Creating check-run for '{}' at '{url}'", payload.head_sha);
63

            
64
3
    let client = new_client_with_common_headers(token)?;
65
3
    let response = send_request(client.post(&url).json(payload)).await?;
66
3
    let response = receive_body(response).await?;
67

            
68
3
    match serde_json::from_str::<CheckRun>(&response) {
69
3
        Ok(check_run) => {
70
3
            info!(
71
3
                "Created check-run '{}' for commit '{}'",
72
                check_run.id, check_run.head_sha,
73
            );
74
3
            Ok(())
75
        }
76
        Err(e) => {
77
            debug!("Response body: '{}'", response);
78
            Err(Error::Parse("create_check_run", Box::new(e)))
79
        }
80
    }
81
3
}
82

            
83
/// Update a check run for a specific commit.
84
/// API endpoint: PATCH /repos/{owner}/{repo}/check-runs/{check_run_id}
85
pub async fn update_check_run(
86
    endpoint: &str,
87
    token: &str,
88
    repo: &str,
89
    payload: &CheckRun,
90
) -> Result<(), Error> {
91
    let url = format!("{endpoint}/repos/{repo}/check-runs/{}", payload.id);
92
    info!("Updating check-run for '{}' at '{url}'", payload.head_sha);
93

            
94
    let client = new_client_with_common_headers(token)?;
95
    let response = send_request(client.patch(&url).json(payload)).await?;
96
    let response = receive_body(response).await?;
97

            
98
    match serde_json::from_str::<CheckRun>(&response) {
99
        Ok(check_run) => {
100
            info!(
101
                "Updated check-run '{}' for commit '{}'",
102
                check_run.id, check_run.head_sha,
103
            );
104
            Ok(())
105
        }
106
        Err(e) => {
107
            debug!("Response body: '{}'", response);
108
            Err(Error::Parse("update_check_run", Box::new(e)))
109
        }
110
    }
111
}
112

            
113
/// Get the current status of a pull request.
114
/// API endpoint: GET /repos/{owner}/{repo}/pulls/{pull_number}
115
1
pub async fn get_pull_request(
116
1
    endpoint: &str,
117
1
    token: &str,
118
1
    repo: &str,
119
1
    pull_number: u64,
120
1
) -> Result<PullRequestResponse, Error> {
121
1
    let url = format!("{endpoint}/repos/{repo}/pulls/{pull_number}");
122
1
    info!("Fetching pull request from '{url}'");
123

            
124
1
    let client = new_client_with_common_headers(token)?;
125
1
    let response = send_request(client.get(&url)).await?;
126
1
    let response = receive_body(response).await?;
127

            
128
1
    match serde_json::from_str::<PullRequestResponse>(&response) {
129
1
        Ok(pull_request) => Ok(pull_request),
130
        Err(e) => {
131
            debug!("Response body: '{}'", response);
132
            Err(Error::Parse("get_pull_request", Box::new(e)))
133
        }
134
    }
135
1
}
136

            
137
14
fn new_client_with_common_headers(token: &str) -> Result<Client, Error> {
138
14
    let mut headers = HeaderMap::new();
139
14
    headers.insert(
140
14
        header::ACCEPT,
141
14
        HeaderValue::from_static("application/vnd.github+json"),
142
    );
143
14
    headers.insert(
144
14
        HeaderName::from_static("x-github-api-version"),
145
14
        HeaderValue::from_static("2022-11-28"),
146
    );
147
14
    headers.insert(header::USER_AGENT, HeaderValue::from_static(version::NAME));
148
14
    if !token.is_empty() {
149
14
        let bearer = format!("Bearer {token}");
150
14
        let bearer = HeaderValue::from_str(&bearer).map_err(|_| Error::InvalidBearerToken())?;
151
14
        headers.insert(header::AUTHORIZATION, bearer);
152
    }
153
14
    Client::builder()
154
14
        .default_headers(headers)
155
14
        .build()
156
14
        .map_err(Error::CreateRequest)
157
14
}
158

            
159
14
async fn send_request(builder: reqwest::RequestBuilder) -> Result<reqwest::Response, Error> {
160
14
    let response = builder.send().await.map_err(Error::Send)?;
161

            
162
14
    if !response.status().is_success() {
163
1
        let status = response.status();
164
1
        let url = response.url().to_string();
165

            
166
1
        debug!(
167
            "Request failed with: status='{}', body='{}'",
168
            status,
169
            response.text().await.unwrap_or_default(),
170
        );
171
1
        return Err(Error::NonOkStatus(url, status));
172
13
    }
173
13
    Ok(response)
174
14
}
175

            
176
7
async fn receive_body(response: reqwest::Response) -> Result<String, Error> {
177
7
    response.text().await.map_err(Error::ReceiveBody)
178
7
}