Recently, I encountered an issue with integrating the Axum web framework and the rust-s3 library. The task at hand is to create two endpoints:
Of course, this should be done without using temporary files and without holding the entire file data in memory.
Since working with S3 requires some auxiliary objects (access settings for a specific bucket), we will encapsulate the actual work in the UploadService
structure:
#[derive(Clone)]
pub struct UploadService {
bucket: Arc<s3::Bucket>
}
To enable dependency injection (DI) in the endpoint handler, our structure needs to implement the Clone
trait. Since the service will be cloned for each request, we'll wrap s3::Bucket
in an Arc
to make cloning as cost-effective as possible.
Now, let's implement the constructor for an instance of the service:
use s3::{Bucket, Region};
use s3::creds::Credentials;
...
impl UploadService {
pub fn new() -> Self {
let bucket_name = std::env::var("UPLOAD_BUCKET_NAME")
.expect("Expected UPLOAD_BUCKET_NAME environment variable");
let region = Region::Custom {
region: std::env::var("UPLOAD_BUCKET_REGION")
.expect("Expected UPLOAD_BUCKET_REGION environment variable"),
endpoint: std::env::var("UPLOAD_BUCKET_ENDPOINT")
.expect("Expected UPLOAD_BUCKET_ENDPOINT environment variable")
};
let credentials = Credentials::new(
Some(
&std::env::var("UPLOAD_BUCKET_ACCESS_KEY")
.expect("Expected UPLOAD_BUCKET_ACCESS_KEY environment variable")
),
Some(
&std::env::var("UPLOAD_BUCKET_SECRET_KEY")
.expect("Expected UPLOAD_BUCKET_SECRET_KEY environment variable")
),
None,
None,
None
).unwrap();
let bucket = Bucket::new(&bucket_name, region, credentials).unwrap()
.with_path_style();
Self {
bucket: Arc::new(bucket)
}
}
...
The service is configured using environment variables: UPLOAD_BUCKET_NAME
, UPLOAD_BUCKET_REGION
, UPLOAD_BUCKET_ACCESS_KEY
, UPLOAD_BUCKET_SECRET_KEY
, and UPLOAD_BUCKET_ENDPOINT
. The last parameter is necessary because you are using a different S3-compatible provider (Scaleway). When using Amazon S3, you can explicitly set the desired region using one of the values from the s3::Region
enumeration (e.g., s3::Region::UsWest1
), or use s3::Region::from_str
to parse the region from a string like us-west-1
. It's worth noting that in the region enumeration, there are regions for providers beyond Amazon, such as Digital Ocean, Wasabi, and Yandex.
Now, let's move on to the most complex part: the function for uploading a file to storage.
use std::sync::{Arc, Mutex};
use std::path::Path;
use std::ffi::OsStr;
use axum::http::StatusCode;
use axum::extract::multipart::Field;
use async_hash::{Sha256, Digest};
use async_compat::CompatExt;
use futures::TryStreamExt;
use uuid::Uuid;
...
pub async fn upload<'a>(&self, field: Field<'a>) -> Result<String, StatusCode> {
let orig_filename = field.file_name()
.unwrap_or("file")
.to_owned();
let mimetype = field.content_type()
.unwrap_or("application/octet-stream")
.to_owned();
let digest = Arc::new(Mutex::new(Sha256::new()));
let mut reader = field
.map_ok(|chunk| {
if let Ok(mut digest) = digest.lock() {
digest.update(&chunk);
}
chunk
})
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))
.into_async_read()
.compat();
let tmp_filename = format!("tmp/{}.bin", Uuid::new_v4());
self.bucket.put_object_stream_with_content_type(
&mut reader,
&tmp_filename,
&mimetype
)
.await
.map_err(|err| {
log::error!("S3 upload error: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;
drop(reader); // Release digest borrow
let mut result = Err(StatusCode::INTERNAL_SERVER_ERROR);
if let Some(digest) = Arc::into_inner(digest).and_then(|m| m.into_inner().ok()) {
let digest = hex::encode(digest.finalize());
let ext = Path::new(&orig_filename).extension().and_then(OsStr::to_str);
let mut filename = if let Some(ext) = ext {
format!("{}.{}", digest, ext)
} else {
digest
};
filename.make_ascii_lowercase();
match self.bucket.copy_object_internal(&tmp_filename, &filename).await {
Ok(_) => result = Ok(format!("/uploads/{}", &filename)),
Err(err) => log::error!("S3 copy error: {:?}", err)
}
}
if let Err(err) = self.bucket.delete_object(&tmp_filename).await {
log::error!("S3 delete error: {:?}", err);
}
result
}
...
The function accepts a field from a multipart/form-data
request (request handling will be discussed later), determines the original filename and MIME type (using "file
" and "application/octet-stream
" as default values if this data is absent). The field's data is then transformed into an AsyncRead
using the async-compat
library. In the process, our stream reader calculates its SHA256 hash as it reads the stream (which will be useful in the future).
Now, we can upload the file to the S3 storage under a temporary name like "tmp/<UUID>.bin
" (where UUID is randomly generated). If an error occurs at this point, the function returns an Internal Server Error status code.
We now have the file in the S3 storage and the calculated SHA256 of its data. We can proceed to rename the file to its final name (I want to use SHA256 as the file name to prevent duplicate files in the storage). To do this, I take the HEX representation of the SHA256 and append the file extension from the original name (if it was there). The result is converted to lowercase (in case the file extension was not in lowercase), and then we perform a copy of the S3 object (since the S3 API does not have a rename function). If the copy is successful, we obtain the resulting URL of the file.
Finally, we can delete the temporary object from S3. This is done in any case, whether the copy was successful or not.
The last function of our service is to serve a file via a link (theoretically, this can be delegated to the web server, but at the very least, it's convenient to have this function for local development, and at most, we may need to implement additional business logic, such as access control for the file):
use axum::response::IntoResponse;
use axum::body::StreamBody;
use s3::error::S3Error;
...
pub async fn download(
&self,
filename: &str
) -> Result<impl IntoResponse, StatusCode> {
let stream = self.bucket.get_object_stream(filename)
.await
.map_err(|err| match err {
S3Error::HttpFailWithBody(status_code, body) => match status_code {
404 => StatusCode::NOT_FOUND,
_ => {
log::error!(
"S3 download HTTP error with code={} and body={:?}",
status_code,
body
);
StatusCode::INTERNAL_SERVER_ERROR
}
}
err => {
log::error!("S3 download error: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
}
})?;
Ok(StreamBody::from(stream.bytes))
}
}
Here, it's quite straightforward. We obtain the S3 object stream, map the error of a missing file to a 404 error in Axum, map any other errors to a 500 error, and return a StreamBody
.
Now, we just need to implement the handlers for the endpoints themselves:
use axum::{Extension, Json};
use axum::extract::{Multipart, Path};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::http::header::CACHE_CONTROL;
use axum::response::IntoResponse;
#[derive(Debug, serde::Serialize)]
pub struct UploadResponse {
pub url: String
}
pub async fn upload_file(
Extension(upload_service): Extension<UploadService>,
mut multipart: Multipart
) -> Result<impl IntoResponse, StatusCode> {
while let Some(field) = multipart.next_field().await.map_err(|_|
StatusCode::INTERNAL_SERVER_ERROR
)? {
if let Some("upload") = field.name() {
let url = upload_service.upload(field).await?;
return Ok(Json(UploadResponse { url }));
}
}
Err(StatusCode::BAD_REQUEST)
}
pub async fn download_file(
Path(path): Path<String>,
Extension(upload_service): Extension<UploadService>
) -> Result<impl IntoResponse, StatusCode> {
let body = upload_service.download(&path).await?;
let headers = HeaderMap::from_iter([
(CACHE_CONTROL, HeaderValue::from_str("max-age=31536000").unwrap()) // One year
]);
Ok((headers, body))
}
The file upload handler uploads one file at a time, expecting the field name for the uploaded file in the submitted form to be "upload
". The file download handler sets the file's cache lifetime to one year because file changes are not anticipated (if the file changes, it will have a different SHA256 and a different name).
The last step is to create a router and start the server:
use std::str::FromStr;
use axum::extract::DefaultBodyLimit;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("info"));
let upload_service = UploadService::new();
let router = axum::Router::new()
.route("/uploads", axum::routing::post(upload_file))
.route("/uploads/*path", axum::routing::get(download_file))
.layer(Extension(upload_service))
.layer(DefaultBodyLimit::max(8 * 1024 * 1024));
let address = std::env::var("HOST").expect("Expected HOST environment variable");
let port = std::env::var("PORT").expect("Expected PORT environment variable")
.parse::<u16>().expect("PORT environment variable must be an integer");
log::info!("Listening on http://{}:{}/", address, port);
axum::Server::bind(
&std::net::SocketAddr::new(
std::net::IpAddr::from_str(&address).unwrap(),
port
)
).serve(router.into_make_service()).await?;
Ok(())
}
The UploadService
instance is passed via an Extension
(DI mechanism in Axum). It may also be helpful to set the DefaultBodyLimit
because the default value of 1 MB may not suit all situations. The host and port for listening can be obtained from the corresponding environment variables.
Additionally, you may need to add some form of authorization checks to the upload (and possibly download) endpoint, but this depends on the specific requirements of your service.
Here is Cargo.toml for the project:
[package]
name = "uploader"
version = "0.1.0"
edition = "2021"
[dependencies]
log = "0.4.20"
env_logger = "0.10.0"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
axum = { version = "0.6.20", features = ["multipart"] }
serde = "1.0.188"
uuid = { version = "1.4.1", features = ["v4"] }
rust-s3 = "0.34.0-rc1"
futures = "0.3.28"
async-compat = "0.2.2"
async-hash = "0.5.1"
hex = "0.4.3"
As a bonus here an example of browser TypeScript code to upload a file:
interface UploadResponse {
url: string;
}
async function uploadFile(file: Blob, filename?: string): Promise<UploadResponse | "error"> {
const data = new FormData();
data.append("upload", file, filename);
const response = await fetch("/uploads", {
method: "post",
body: data
});
if (response.status >= 200 && response.status <= 299) {
return await response.json();
} else {
return "error";
}
}
Our service is done.