1 module aws.aws;
2 
3 import std.algorithm;
4 import std.datetime;
5 import std.random;
6 import std.range;
7 import std.stdio;
8 import std.string;
9 import std.conv;
10 
11 import std.digest.sha;
12 import aws.sigv4;
13 import std.math;
14 
15 import arsd.dom;
16 
17 public import aws.credentials;
18 
19 auto safeInnerText(Element node) {
20     if (node is null)
21         return null;
22     return node.innerText;
23 }
24 
25 class AWSException : Exception
26 {
27     immutable string type;
28     immutable bool retriable;
29 
30     this(string type, bool retriable, string message, string file = __FILE__, size_t line = __LINE__, Throwable next = null)
31     {
32         super(type ~ ": " ~ message, file, line, next);
33         this.type = type;
34         this.retriable = retriable;
35     }
36 
37     /**
38       Returns the 'ThrottlingException' from 'com.amazon.coral.service#ThrottlingException'
39      */
40     @property string simpleType()
41     {
42         auto h = type.indexOf('#');
43         if (h == -1) return type;
44         return type[h+1..$];
45     }
46 }
47 
48 /**
49   Thrown when the signature/authorization information is wrong
50  */
51 class AuthorizationException : AWSException
52 {
53     this(string type, string message, string file = __FILE__, size_t line = __LINE__, Throwable next = null)
54     {
55         super(type, false, message, file, line, next);
56     }
57 }
58 
59 abstract class RESTClient {
60     import std.range : isInputRange, hasLength;
61     import requests : Response;
62     immutable string endpoint;
63     immutable string baseUri;
64     immutable string region;
65     immutable string service;
66 
67     private AWSCredentialSource m_credsSource;
68 
69     this(string endpoint, string region, string service, AWSCredentialSource credsSource) nothrow @safe
70     {
71         this.region = region;
72         if (endpoint.startsWith("http://")) {
73             this.baseUri = endpoint;
74             this.endpoint = endpoint[7..$];
75         } else if (endpoint.startsWith("https://")) {
76             this.baseUri = endpoint;
77             this.endpoint = endpoint[8..$];
78         } else {
79             this.baseUri = "https://"~endpoint;
80             this.endpoint = endpoint;
81         }
82         this.service = service;
83         this.m_credsSource = credsSource;
84     }
85 
86     private static string buildQueryParameterString(string[string] queryParameters)
87     {
88         import std.uri : encodeComponent;
89 
90         return queryParameters
91             .byKeyValue
92             .map!(kv => only(encodeComponent(kv.key), encodeComponent(kv.value)).joiner("="))
93             .joiner("&")
94             .text();
95     }
96 
97     Response doRequest(string method, string resource, string[string] queryParameters, string[string] headers) shared {
98         return (cast()this).doRequest(method, resource, queryParameters, headers);
99     }
100     Response doRequest(string method, string resource, string[string] queryParameters, string[string] headers)
101     {
102         import requests : Request;
103         if (!resource.startsWith("/"))
104             resource = "/" ~ resource;
105 
106         //Initialize credentials
107         auto creds = m_credsSource.credentials(region ~ "/" ~ service);
108         auto queryString = buildQueryParameterString(queryParameters);
109 
110         auto url = baseUri ~ resource;
111         auto req = Request();
112 
113         req.method = method;
114         req.addHeaders(headers);
115         req.useStreaming = true;
116         ubyte[] reqBody = null;
117         req.addHeaders(["host": endpoint]);
118         if (creds.sessionToken && !creds.sessionToken.empty)
119             req.addHeaders(["x-amz-security-token": creds.sessionToken]);
120 
121         req.addHeaders(signRequest2(resource, method, req.headers, queryParameters, null, creds, region, service));
122 
123         return req.execute(method, url~"?"~queryString);
124     }
125 
126     Response doUpload(Range)(string method, string resource, string[string] queryParameters,
127                              string[string] headers, in string[] additionalSignedHeaders,
128                              scope Range payload, ulong blockSize = 512*1024) if (isInputRange!Range && hasLength!Range) {
129         return doUpload(method, resource, queryParameters, headers, additionalSignedHeaders, payload, payload.length, blockSize);
130     }
131 
132     Response doUpload(Range)(string method, string resource, string[string] queryParameters,
133                              string[string] headers, in string[] additionalSignedHeaders,
134                              scope Range payload, size_t payloadSize, ulong blockSize = 512*1024) if (isInputRange!Range) {
135         import requests : Request;
136 
137         //Calculate the body size upfront for the "Content-Length" header
138         auto base16 = (ulong x) => ceil(log2(x)/4).to!ulong;
139         enum ulong signatureSize = ";chunk-signature=".length + 64;
140         immutable ulong numFullSizeBlocks = payloadSize / blockSize;
141         immutable ulong lastBlockSize = payloadSize % blockSize;
142         immutable ulong bodySize =  numFullSizeBlocks * (base16(blockSize)  + signatureSize + 4 + blockSize) //Full-Sized blocks (4 = 2*"\r\n")
143             + (lastBlockSize  ? (base16(lastBlockSize) + signatureSize + 4 + lastBlockSize) : 0) //Part-Sized last block
144             + (1 + signatureSize + 4); //Finishing 0-sized block
145 
146 
147         if (!resource.startsWith("/"))
148             resource = "/" ~ resource;
149 
150         //Initialize credentials
151         auto creds = m_credsSource.credentials(region ~ "/" ~ service);
152 
153         auto url = baseUri ~ resource;
154         auto req = Request();
155 
156         req.method = method;
157         req.addHeaders(headers);
158         req.useStreaming = true;
159         ubyte[] reqBody = null;
160         req.addHeaders(["host": endpoint]);
161         if (creds.sessionToken && !creds.sessionToken.empty)
162             req.addHeaders(["x-amz-security-token": creds.sessionToken]);
163 
164         req.addHeaders(signRequest2(url, method, req.headers, queryParameters, null, creds, region, service));
165 
166         auto isoTimeString = currentTimeString();
167         req.addHeaders(["x-amz-date": isoTimeString]);
168         auto date = isoTimeString.dateFromISOString;
169         auto time = isoTimeString.timeFromISOString;
170 
171         if ("content-type" !in headers)
172             req.addHeaders(["content-type": "application/octet-stream"]);
173 
174         if (payloadSize > 0)
175             req.addHeaders(["x-amz-decoded-content-length": payloadSize.to!string,
176                             "x-amz-content-sha256": streaming_payload_hash,
177                             "content-length": bodySize.to!string]);
178         else {
179             req.addHeaders(["x-amz-content-sha256": "UNSIGNED-PAYLOAD",
180                             "content-length": "0"]);
181         }
182 
183         auto canonicalRequest = CanonicalRequest(
184                                                  method.to!string,
185                                                  resource,
186                                                  queryParameters,
187                                                  [
188                                                   "host":                         req.headers["host"],
189                                                   // "content-encoding":             req.headers["content-encoding"],
190                                                   "content-length":               req.headers["content-length"],
191                                                   "x-amz-date":                   req.headers["x-amz-date"],
192                                                   ]
193                                                  );
194         canonicalRequest.setStreamingPayloadHash(payloadSize.to!string);
195 
196         foreach (key; additionalSignedHeaders)
197             canonicalRequest.headers[key] = req.headers[key];
198 
199         //Calculate the seed signature
200         auto signableRequest = SignableRequest(date, time, region, service, canonicalRequest);
201         auto key = signingKey(creds.accessKeySecret, date, region, service);
202         auto binarySignature = key.sign(cast(ubyte[])signableRequest.signableString);
203 
204         auto credScope = date ~ "/" ~ region ~ "/" ~ service;
205         auto authHeader = createSignatureHeader(creds.accessKeyID, credScope, canonicalRequest.headers, binarySignature);
206         req.addHeaders(["authorization": authHeader]);
207 
208         string signature = binarySignature.toHexString().toLower();
209         auto extension = (ubyte[] data) @trusted
210             {
211                 // has to be trusted because compiler things toLower escapes the stack allocated hex-string
212                 auto chunk = SignableChunk(date, time, region, service, signature, hash(data));
213                 signature = key.sign(chunk.signableString.representation).toHexString().toLower();
214                 return text(";chunk-signature=", signature);
215             };
216 
217         auto chunked = payload.chunkedContent(blockSize, extension);
218 
219         return req.execute(method, url, chunked);
220     }
221 
222     Document readXML(Response response)
223     {
224         import std.algorithm : joiner;
225         import std.array : array;
226         ubyte[] content = response.receiveAsRange().joiner.array;
227         return new Document(cast(string)content);
228     }
229 
230     void checkForError(Response response, string file = __FILE__, size_t line = __LINE__, Throwable next = null)
231     {
232         if (response.code < 400)
233             return; // No error
234 
235         auto document = readXML(response);
236         auto code = document.querySelector("error code").safeInnerText;
237         auto message = document.querySelector("error message").safeInnerText;
238         throw makeException(code, response.code / 100 == 5, message, file, line, next);
239     }
240 
241     AWSException makeException(string type, bool retriable, string message,
242         string file = __FILE__, size_t line = __LINE__, Throwable next = null)
243     {
244         if (type == "UnrecognizedClientException" || type == "InvalidSignatureException")
245             throw new AuthorizationException(type, message, file, line, next);
246         return new AWSException(type, retriable, message, file, line, next);
247     }
248 }
249 
250 private auto currentTimeString()
251 {
252     auto t = Clock.currTime(UTC());
253     t.fracSecs = 0.seconds;
254     return t.toISOString();
255 }
256 
257 private string[string] signRequest2(string uri, string method, string[string] headers, string[string] queryParameters,
258                          in ubyte[] requestBody, AWSCredentials creds, string region, string service)
259 {
260     auto timeString = currentTimeString();
261     auto dateString = dateFromISOString(timeString);
262     auto credScope = dateString ~ "/" ~ region ~ "/" ~ service;
263 
264     SignableRequest signRequest;
265     signRequest.dateString = dateString;
266     signRequest.timeStringUTC = timeFromISOString(timeString);
267     signRequest.region = region;
268     signRequest.service = service;
269     signRequest.canonicalRequest.method = method;
270 
271     auto pos = uri.indexOf("?");
272     if (pos < 0)
273         pos = uri.length;
274 
275     signRequest.canonicalRequest.uri = uri[0..pos];
276     signRequest.canonicalRequest.queryParameters = queryParameters;
277 
278     string[string] newHeaders = ["x-amz-date": timeString];
279     import std.algorithm : startsWith;
280     import std.range : chain;
281     foreach (x; chain(headers.byKeyValue, newHeaders.byKeyValue)) {
282         auto lower = x.key.toLower();
283         if (lower == "host" || lower.startsWith("x-amz-"))
284             signRequest.canonicalRequest.headers[lower] = x.value;
285     }
286     signRequest.canonicalRequest.setPayload(requestBody);
287     newHeaders["x-amz-content-sha256"] = signRequest.canonicalRequest.payloadHash;
288 
289     ubyte[32] signKey = signingKey(creds.accessKeySecret, dateString, region, service);
290     ubyte[] stringToSign = cast(ubyte[])signableString(signRequest);
291 
292     auto signature = sign(signKey, stringToSign);
293     auto authHeader = createSignatureHeader(creds.accessKeyID, credScope, signRequest.canonicalRequest.headers, signature);
294     newHeaders["authorization"] = authHeader;
295 
296     return newHeaders;
297 }
298 
299 struct ChunkedContent(Range) if (is(ElementType!Range == ubyte) || is(ElementType!Range == ubyte[])) {
300     enum Position {
301         data,
302         finalizer,
303         end
304     }
305     static if (is(ElementType!Range == ubyte)) {
306         import std.range : chunks, Chunks;
307         Chunks!Range range;
308     } else {
309         Range range;
310     }
311     alias ExtensionCallback = string delegate(ubyte[]);
312     ExtensionCallback extension;
313     static ubyte[] delimiter = ['\r','\n'];
314     Position pos;
315     this(Range range, size_t chunkSize, ExtensionCallback cb) {
316         static if (is(ElementType!Range == ubyte)) {
317             this.range = range.chunks(chunkSize);
318         } else
319             this.range = range;
320         this.extension = cb;
321         pos = range.empty ? Position.finalizer : Position.data;
322     }
323     bool empty() {
324         return pos == Position.end;
325     }
326     auto front() {
327         import std.format : format;
328         ubyte[] data;
329         if (pos == Position.data)
330             data = range.front;
331         string length = format("%x", data.length);
332         string headerString = extension != null ? length ~ extension(data) : length;
333         ubyte[] header = cast(ubyte[])headerString.representation;
334         import std.array : join;
335         return join([header, delimiter, data, delimiter]);
336     }
337     void popFront() {
338         if (pos != Position.data)
339             pos = Position.end;
340         else {
341             range.popFront();
342             while (!range.empty) {
343                 if (range.front.length != 0)
344                     break;
345                 range.popFront();
346             }
347             pos = range.empty ? Position.finalizer : Position.data;
348         }
349     }
350 }
351 
352 auto chunkedContent(Range)(Range range, size_t chunkSize, string delegate(ubyte[]) extension = null) {
353     return ChunkedContent!(Range)(range, chunkSize, extension);
354 }
355 
356 unittest {
357     ubyte[] data = ['h','e','l','l','o',' ','w','o','r','l','d'];
358     ubyte[][] expected = [[52, 13, 10, 104, 101, 108, 108, 13, 10], [52, 13, 10, 111, 32, 119, 111, 13, 10], [51, 13, 10, 114, 108, 100, 13, 10], [48, 13, 10, 13, 10]];
359     assert(data.chunkedContent(4).map!(i => i.array).array() == expected);
360 }
361 
362 unittest {
363     ubyte[] data = ['h','e','l','l','o',' ','w','o','r','l','d'];
364     ubyte[][] expected = [[98, 13, 10, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 13, 10], [48, 13, 10, 13, 10]];
365     assert(data.chunkedContent(11).map!(i => i.array).array() == expected);
366 }
367 
368 unittest {
369     ubyte[] data = ['h','e','l','l','o',' ','w','o','r','l','d'];
370     ubyte[][] expected = [[52, 13, 10, 104, 101, 108, 108, 13, 10], [52, 13, 10, 111, 32, 119, 111, 13, 10], [51, 13, 10, 114, 108, 100, 13, 10], [48, 13, 10, 13, 10]];
371     string delegate(ubyte[]) extension = (ubyte[] data) {
372         return "";
373     };
374     assert(data.chunkedContent(4, extension).map!(i => i.array).array() == expected);
375 }
376 
377 unittest {
378     ubyte[] data = ['h','e','l','l','o',' ','w','o','r','l','d'];
379     string delegate(ubyte[]) extension = (ubyte[] data) {
380         return ";chunk-signature=CHECKCHECK";
381     };
382     ubyte[][] expected = [[52, 59, 99, 104, 117, 110, 107, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 61, 67, 72, 69, 67, 75, 67, 72, 69, 67, 75, 13, 10, 104, 101, 108, 108, 13, 10], [52, 59, 99, 104, 117, 110, 107, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 61, 67, 72, 69, 67, 75, 67, 72, 69, 67, 75, 13, 10, 111, 32, 119, 111, 13, 10], [51, 59, 99, 104, 117, 110, 107, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 61, 67, 72, 69, 67, 75, 67, 72, 69, 67, 75, 13, 10, 114, 108, 100, 13, 10], [48, 59, 99, 104, 117, 110, 107, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 61, 67, 72, 69, 67, 75, 67, 72, 69, 67, 75, 13, 10, 13, 10]];
383     assert(data.chunkedContent(4, extension).map!(i => i.array).array() == expected);
384 }