diff --git a/server/config.go b/server/config.go index 1d510b1..5fbb892 100644 --- a/server/config.go +++ b/server/config.go @@ -21,6 +21,9 @@ func (c ProducerListenConfig) Addr() string { type ProducerUploadConfig struct { BodyLimit int `envkey:"BODY_LIMIT" default:"1073741824"` BaseURL string `envkey:"BASE_URL"` + // Default expiration time in seconds + DefaultExpirationTime int `envkey:"DEFAULT_EXPIRATION_TIME" default:"1800"` + AdminToken string `envkey:"ADMIN_TOKEN"` } type ProducerS3Config struct { diff --git a/server/files.go b/server/files.go index 2d54ea8..c08a697 100644 --- a/server/files.go +++ b/server/files.go @@ -4,6 +4,7 @@ package server import ( "context" + "crypto/subtle" "errors" "fmt" "io" @@ -11,6 +12,7 @@ import ( "net/url" "strconv" "strings" + "time" "github.com/danielgtaylor/huma/v2" "github.com/gabriel-vasile/mimetype" @@ -22,13 +24,32 @@ type UploadData struct { } type UploadInput struct { + Auth string `header:"Authorization"` + Expires int64 `header:"Expires"` RawBody huma.MultipartFormFiles[UploadData] } type UploadOutput struct { Body struct { - ID string `json:"id"` - URL string `json:"url"` + ID string `json:"id"` + URL string `json:"url"` + Expires int `json:"expires"` + } +} + +var BEARER_PREFIX = "Bearer " + +func isAuthenticated(authorization string) bool { + if CONFIG.Upload.AdminToken == "" { + return false + } + + if authorization[:len(BEARER_PREFIX)] == BEARER_PREFIX { + header_token := authorization[len(BEARER_PREFIX):] + + return subtle.ConstantTimeCompare([]byte(header_token), []byte(CONFIG.Upload.AdminToken)) == 1 + } else { + return false } } @@ -52,7 +73,12 @@ func Upload(ctx context.Context, input *UploadInput) (*UploadOutput, error) { mime := mimetype.Detect(det_buf[:n]) fd.Seek(0, 0) - err = UploadToS3(ctx, fd, file_id.String(), file.Filename, file.Size, mime.String()) + expires := time.Now().Add(time.Duration(CONFIG.Upload.DefaultExpirationTime) * time.Second) + if isAuthenticated(input.Auth) { + expires = time.Unix(input.Expires, 0) + } + + err = UploadToS3(ctx, fd, file_id.String(), file.Filename, file.Size, mime.String(), expires) if err != nil { return nil, err } @@ -60,6 +86,7 @@ func Upload(ctx context.Context, input *UploadInput) (*UploadOutput, error) { resp := &UploadOutput{} resp.Body.ID = file_id.String() resp.Body.URL = fmt.Sprintf("%s/%s", CONFIG.Upload.BaseURL, file_id.String()) + resp.Body.Expires = int(expires.Unix()) return resp, nil } diff --git a/server/producer_test.go b/server/producer_test.go index 8093125..ba2b597 100644 --- a/server/producer_test.go +++ b/server/producer_test.go @@ -11,6 +11,7 @@ import ( "net/url" "regexp" "testing" + "time" "github.com/danielgtaylor/huma/v2/humatest" "github.com/gabriel-vasile/mimetype" @@ -67,6 +68,12 @@ func TestUploadDownload(t *testing.T) { filename := "test file" upload_data := uploadData(t, api, test_data, filename) + expected_expires := time.Now().Add(time.Duration(CONFIG.Upload.DefaultExpirationTime) * time.Second).Unix() + + if expected_expires+1 < int64(upload_data.Body.Expires) || expected_expires-1 > int64(upload_data.Body.Expires) { + t.Fatalf("Expected expire value of %d (+-1), found %d", expected_expires, upload_data.Body.Expires) + } + path := fmt.Sprintf("/%s", upload_data.Body.ID) resp := assertRespCode(t, api.Get(path), 200) diff --git a/server/s3.go b/server/s3.go index f3fe7a3..687fc3e 100644 --- a/server/s3.go +++ b/server/s3.go @@ -6,6 +6,7 @@ package server import ( "context" "io" + "time" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" @@ -19,7 +20,7 @@ func getS3Client() (*minio.Client, error) { return client, err } -func UploadToS3(ctx context.Context, file io.Reader, file_id string, filename string, filesize int64, content_type string) error { +func UploadToS3(ctx context.Context, file io.Reader, file_id string, filename string, filesize int64, content_type string, expires time.Time) error { client, err := getS3Client() if err != nil { return err @@ -35,6 +36,7 @@ func UploadToS3(ctx context.Context, file io.Reader, file_id string, filename st "Filename": filename, "Type": content_type, }, + Expires: expires, }, ) getLogger().Printf("upload info: %+v\n", info)