import gnu.bytecode.*;

public abstract class CompileRegexp
{
  public static ClassType compile (String pattern)
  {
    ClassType genClass = new ClassType("gen_matcher");
    genClass.setModifiers(Access.PUBLIC);
    ClassType typePattern = ClassType.make("Pattern");
    genClass.setSuper(typePattern);

    CodeAttr code;

    // Generate default constructor:
    Method initMethod
      = genClass.addMethod("<init>", Access.PUBLIC,
			   new Type[0], Type.void_type);
    code = initMethod.startCode();
    code.emitPushThis();
    Method superConstructor
      = typePattern.getDeclaredMethod("<init>", Type.typeArray0);
    code.emitInvoke(superConstructor);
    code.emitReturn();

    // Generate 'match' method:
    Method matchMethod
      = genClass.addMethod("match", Access.PUBLIC,
                           new Type[] { Type.string_type, Type.int_type, Type.int_type },
                           Type.boolean_type);
    code = matchMethod.startCode();

   
    Variable string = code.getArg(1);
    Variable curPos = code.getArg(2);
    Variable limit = code.getArg(3);

    Method charAtMethod = Type.string_type.getDeclaredMethod("charAt", 1);

    Label failureReturn = new Label(code);
    Label onFail = failureReturn;
    int patLength = pattern.length();
    for (int i = 0;  i < patLength;  i++)
      {
	char ch = pattern.charAt(i);
	if (ch == '*')
	  {
	    if (i == patLength - 1)
	      {
		// Special case if pattern ends with '*'.
		code.emitPushInt(1);
		code.emitReturn();
	      }
	    else if (pattern.indexOf('*', i+1) < 0)
	      {
		// There are no more '*', so we can optimize.
		int restLength = patLength - i - 1;
		// Emit: if (curPos > limit - restLength) goto onFail.
		code.emitLoad(curPos);
		code.emitLoad(limit);
		code.emitPushInt(restLength);
		code.emitSub(Type.int_type);
		code.emitGotoIfGt(onFail);
		// Emit: curPos = limit - restLength;
		code.emitLoad(limit);
		code.emitPushInt(restLength);
		code.emitSub(Type.int_type);
		code.emitStore(curPos);
	      }
	    else
	      {
		// Emit:  int savePos = intPos.
		Variable savePos = code.addLocal(Type.int_type);
		code.emitLoad(curPos);
		code.emitStore(savePos);
		// Emit:  onFailHere:
		Label onFailHere = new Label(code);
		onFailHere.define(code);
		// Emit:  curPos = savePos++.
		code.emitLoad(savePos);
		code.emitDup();
		code.emitStore(curPos);
		// Emit:  If (curPos == limit) goto onFail
		code.emitLoad(limit);
		code.emitGotoIfEq(onFail);
		code.emitInc(savePos, (short) 1);
		// Make onFailHere the new place to backtrack to.
		onFail = onFailHere;
	      }
	  }
	else
	  {
	    // Emit: if (curPos == limit) goto onFail.
	    code.emitLoad(curPos);
	    code.emitLoad(limit);
	    code.emitGotoIfEq(onFail);
	    if (ch != '?')
	      {
		// Emit: if (string.charAt(curPos) != ch) goto onFail
		code.emitLoad(string);
		code.emitLoad(curPos);
		code.emitInvoke(charAtMethod);
		code.emitPushInt(ch);
		code.emitGotoIfNE(onFail);
	      }
	    // Emit:  curPos++
	    code.emitInc(curPos, (short) 1);
	  }
      }
    // We've match the while pattern.
    // Check that we're at the end of the string.
    code.emitLoad(curPos);
    code.emitLoad(limit);
    code.emitGotoIfNE(onFail);
    code.emitPushInt(1) ; // push true
    code.emitReturn();
    failureReturn.define(code);
    code.emitPushInt(0);  // push false
    code.emitReturn();

    return genClass;
  }

  /** Usage:  java CompileRegexp pattern string1 ... stringn.
   * Compile pattern, and then matching it against the stringi arguments. */
  public static void main(String[] args) throws Throwable
  {
    ClassType gtype = compile (args[0]);
    // Uncomment for debugging - writes out an actual .class file.
 gtype.writeToFile();
    ArrayClassLoader loader = new ArrayClassLoader();
    loader.addClass(gtype);
    Class gclass = loader.loadClass(gtype.getName(), true);
    Pattern rexp = (Pattern) gclass.newInstance();
    for (int i = 1;  i < args.length;  i++)
      {
	String str = args[i];
	System.out.println("match '"+ str +"'? "+ rexp.match(str));
      }
  }
  
}
