Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite @ScalaSignature when shading #393

Merged
merged 12 commits into from
May 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ To see the verbose output for shading:
logLevel in assembly := Level.Debug
```

#### Scala libraries

Scala classes contain an annotation which, among other things, contain all symbols referenced in that class. As of sbt-assembly XXX the rename rules
will be applied to these annotations as well which makes it possible to compile or reflect against a shaded library.

This is currently limited to renaming packages. Renaming class names will not work and cause compiler errors when compiling against the shaded library.

Excluding JARs and files
------------------------

Expand Down
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ lazy val root = (project in file(".")).
scalacOptions := Seq("-deprecation", "-unchecked", "-Dscalac.patmat.analysisBudget=1024", "-Xfuture"),
libraryDependencies ++= Seq(
"org.scalactic" %% "scalactic" % "3.0.8",
"org.pantsbuild" % "jarjar" % "1.7.2"
"org.pantsbuild" % "jarjar" % "1.7.2",
"org.scalatest" %% "scalatest" % "3.1.1" % Test,
),
crossSbtVersions := Seq("0.13.18", "1.2.8"), // https://github.com/sbt/sbt/issues/5049
publishArtifact in (Compile, packageBin) := true,
Expand Down
100 changes: 84 additions & 16 deletions src/main/scala/org/pantsbuild/jarjar/JJProcessor.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,99 @@
package org.pantsbuild.jarjar

import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor}
import java.io.IOException

import org.pantsbuild.jarjar.misplaced.MisplacedClassProcessorFactory
import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor, JarProcessorChain, JarTransformerChain, RemappingClassTransformer, StandaloneJarProcessor}

import scala.collection.JavaConverters._
import scala.collection.mutable

/**
* Creates a new JJProcessor, which automatically generates the standard zap, keep, remap, etc processors.
* This is a copy of the MainProcessor in JarJar with an added ScalaSigProcessor
*
* @param patterns List of rules to parse.
* @param verbose Whether to verbosely log information.
* @param skipManifest If true, omits the manifest file from the processed jar.
* @param misplacedClassStrategy The strategy to use when processing class files that are in the
* wrong package (see MisplacedClassProcessorFactory.STRATEGY_* constants).
*/
class JJProcessor(val patterns: Seq[PatternElement], val verbose: Boolean, val skipManifest: Boolean, val misplacedClassStrategy: String) extends JarProcessor {

val zapList: Seq[Zap] = patterns.collect { case zap: Zap => zap }
val ruleList: Seq[Rule] = patterns.collect { case rule: Rule => rule }
val keepList: Seq[Keep] = patterns.collect { case keep: Keep => keep }
val renames: mutable.Map[String, String] = collection.mutable.HashMap[String, String]()

val kp: KeepProcessor = if (keepList.isEmpty) null else new KeepProcessor(keepList.asJava)

val pr = new PackageRemapper(ruleList.asJava, verbose)

class JJProcessor(val proc: JarProcessor) {
val processors: mutable.ArrayBuffer[JarProcessor] = collection.mutable.ArrayBuffer[JarProcessor]()
if (skipManifest)
processors += ManifestProcessor.getInstance
if (kp != null)
processors += kp

def process(entry: EntryStruct): Boolean = proc.process(entry)
val misplacedClassProcessor: JarProcessor = MisplacedClassProcessorFactory.getInstance.getProcessorForName(misplacedClassStrategy)
processors += new ZapProcessor(zapList.asJava)
processors += misplacedClassProcessor
processors += new JarTransformerChain(Array[RemappingClassTransformer](new RemappingClassTransformer(pr)))

def getExcludes(): Set[String] = {
val field = proc.getClass().getDeclaredField("kp")
field.setAccessible(true)
val keepProcessor = field.get(proc)
val renamer: String => Option[String] = {
val wildcards = PatternElement.createWildcards(ruleList.asJava).asScala

if (keepProcessor == null) Set()
else {
val method = proc.getClass().getDeclaredMethod("getExcludes")
method.setAccessible(true)
method.invoke(proc).asInstanceOf[java.util.Set[String]].asScala.toSet
value: String => {
val result = wildcards.flatMap {
wc =>
val slashed = value.replace('.', '/') // The jarjar wildcards expect slashes instead of dots
// Hack to replace the package object name.
val renamed = Option(wc.replace(slashed)).orElse(Option(wc.replace(slashed + "/")).map(_.dropRight(1)))
renamed.map(_.replace('/', '.')) // Unslash
}.headOption

result
}
}

}
processors += new ScalaSigProcessor(renamer)
processors += new MethodSignatureProcessor(pr)
processors += new ResourceProcessor(pr)
val chain = new JarProcessorChain(processors.toArray)

object JJProcessor {
@throws[IOException]
def strip(file: Nothing): Unit = {
if (kp != null) {
val excludes = getExcludes
if (excludes.nonEmpty) StandaloneJarProcessor.run(file, file, new ExcludeProcessor(excludes.asJava, verbose))
}
}

def apply(patterns: Seq[PatternElement], verbose: Boolean, skipManifest: Boolean): JJProcessor =
new JJProcessor(new MainProcessor(patterns.asJava, verbose, skipManifest))
/**
* Returns the <code>.class</code> files to delete. As well the root-parameter as the rename ones
* are taken in consideration, so that the concerned files are not listed in the result.
*
* @return the paths of the files in the jar-archive, including the <code>.class</code> suffix
*/
def getExcludes: Set[String] = if (kp != null) kp.getExcludes.asScala.map { exclude =>
val name = exclude + ".class"
renames.getOrElse(name, name)
}.toSet else Set.empty

/**
*
* @param struct entry struct to process
* @return <code>true</code> if the entry is to include in the output jar
* @throws IOException
*/
@throws[IOException]
def process(struct: EntryStruct): Boolean = {
val name = struct.name
val keepIt = chain.process(struct)
if (keepIt) if (!name.equals(struct.name)) {
if (kp != null) renames.put(name, struct.name)
if (verbose) System.err.println("Renamed " + name + " -> " + struct.name)
} else if (verbose) System.err.println("Removed " + name)
keepIt
}
}
20 changes: 20 additions & 0 deletions src/main/scala/org/pantsbuild/jarjar/ScalaSigProcessor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.pantsbuild.jarjar

import org.objectweb.asm.{ClassReader, ClassWriter}
import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor}
import sbtassembly.scalasig.ScalaSigClassVisitor

class ScalaSigProcessor(renamer: String => Option[String]) extends JarProcessor {
override def process(struct: EntryStruct): Boolean = {

if (!struct.name.endsWith(".class") || struct.skipTransform) true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe allow *.sig files here? - scala/scala#7712

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it to the list. I have some reading up to do here.

else {
val classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS)
val reader = new ClassReader(struct.data)

reader.accept(new ScalaSigClassVisitor(classWriter, renamer), ClassReader.EXPAND_FRAMES)
struct.data = classWriter.toByteArray
true
}
}
}
9 changes: 4 additions & 5 deletions src/main/scala/sbtassembly/Shader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package sbtassembly

import java.io.File

import org.pantsbuild.jarjar._
import org.pantsbuild.jarjar.{JJProcessor, _}
import org.pantsbuild.jarjar.util.EntryStruct

import sbt._

case class ShadeRule(shadePattern: ShadePattern, targets: Seq[ShadeTarget] = Seq()) {
Expand Down Expand Up @@ -83,7 +82,7 @@ private[sbtassembly] object Shader {
case _ => Nil
}}

val proc = JJProcessor(jjrules, verbose = level == Level.Debug, true)
val proc = new JJProcessor(jjrules, verbose = level == Level.Debug, true, null)

/*
jarjar MisplacedClassProcessor class transforms byte[] to a class using org.objectweb.asm.ClassReader.getClassName
Expand All @@ -104,7 +103,7 @@ private[sbtassembly] object Shader {
IO.write(dir / entry.name, entry.data)
}
}
val excludes = proc.getExcludes()
val excludes = proc.getExcludes
excludes.foreach(exclude => IO.delete(dir / exclude))
}
}
}
22 changes: 22 additions & 0 deletions src/main/scala/sbtassembly/scalasig/ByteArrayReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package sbtassembly.scalasig

// Utility class to read the content of a single table entry
class ByteArrayReader(bytes: Array[Byte]) extends Nat.Reader {
private var readIndex = 0

/** Read a byte */
override def readByte(): Int = {
val x = bytes(readIndex).toInt
readIndex += 1
x
}

/** Reads a number of bytes into an array */
def readBytes(len: Int): Array[Byte] = {
val result = bytes.slice(readIndex, readIndex + len)
readIndex += len
result
}

def atEnd: Boolean = readIndex == bytes.length
}
150 changes: 150 additions & 0 deletions src/main/scala/sbtassembly/scalasig/EntryTable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package sbtassembly.scalasig

import java.io.ByteArrayOutputStream

import scala.collection.mutable
import scala.reflect.internal.pickling.PickleFormat

/**
* Mutable table of tagged entries
* @param majorVersion major table version
* @param minorVersion minor table version
* @param entries initial table entries
*/
class EntryTable(majorVersion: Int, minorVersion: Int, entries: mutable.Buffer[TaggedEntry]) {
// Mapping of known TermName or TypeNames to their index in the table.
private val nameIndices: mutable.Map[NameEntry, Int] = mutable.HashMap(
entries.zipWithIndex.collect {
case (entry: NameEntry, index) => (entry, index)
}:_*
)

/**
* Return the current table entries as an immutable seq.
* @return table entries
*/
def toSeq: Seq[TaggedEntry] = entries.toVector

/**
* Rename term and type entries in this table according to the renamer function.
* A name or type is referred to by a Ref entry. The existing ref entries are reused to references to them will remain intact.
* Unused entries will not be removed from the table.
*
* @param renamer renames a fully qualified type or term name or return None if it does not match.
*/
def renameEntries(renamer: String => Option[String]): Unit = {

entries.zipWithIndex.collect {
case (ref: RefEntry, index) =>
entries(ref.nameRef) match {
case nameEntry: NameEntry =>
for {
fqName <- resolveRef(ref)
renamed <- renamer(fqName)
} {
val parts = renamed.split('.')

val myOwner = parts.init.foldLeft(Option.empty[Int]) { (owner, part) =>
val nameIndex = getOrAppendNameEntry(NameEntry(PickleFormat.TERMname, part))
val nextOwner = appendEntry(RefEntry(PickleFormat.EXTMODCLASSref, nameIndex, owner))
Some(nextOwner)
}

entries(index) = ref.copy(nameRef = getOrAppendNameEntry(nameEntry.copy(name = parts.last)), ownerRef = myOwner)
}

case other =>
throw new RuntimeException(s"Ref entry does not point to a name but to a ${other.tag}")
}
}
}

// Return existing name entry or append a new one.
private def getOrAppendNameEntry(name: NameEntry): Int = {
nameIndices.getOrElse(name, appendEntry(name))
}

private def appendEntry(entry: TaggedEntry): Int = {
val index = entries.size
entries += entry

entry match {
case name: NameEntry =>
nameIndices.put(name, index)
case _ => // NoOp
}

index
}

// Resolves a ref into a fully qualified name
def resolveRef(extMod: RefEntry): Option[String] = {

val myName = entries(extMod.nameRef) match {
case term: NameEntry => term.name
case raw: RawEntry => throw new RuntimeException(s"Unexpected raw type for nameref ${raw.tag}")
case other => throw new RuntimeException(s"Unexpected type for nameref $other")
}
extMod.ownerRef match {
case None => Some(myName)
case Some(owner) =>
entries(owner) match {
case name: NameEntry =>
Some(s"$name/$myName")
case mod: RefEntry =>
resolveRef(mod).map(p => s"$p.$myName")
case raw: RawEntry if raw.tag == PickleFormat.NONEsym =>
None
case raw: RawEntry =>
throw new RuntimeException(s"Not a known owner type tag for $myName : ${raw.tag}")
}
}
}

/**
* Serializes this entry table into a byte array.
*/
def toBytes: Array[Byte] = {
val os = new ByteArrayOutputStream()
val writer = new Nat.Writer {
override def writeByte(b: Int): Unit = os.write(b)
}

writer.writeNat(majorVersion)
writer.writeNat(minorVersion)
writer.writeNat(entries.size)

entries.foreach { entry =>
val payloadBytes = entry.toBytes
writer.writeNat(entry.tag) // Tag of entry
writer.writeNat(payloadBytes.length) // Size of payload
os.write(payloadBytes)
}

os.toByteArray
}
}

object EntryTable {

/**
* Parse bytes into a EntryTable
*/
def fromBytes(bytes: Array[Byte]): EntryTable = {
val reader = new ByteArrayReader(bytes)

val majorVersion = reader.readNat()
val minorVersion = reader.readNat()

val result = new Array[TaggedEntry](reader.readNat())

result.indices foreach { index =>
val tag = reader.readNat()
val len = reader.readNat()

result(index) = TaggedEntry(tag, reader.readBytes(len))
}

new EntryTable(majorVersion, minorVersion, result.toBuffer)
}
}
Loading