Skip to content

Commit

Permalink
Reduce Common Usage between APOC Extended and APOC Core
Browse files Browse the repository at this point in the history
  • Loading branch information
gem-neo4j committed Dec 20, 2024
1 parent 444989b commit ca74218
Show file tree
Hide file tree
Showing 67 changed files with 316 additions and 188 deletions.
40 changes: 39 additions & 1 deletion common/src/main/java/apoc/ApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package apoc;

import static apoc.util.FileUtils.isFile;
import static java.lang.String.format;
import static org.neo4j.configuration.BootloaderSettings.lib_directory;
import static org.neo4j.configuration.BootloaderSettings.run_directory;
Expand All @@ -32,9 +31,14 @@
import static org.neo4j.internal.helpers.ProcessUtils.executeCommandWithOutput;

import apoc.export.util.ExportConfig;
import apoc.util.FileUtils;
import apoc.util.SupportedProtocols;
import apoc.util.Util;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.nio.file.Path;
import java.time.Duration;
Expand Down Expand Up @@ -337,6 +341,40 @@ public void checkWriteAllowed(ExportConfig exportConfig, String fileName) {
}
}

public static boolean isFile(String fileName) {
return from(fileName) == SupportedProtocols.file;
}

public static SupportedProtocols from(URL url) {
return FileUtils.of(url.getProtocol());
}

public static SupportedProtocols from(String source) {
try {
final URL url = new URL(source);
return from(url);
} catch (MalformedURLException e) {
if (!e.getMessage().contains("no protocol")) {
try {
// in case new URL(source) throw e.g. unknown protocol: hdfs, because of missing jar,
// we retrieve the related enum and throw the associated MissingDependencyException(..)
// otherwise we return unknown protocol: yyyyy
return SupportedProtocols.valueOf(new URI(source).getScheme());
} catch (Exception ignored) {
}

// in case a Windows user write an url like `C:/User/...`
if (e.getMessage().contains("unknown protocol") && Util.isWindows()) {
throw new RuntimeException(e.getMessage()
+ "\n Please note that for Windows absolute paths they have to be explicit by prepending `file:` or supplied without the drive, "
+ "\n e.g. `file:C:/my/path/file` or `/my/path/file`, instead of `C:/my/path/file`");
}
throw new RuntimeException(e);
}
return SupportedProtocols.file;
}
}

public static ApocConfig apocConfig() {
return theInstance;
}
Expand Down
4 changes: 2 additions & 2 deletions common/src/main/java/apoc/result/VirtualPath.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package apoc.result;

import apoc.util.CollectionUtils;
import apoc.util.Util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
Expand Down Expand Up @@ -131,7 +131,7 @@ public String toString() {

private void requireConnected(Relationship relationship) {
final List<Node> previousNodes = getPreviousNodes();
boolean isRelConnectedToPrevious = CollectionUtils.containsAny(previousNodes, relationship.getNodes());
boolean isRelConnectedToPrevious = Util.containsAny(previousNodes, relationship.getNodes());
if (!isRelConnectedToPrevious) {
throw new IllegalArgumentException("Relationship is not part of current path.");
}
Expand Down
153 changes: 151 additions & 2 deletions common/src/main/java/apoc/util/FileUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import static apoc.ApocConfig.APOC_IMPORT_FILE_ALLOW__READ__FROM__FILESYSTEM;
import static apoc.ApocConfig.apocConfig;
import static apoc.export.util.LimitedSizeInputStream.toLimitedIStream;
import static apoc.util.Util.ERROR_BYTES_OR_STRING;
import static apoc.util.Util.REDIRECT_LIMIT;
import static apoc.util.Util.readHttpInputStream;
import static apoc.util.Util.isRedirect;

import apoc.ApocConfig;
import apoc.export.util.CountingInputStream;
Expand All @@ -32,22 +33,32 @@
import apoc.util.s3.S3URLConnection;
import apoc.util.s3.S3UploadUtils;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.StringWriter;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.nio.file.NoSuchFileException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.graphdb.security.URLAccessValidationError;
Expand Down Expand Up @@ -187,7 +198,7 @@ public static CountingInputStream inputStreamFor(
if (input instanceof String) {
String fileName = (String) input;
fileName = changeFileUrlIfImportDirectoryConstrained(fileName, urlAccessChecker);
return Util.openInputStream(fileName, headers, payload, compressionAlgo, urlAccessChecker);
return FileUtils.openInputStream(fileName, headers, payload, compressionAlgo, urlAccessChecker);
} else if (input instanceof byte[]) {
return getInputStreamFromBinary((byte[]) input, compressionAlgo);
} else {
Expand Down Expand Up @@ -345,4 +356,142 @@ public static File getLogDirectory() {
public static CountingInputStream getInputStreamFromBinary(byte[] urlOrBinary, String compressionAlgo) {
return CompressionAlgo.valueOf(compressionAlgo).toInputStream(urlOrBinary);
}

public static StreamConnection readHttpInputStream(
String urlAddress,
Map<String, Object> headers,
String payload,
int redirectLimit,
URLAccessChecker urlAccessChecker)
throws IOException {
URL url = ApocConfig.apocConfig().checkAllowedUrlAndPinToIP(urlAddress, urlAccessChecker);
URLConnection con = openUrlConnection(url, headers);
writePayload(con, payload);
String newUrl = handleRedirect(con, urlAddress);
if (newUrl != null && !urlAddress.equals(newUrl)) {
con.getInputStream().close();
if (redirectLimit == 0) {
throw new IOException("Redirect limit exceeded");
}
return readHttpInputStream(newUrl, headers, payload, --redirectLimit, urlAccessChecker);
}

return new StreamConnection.UrlStreamConnection(con);
}

public static URLConnection openUrlConnection(URL src, Map<String, Object> headers) throws IOException {
URLConnection con = src.openConnection();
con.setRequestProperty("User-Agent", "APOC Procedures for Neo4j");
if (con instanceof HttpURLConnection) {
HttpURLConnection http = (HttpURLConnection) con;
http.setInstanceFollowRedirects(false);
if (headers != null) {
Object method = headers.get("method");
if (method != null) {
http.setRequestMethod(method.toString());
http.setChunkedStreamingMode(1024 * 1024);
}
headers.forEach((k, v) -> con.setRequestProperty(k, v == null ? "" : v.toString()));
}
}

con.setConnectTimeout(apocConfig().getInt("apoc.http.timeout.connect", 10_000));
con.setReadTimeout(apocConfig().getInt("apoc.http.timeout.read", 60_000));
return con;
}

private static void writePayload(URLConnection con, String payload) throws IOException {
if (payload == null) return;
con.setDoOutput(true);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(con.getOutputStream(), "UTF-8"));
writer.write(payload);
writer.close();
}

private static String handleRedirect(URLConnection con, String url) throws IOException {
if (!(con instanceof HttpURLConnection)) return url;
if (!isRedirect(((HttpURLConnection) con))) return url;
return con.getHeaderField("Location");
}

public static CountingInputStream openInputStream(
Object input,
Map<String, Object> headers,
String payload,
String compressionAlgo,
URLAccessChecker urlAccessChecker)
throws IOException, URISyntaxException, URLAccessValidationError {
if (input instanceof String) {
String urlAddress = (String) input;
final ArchiveType archiveType = ArchiveType.from(urlAddress);
if (archiveType.isArchive()) {
return getStreamCompressedFile(urlAddress, headers, payload, archiveType, urlAccessChecker);
}

StreamConnection sc = getStreamConnection(urlAddress, headers, payload, urlAccessChecker);
return sc.toCountingInputStream(compressionAlgo);
} else if (input instanceof byte[]) {
return FileUtils.getInputStreamFromBinary((byte[]) input, compressionAlgo);
} else {
throw new RuntimeException(ERROR_BYTES_OR_STRING);
}
}

private static CountingInputStream getStreamCompressedFile(
String urlAddress,
Map<String, Object> headers,
String payload,
ArchiveType archiveType,
URLAccessChecker urlAccessChecker)
throws IOException, URISyntaxException, URLAccessValidationError {
StreamConnection sc;
InputStream stream;
String[] tokens = urlAddress.split("!");
urlAddress = tokens[0];
String zipFileName;
if (tokens.length == 2) {
zipFileName = tokens[1];
sc = getStreamConnection(urlAddress, headers, payload, urlAccessChecker);
stream = getFileStreamIntoCompressedFile(sc.getInputStream(), zipFileName, archiveType);
stream = toLimitedIStream(stream, sc.getLength());
} else throw new IllegalArgumentException("filename can't be null or empty");

return new CountingInputStream(stream, sc.getLength());
}

public static StreamConnection getStreamConnection(
String urlAddress, Map<String, Object> headers, String payload, URLAccessChecker urlAccessChecker)
throws IOException, URISyntaxException, URLAccessValidationError {
return FileUtils.getStreamConnection(
FileUtils.from(urlAddress), urlAddress, headers, payload, urlAccessChecker);
}

private static InputStream getFileStreamIntoCompressedFile(InputStream is, String fileName, ArchiveType archiveType)
throws IOException {
try (ArchiveInputStream archive = archiveType.getInputStream(is)) {
ArchiveEntry archiveEntry;

while ((archiveEntry = archive.getNextEntry()) != null) {
if (!archiveEntry.isDirectory() && archiveEntry.getName().equals(fileName)) {
return new ByteArrayInputStream(IOUtils.toByteArray(archive));
}
}
}

return null;
}

public static Object getStringOrCompressedData(StringWriter writer, ExportConfig config) {
try {
final String compression = config.getCompressionAlgo();
final String writerString = writer.toString();
Object data = compression.equals(CompressionAlgo.NONE.name())
? writerString
: CompressionAlgo.valueOf(compression).compress(writerString, config.getCharset());
writer.getBuffer().setLength(0);
return data;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
Loading

0 comments on commit ca74218

Please sign in to comment.